Commit 17304955 authored by Andrey Filippov's avatar Andrey Filippov

making pydev happy

parent 92d4bc4e
...@@ -55,7 +55,22 @@ Temporarily for backward compatibility ...@@ -55,7 +55,22 @@ Temporarily for backward compatibility
""" """
if not "SLOSS_CLIP" in parameters: if not "SLOSS_CLIP" in parameters:
parameters['SLOSS_CLIP'] = 0.5 parameters['SLOSS_CLIP'] = 0.5
print ("Old config, setting SLOSS_CLIP=",SLOSS_CLIP) print ("Old config, setting SLOSS_CLIP=", parameters['SLOSS_CLIP'])
"""
Defined in config file
"""
TILE_SIDE, TILE_LAYERS, TWO_TRAINS, NET_ARCH1, NET_ARCH2 = [None]*5
ABSOLUTE_DISPARITY,SYM8_SUB, WLOSS_LAMBDA, SLOSS_LAMBDA, SLOSS_CLIP = [None]*5
SPREAD_CONVERGENCE, INTER_CONVERGENCE, HOR_FLIP, DISP_DIFF_CAP, DISP_DIFF_SLOPE = [None]*5
CLUSTER_RADIUS = None
PARTIALS_WEIGHTS, MAX_IMGS_IN_MEM, MAX_FILES_PER_GROUP, BATCH_WEIGHTS, ONLY_TILE = [None] * 5
USE_CONFIDENCE, WBORDERS_ZERO, EPOCHS_TO_RUN, FILE_UPDATE_EPOCHS = [None] * 4
LR600,LR400,LR200,LR100,LR = [None]*5
SHUFFLE_FILES, EPOCHS_FULL_TEST, SAVE_TIFFS = [None] * 3
globals().update(parameters) globals().update(parameters)
...@@ -178,7 +193,7 @@ def debug_gt_variance( ...@@ -178,7 +193,7 @@ def debug_gt_variance(
gt_ds_batch # [?:9:2] gt_ds_batch # [?:9:2]
): ):
with tf.name_scope("Debug_GT_Variance"): with tf.name_scope("Debug_GT_Variance"):
tf_num_tiles = tf.shape(gt_ds_batch)[0] # tf_num_tiles = tf.shape(gt_ds_batch)[0]
d_gt_this = tf.reshape(gt_ds_batch[:,2 * indx],[-1], name = "d_this") d_gt_this = tf.reshape(gt_ds_batch[:,2 * indx],[-1], name = "d_this")
d_gt_center = tf.reshape(gt_ds_batch[:,2 * center_indx],[-1], name = "d_center") d_gt_center = tf.reshape(gt_ds_batch[:,2 * center_indx],[-1], name = "d_center")
d_gt_diff = tf.subtract(d_gt_this, d_gt_center, name = "d_diff") d_gt_diff = tf.subtract(d_gt_this, d_gt_center, name = "d_diff")
...@@ -401,7 +416,8 @@ with tf.Session() as sess: ...@@ -401,7 +416,8 @@ with tf.Session() as sess:
img_gain_test9 = 1.0 img_gain_test9 = 1.0
num_train_variants = len(datasets_train) num_train_variants = len(datasets_train)
thr=None; thr=None
thr_result = None
trains_to_update = [train_next[n_train]['files'] > train_next[n_train]['slots'] for n_train in range(len(train_next))] trains_to_update = [train_next[n_train]['files'] > train_next[n_train]['slots'] for n_train in range(len(train_next))]
for epoch in range (EPOCHS_TO_RUN): for epoch in range (EPOCHS_TO_RUN):
""" """
......
...@@ -10,7 +10,8 @@ import tensorflow as tf ...@@ -10,7 +10,8 @@ import tensorflow as tf
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import time import time
import imagej_tiffwriter import imagej_tiffwriter
TIME_LAST = 0
TIME_START = 0
class bcolors: class bcolors:
HEADER = '\033[95m' HEADER = '\033[95m'
...@@ -32,8 +33,6 @@ def print_time(txt="",end="\n"): ...@@ -32,8 +33,6 @@ def print_time(txt="",end="\n"):
def parseXmlConfig(conf_file, root_dir): def parseXmlConfig(conf_file, root_dir):
tree = ET.parse(conf_file) tree = ET.parse(conf_file)
root = tree.getroot() root = tree.getroot()
directories = root.find('directories')
files = root.find('files')
parameters = {} parameters = {}
for p in root.find('parameters'): for p in root.find('parameters'):
parameters[p.tag]=eval(p.text.strip()) parameters[p.tag]=eval(p.text.strip())
...@@ -141,7 +140,7 @@ def getMoreFiles(fpaths,rslt, cluster_radius, hor_flip, tile_layers, tile_side): ...@@ -141,7 +140,7 @@ def getMoreFiles(fpaths,rslt, cluster_radius, hor_flip, tile_layers, tile_side):
rslt.append(dataset) rslt.append(dataset)
#from http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ #from http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
def read_and_decode(filename_queue): def read_and_decode(filename_queue, featrures_per_tile):
reader = tf.TFRecordReader() reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) _, serialized_example = reader.read(filename_queue)
...@@ -149,7 +148,7 @@ def read_and_decode(filename_queue): ...@@ -149,7 +148,7 @@ def read_and_decode(filename_queue):
serialized_example, serialized_example,
# Defaults are not specified since both keys are required. # Defaults are not specified since both keys are required.
features={ features={
'corr2d': tf.FixedLenFeature([FEATURES_PER_TILE],tf.float32), #string), 'corr2d': tf.FixedLenFeature([featrures_per_tile],tf.float32), #string),
'target_disparity': tf.FixedLenFeature([1], tf.float32), #.string), 'target_disparity': tf.FixedLenFeature([1], tf.float32), #.string),
'gt_ds': tf.FixedLenFeature([2], tf.float32) #.string) 'gt_ds': tf.FixedLenFeature([2], tf.float32) #.string)
}) })
...@@ -488,19 +487,28 @@ def result_npy_to_tiff(npy_path, absolute, fix_nan, insert_deltas=True): ...@@ -488,19 +487,28 @@ def result_npy_to_tiff(npy_path, absolute, fix_nan, insert_deltas=True):
""" """
tiff_path = npy_path.replace('.npy','.tiff') tiff_path = npy_path.replace('.npy','.tiff')
data = np.load(npy_path) #(324,242,4) [nn_disp, target_disp,gt_disp, gt_conf] data = np.load(npy_path) #(324,242,4) [nn_disp, target_disp,gt_disp, gt_conf]
nn_out = 0
target_disparity = 1
gt_disparity = 2
gt_strength = 3
if not absolute: if not absolute:
if fix_nan: if fix_nan:
data[...,0] += np.nan_to_num(data[...,1], copy=True) data[...,nn_out] += np.nan_to_num(data[...,1], copy=True)
else: else:
data[...,0] += data[...,1] data[...,nn_out] += data[...,1]
if insert_deltas: if insert_deltas:
data = np.concatenate([data[...,0:4],data[...,0:2],data[...,4:]], axis = 2) np.nan_to_num(data[...,gt_strength], copy=False)
data[...,4] -= data[...,2] data = np.concatenate([data[...,0:4],data[...,0:2],data[...,0:2],data[...,4:]], axis = 2)
data[...,5] -= data[...,2] data[...,6] -= data[...,gt_disparity]
np.nan_to_num(data[...,3], copy=False) data[...,7] -= data[...,gt_disparity]
data[...,4] = np.select([data[...,3]==0.0, data[...,3]>0.0], [np.nan,data[...,4]]) for l in [4,5,6,7]:
data[...,5] = np.select([data[...,3]==0.0, data[...,3]>0.0], [np.nan,data[...,5]]) data[...,l] = np.select([data[...,gt_strength]==0.0, data[...,gt_strength]>0.0], [np.nan,data[...,l]])
# All other layers - mast too
for l in range(8,data.shape[2]):
data[...,l] = np.select([data[...,gt_strength]==0.0, data[...,gt_strength]>0.0], [np.nan,data[...,l]])
# data[...,4] = np.select([data[...,3]==0.0, data[...,3]>0.0], [np.nan,data[...,4]])
# data[...,5] = np.select([data[...,3]==0.0, data[...,3]>0.0], [np.nan,data[...,5]])
data = data.transpose(2,0,1) data = data.transpose(2,0,1)
print("Saving results to TIFF: "+tiff_path) print("Saving results to TIFF: "+tiff_path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment