Commit 4ecd26fe authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

Merge branch 'master' of git.elphel.com:Elphel/python3-imagej-tiff

parents be49b8cc 3519f5ec
This diff is collapsed.
#!/usr/bin/env python3
from numpy import float64
from tensorflow.contrib.losses.python.metric_learning.metric_loss_ops import npairs_loss
from debian.deb822 import PdiffIndex
##from numpy import float64
##from tensorflow.contrib.losses.python.metric_learning.metric_loss_ops import npairs_loss
##from debian.deb822 import PdiffIndex
__copyright__ = "Copyright 2018, Elphel, Inc."
__license__ = "GPL-3.0+"
__email__ = "andrey@elphel.com"
from PIL import Image
##from PIL import Image
import os
import sys
import glob
##import glob
import numpy as np
import itertools
##import itertools
import time
import matplotlib.pyplot as plt
##import matplotlib.pyplot as plt
import shutil
from threading import Thread
......@@ -49,7 +49,7 @@ except IndexError:
root_dir = os.path.dirname(conf_file)
print ("Configuration file: " + conf_file)
parameters, dirs, files = qsf.parseXmlConfig(conf_file, root_dir)
parameters, dirs, files, _ = qsf.parseXmlConfig(conf_file, root_dir)
"""
Temporarily for backward compatibility
"""
......@@ -221,7 +221,7 @@ if SPREAD_CONVERGENCE:
else:
outs, inp_weights = qcstereo_network.networks_siam(
input= corr2d_Nx325,
input_tensor= corr2d_Nx325,
input_global = None,
layout1 = NN_LAYOUT1,
layout2 = NN_LAYOUT2,
......@@ -247,7 +247,7 @@ G_losses[0], _disp_slice, _d_gt_slice, _out_diff, _out_diff2, _w_norm, _out_wdif
absolute_disparity = ABSOLUTE_DISPARITY,
use_confidence = USE_CONFIDENCE, # True,
lambda_conf_avg = 0.01,
lambda_conf_pwr = 0.1,
## lambda_conf_pwr = 0.1,
conf_pwr = 2.0,
gt_conf_offset = 0.08,
gt_conf_pwr = 2.0,
......@@ -268,7 +268,7 @@ for n in range (1,len(partials)):
absolute_disparity = ABSOLUTE_DISPARITY,
use_confidence = USE_CONFIDENCE, # True,
lambda_conf_avg = 0.01,
lambda_conf_pwr = 0.1,
# lambda_conf_pwr = 0.1,
conf_pwr = 2.0,
gt_conf_offset = 0.08,
gt_conf_pwr = 2.0,
......
#!/usr/bin/env python3
__copyright__ = "Copyright 2018, Elphel, Inc."
__license__ = "GPL-3.0+"
__email__ = "andrey@elphel.com"
from PIL import Image
import os
import sys
import glob
import numpy as np
import time
import matplotlib.pyplot as plt
import qcstereo_functions as qsf
#import xml.etree.ElementTree as ET
qsf.TIME_START = time.time()
qsf.TIME_LAST = qsf.TIME_START
IMG_WIDTH = 324 # tiles per image row
DEBUG_LEVEL= 1
try:
conf_file = sys.argv[1]
except IndexError:
print("Configuration path is required as a first argument. Optional second argument specifies root directory for data files")
exit(1)
try:
root_dir = sys.argv[2]
except IndexError:
root_dir = os.path.dirname(conf_file)
print ("Configuration file: " + conf_file)
parameters, dirs, files, dbg_parameters = qsf.parseXmlConfig(conf_file, root_dir)
"""
Temporarily for backward compatibility
"""
if not "SLOSS_CLIP" in parameters:
parameters['SLOSS_CLIP'] = 0.5
print ("Old config, setting SLOSS_CLIP=", parameters['SLOSS_CLIP'])
"""
Defined in config file
"""
TILE_SIDE, TILE_LAYERS, TWO_TRAINS, NET_ARCH1, NET_ARCH2 = [None]*5
ABSOLUTE_DISPARITY,SYM8_SUB, WLOSS_LAMBDA, SLOSS_LAMBDA, SLOSS_CLIP = [None]*5
SPREAD_CONVERGENCE, INTER_CONVERGENCE, HOR_FLIP, DISP_DIFF_CAP, DISP_DIFF_SLOPE = [None]*5
CLUSTER_RADIUS,ABSOLUTE_DISPARITY = [None]*2
globals().update(parameters)
#exit(0)
TILE_SIZE = TILE_SIDE* TILE_SIDE # == 81
FEATURES_PER_TILE = TILE_LAYERS * TILE_SIZE# == 324
BATCH_SIZE = ([1,2][TWO_TRAINS])*2*1000//25 # == 80 Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches
SUFFIX=(str(NET_ARCH1)+'-'+str(NET_ARCH2)+
(["R","A"][ABSOLUTE_DISPARITY]) +
(["NS","S8"][SYM8_SUB])+
"WLAM"+str(WLOSS_LAMBDA)+
"SLAM"+str(SLOSS_LAMBDA)+
"SCLP"+str(SLOSS_CLIP)+
(['_nG','_G'][SPREAD_CONVERGENCE])+
(['_nI','_I'][INTER_CONVERGENCE]) +
(['_nHF',"_HF"][HOR_FLIP]) +
('_CP'+str(DISP_DIFF_CAP)) +
('_S'+str(DISP_DIFF_SLOPE))
)
##############################################################################
cluster_size = (2 * CLUSTER_RADIUS + 1) * (2 * CLUSTER_RADIUS + 1)
center_tile_index = 2 * CLUSTER_RADIUS * (CLUSTER_RADIUS + 1)
qsf.prepareFiles(dirs, files, suffix = SUFFIX)
#import tensorflow.contrib.slim as slim
NN_DISP = 0
HEUR_DISP = 1
GT_DISP = 2
GT_CONF = 3
NN_NAN = 4
HEUR_NAN = 5
NN_DIFF = 6
HEUR_DIFF = 7
CONF_MAX = 0.7
ERR_AMPL = 0.3
TIGHT_TOP = 0.95
TIGHT_HPAD = 1.0
TIGHT_WPAD = 1.0
FIGSIZE = [8.5,11.0]
WOI_COLOR = "red"
#dbg_parameters
def get_fig_params(disparity_ranges):
fig_params = []
for dr in disparity_ranges:
if dr[-1][0]=='-':
fig_params.append(None)
else:
subs = []
for s in dr[:-1]:
mm = s[:2]
try:
lims = s[2]
except IndexError:
lims = None
subs.append({'lim_val':mm, 'lim_xy':lims})
fig_params.append({'name':dr[-1],'ranges':subs})
return fig_params
#try:
fig_params = get_fig_params(dbg_parameters['disparity_ranges'])
pass
figs = []
def setlimsxy(lim_xy):
if not lim_xy is None:
plt.xlim(min(lim_xy[:2]),max(lim_xy[:2]))
plt.ylim(max(lim_xy[2:]),min(lim_xy[2:]))
for nfile, fpars in enumerate(fig_params):
if not fpars is None:
data = qsf.result_npy_prepare(files['result'][nfile], ABSOLUTE_DISPARITY, fix_nan=True, insert_deltas=True)
for rng in fpars['ranges']:
lim_val = rng['lim_val']
lim_xy = rng['lim_xy']
fig = plt.figure(figsize=FIGSIZE)
fig.canvas.set_window_title(fpars['name'])
fig.suptitle(fpars['name'])
ax_conf=plt.subplot(322)
ax_conf.set_title("Ground truth confidence")
# fig.suptitle("Groud truth confidence")
plt.imshow(data[...,GT_CONF], vmin=0, vmax=CONF_MAX, cmap='gray')
if not lim_xy is None:
pass # show frame
xdata=[min(lim_xy[:2]),max(lim_xy[:2]),max(lim_xy[:2]),min(lim_xy[:2]),min(lim_xy[:2])]
ydata=[min(lim_xy[2:]),min(lim_xy[2:]),max(lim_xy[2:]),max(lim_xy[2:]),min(lim_xy[2:])]
plt.plot(xdata,ydata,color=WOI_COLOR)
# setlimsxy(lim_xy)
plt.colorbar(orientation='vertical') # location='bottom')
ax_gtd=plt.subplot(321)
ax_gtd.set_title("Ground truth disparity map")
plt.imshow(data[...,GT_DISP], vmin=lim_val[0], vmax=lim_val[1])
setlimsxy(lim_xy)
plt.colorbar(orientation='vertical') # location='bottom')
ax_hed=plt.subplot(323)
ax_hed.set_title("Heuristic disparity map")
plt.imshow(data[...,HEUR_NAN], vmin=lim_val[0], vmax=lim_val[1])
setlimsxy(lim_xy)
plt.colorbar(orientation='vertical') # location='bottom')
ax_nnd=plt.subplot(325)
ax_nnd.set_title("Network disparity output")
plt.imshow(data[...,NN_NAN], vmin=lim_val[0], vmax=lim_val[1])
setlimsxy(lim_xy)
plt.colorbar(orientation='vertical') # location='bottom')
ax_hee=plt.subplot(324)
ax_hee.set_title("Heuristic disparity error")
plt.imshow(data[...,HEUR_DIFF], vmin=-ERR_AMPL, vmax=ERR_AMPL)
setlimsxy(lim_xy)
plt.colorbar(orientation='vertical') # location='bottom')
ax_nne=plt.subplot(326)
ax_nne.set_title("Network disparity error")
plt.imshow(data[...,NN_DIFF], vmin=-ERR_AMPL, vmax=ERR_AMPL)
setlimsxy(lim_xy)
plt.colorbar(orientation='vertical') # location='bottom')
plt.tight_layout(rect =[0,0,1,TIGHT_TOP], h_pad = TIGHT_HPAD, w_pad = TIGHT_WPAD)
figs.append(fig)
pass
#whow to allow adjustment before applying tight_layout?
pass
for fig in figs:
fig.tight_layout(rect =[0,0,1,TIGHT_TOP], h_pad = TIGHT_HPAD, w_pad = TIGHT_WPAD)
plt.show()
#qsf.evaluateAllResults(result_files = files['result'],
# absolute_disparity = ABSOLUTE_DISPARITY,
# cluster_radius = CLUSTER_RADIUS)
print("All done")
exit (0)
......@@ -30,13 +30,14 @@ def print_time(txt="",end="\n"):
txt +=" "
print(("%s"+bcolors.BOLDWHITE+"at %.4fs (+%.4fs)"+bcolors.ENDC)%(txt,t-TIME_START,t-TIME_LAST), end = end, flush=True)
TIME_LAST = t
def parseXmlConfig(conf_file, root_dir):
tree = ET.parse(conf_file)
root = tree.getroot()
parameters = {}
for p in root.find('parameters'):
parameters[p.tag]=eval(p.text.strip())
globals
# globals
dirs={}
for p in root.find('directories'):
dirs[p.tag]=eval(p.text.strip())
......@@ -46,7 +47,11 @@ def parseXmlConfig(conf_file, root_dir):
for p in root.find('files'):
files[p.tag]=eval(p.text.strip())
# globals().update(parameters)
return parameters, dirs, files
dbg_parameters = {}
for p in root.find('dbg_parameters'):
dbg_parameters[p.tag]=eval(p.text.strip())
return parameters, dirs, files, dbg_parameters
......@@ -84,7 +89,8 @@ def readTFRewcordsEpoch(train_filename):
npy_dir_name = "npy"
dirname = os.path.dirname(train_filename)
npy_dir = os.path.join(dirname, npy_dir_name)
filebasename, file_extension = os.path.splitext(train_filename)
# filebasename, file_extension = os.path.splitext(train_filename)
filebasename, _ = os.path.splitext(train_filename)
filebasename = os.path.basename(filebasename)
file_corr2d = os.path.join(npy_dir,filebasename + '_corr2d.npy')
file_target_disparity = os.path.join(npy_dir,filebasename + '_target_disparity.npy')
......@@ -179,7 +185,7 @@ def add_neibs(npa_ext,radius):
height = npa_ext.shape[0]-2*radius
width = npa_ext.shape[1]-2*radius
side = 2 * radius + 1
size = side * side
# size = side * side
npa_neib = np.empty((height, width, side, side, npa_ext.shape[2]), dtype = npa_ext.dtype)
for dy in range (side):
for dx in range (side):
......@@ -187,8 +193,8 @@ def add_neibs(npa_ext,radius):
return npa_neib.reshape(height, width, -1)
def extend_img_to_clusters(datasets_img,radius, width): # = 324):
side = 2 * radius + 1
size = side * side
# side = 2 * radius + 1
# size = side * side
if len(datasets_img) ==0:
return
num_tiles = datasets_img[0]['corr2d'].shape[0]
......@@ -210,7 +216,7 @@ def reformat_to_clusters(datasets_data, cluster_radius):
def flip_horizontal(datasets_data, cluster_radius, tile_layers, tile_side):
cluster_side = 2 * cluster_radius + 1
cluster_size = cluster_side * cluster_side
# cluster_size = cluster_side * cluster_side
"""
TILE_LAYERS = 4
TILE_SIDE = 9 # 7
......@@ -238,8 +244,8 @@ TILE_SIZE = TILE_SIDE* TILE_SIDE # == 81
rec['target_disparity'] = target_disparity.reshape((target_disparity.shape[0],-1))
rec['gt_ds'] = gt_ds.reshape((gt_ds.shape[0],-1))
def replace_nan(datasets_data, cluster_radius):
cluster_size = (2 * cluster_radius + 1) * (2 * cluster_radius + 1)
def replace_nan(datasets_data): # , cluster_radius):
# cluster_size = (2 * cluster_radius + 1) * (2 * cluster_radius + 1)
# Reformat input data
for rec in datasets_data:
if not rec is None:
......@@ -259,7 +265,7 @@ def permute_to_swaps(perm):
def shuffle_in_place(datasets_data, indx, period):
swaps = permute_to_swaps(np.random.permutation(len(datasets_data)))
num_entries = datasets_data[0]['corr2d'].shape[0] // period
# num_entries = datasets_data[0]['corr2d'].shape[0] // period
for swp in swaps:
ds0 = datasets_data[swp[0]]
ds1 = datasets_data[swp[1]]
......@@ -279,9 +285,10 @@ def shuffle_chunks_in_place(datasets_data, tiles_groups_per_chunk):
"""
Improve shuffling by preserving indices inside batches (0 <->0, ... 39 <->39 for 40 tile group batches)
"""
num_files = len(datasets_data)
# num_files = len(datasets_data)
#chunks_per_file = datasets_data[0]['target_disparity']
for nf, ds in enumerate(datasets_data):
# for nf, ds in enumerate(datasets_data):
for ds in datasets_data:
groups_per_file = ds['corr2d'].shape[0]
chunks_per_file = groups_per_file//tiles_groups_per_chunk
permut = np.random.permutation(chunks_per_file)
......@@ -327,7 +334,8 @@ def zip_lvar_hvar(datasets_all_data, del_src = True):
'target_disparity': np.empty((recs[0]['target_disparity'].shape[0]*num_sets_to_combine,recs[0]['target_disparity'].shape[1]),dtype=np.float32),
'gt_ds': np.empty((recs[0]['gt_ds'].shape[0]*num_sets_to_combine, recs[0]['gt_ds'].shape[1]),dtype=np.float32)}
for nset, reci in enumerate(recs):
# for nset, reci in enumerate(recs):
for nset, _ in enumerate(recs):
rec['corr2d'] [nset::num_sets_to_combine] = recs[nset]['corr2d']
rec['target_disparity'][nset::num_sets_to_combine] = recs[nset]['target_disparity']
rec['gt_ds'] [nset::num_sets_to_combine] = recs[nset]['gt_ds']
......@@ -356,10 +364,10 @@ def initTrainTestData(
max_files_per_group, # shuffling buffer for files
two_trains,
train_next):
datasets_train_lvar = []
datasets_train_hvar = []
datasets_train_lvar1 = []
datasets_train_hvar1 = []
# datasets_train_lvar = []
# datasets_train_hvar = []
# datasets_train_lvar1 = []
# datasets_train_hvar1 = []
datasets_train_all = [[],[],[],[]]
for n_train, f_train in enumerate(files['train']):
if len(f_train) and ((n_train<2) or two_trains):
......@@ -445,7 +453,8 @@ def readImageData(image_data,
cluster_radius,
width)
if replace_nans:
replace_nan([image_data[indx]], cluster_radius)
# replace_nan([image_data[indx]], cluster_radius)
replace_nan([image_data[indx]])
return image_data[indx]
......@@ -477,7 +486,7 @@ def evaluateAllResults(result_files, absolute_disparity, cluster_radius):
def result_npy_to_tiff(npy_path, absolute, fix_nan, insert_deltas=True):
def result_npy_prepare(npy_path, absolute, fix_nan, insert_deltas=True):
"""
@param npy_path full path to the npy file with 4-layer data (242,324,4) - nn_disparity(offset), target_disparity, gt disparity, gt strength
......@@ -485,10 +494,9 @@ def result_npy_to_tiff(npy_path, absolute, fix_nan, insert_deltas=True):
@param absolute - True - the first layer contains absolute disparity, False - difference from target_disparity
@param fix_nan - replace nan in target_disparity with 0 to apply offset, target_disparity will still contain nan
"""
tiff_path = npy_path.replace('.npy','.tiff')
data = np.load(npy_path) #(324,242,4) [nn_disp, target_disp,gt_disp, gt_conf]
nn_out = 0
target_disparity = 1
# target_disparity = 1
gt_disparity = 2
gt_strength = 3
if not absolute:
......@@ -501,20 +509,28 @@ def result_npy_to_tiff(npy_path, absolute, fix_nan, insert_deltas=True):
data = np.concatenate([data[...,0:4],data[...,0:2],data[...,0:2],data[...,4:]], axis = 2)
data[...,6] -= data[...,gt_disparity]
data[...,7] -= data[...,gt_disparity]
for l in [4,5,6,7]:
for l in [2, 4, 5, 6, 7]:
data[...,l] = np.select([data[...,gt_strength]==0.0, data[...,gt_strength]>0.0], [np.nan,data[...,l]])
# All other layers - mast too
for l in range(8,data.shape[2]):
data[...,l] = np.select([data[...,gt_strength]==0.0, data[...,gt_strength]>0.0], [np.nan,data[...,l]])
return data
def result_npy_to_tiff(npy_path, absolute, fix_nan, insert_deltas=True):
# data[...,4] = np.select([data[...,3]==0.0, data[...,3]>0.0], [np.nan,data[...,4]])
# data[...,5] = np.select([data[...,3]==0.0, data[...,3]>0.0], [np.nan,data[...,5]])
"""
@param npy_path full path to the npy file with 4-layer data (242,324,4) - nn_disparity(offset), target_disparity, gt disparity, gt strength
data will be written as 4-layer tiff, extension '.npy' replaced with '.tiff'
@param absolute - True - the first layer contains absolute disparity, False - difference from target_disparity
@param fix_nan - replace nan in target_disparity with 0 to apply offset, target_disparity will still contain nan
"""
data = result_npy_prepare(npy_path, absolute, fix_nan, insert_deltas)
tiff_path = npy_path.replace('.npy','.tiff')
data = data.transpose(2,0,1)
print("Saving results to TIFF: "+tiff_path)
imagej_tiffwriter.save(tiff_path,data[...,np.newaxis])
def eval_results(rslt_path, absolute,
min_disp = -0.1, #minimal GT disparity
max_disp = 20.0, # maximal GT disparity
......
......@@ -4,7 +4,7 @@ __license__ = "GPL-3.0+"
__email__ = "andrey@elphel.com"
#from numpy import float64
import numpy as np
#import numpy as np
import tensorflow as tf
def smoothLoss(out_batch, # [batch_size,(1..2)] tf_result
......@@ -76,7 +76,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
absolute_disparity = False, #when false there should be no activation on disparity output !
use_confidence = False,
lambda_conf_avg = 0.01,
lambda_conf_pwr = 0.1,
## lambda_conf_pwr = 0.1,
conf_pwr = 2.0,
gt_conf_offset = 0.08,
gt_conf_pwr = 1.0,
......@@ -90,14 +90,14 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
Here confidence should be after relU. Disparity - may be also if absolute, but no activation if output is residual disparity
"""
tf_lambda_conf_avg = tf.constant(lambda_conf_avg, dtype=tf.float32, name="tf_lambda_conf_avg")
tf_lambda_conf_pwr = tf.constant(lambda_conf_pwr, dtype=tf.float32, name="tf_lambda_conf_pwr")
tf_conf_pwr = tf.constant(conf_pwr, dtype=tf.float32, name="tf_conf_pwr")
## tf_lambda_conf_pwr = tf.constant(lambda_conf_pwr, dtype=tf.float32, name="tf_lambda_conf_pwr")
## tf_conf_pwr = tf.constant(conf_pwr, dtype=tf.float32, name="tf_conf_pwr")
tf_gt_conf_offset = tf.constant(gt_conf_offset, dtype=tf.float32, name="tf_gt_conf_offset")
tf_gt_conf_pwr = tf.constant(gt_conf_pwr, dtype=tf.float32, name="tf_gt_conf_pwr")
tf_num_tiles = tf.shape(gt_ds_batch)[0]
tf_0f = tf.constant(0.0, dtype=tf.float32, name="tf_0f")
tf_1f = tf.constant(1.0, dtype=tf.float32, name="tf_1f")
tf_maxw = tf.constant(1.0, dtype=tf.float32, name="tf_maxw")
## tf_maxw = tf.constant(1.0, dtype=tf.float32, name="tf_maxw")
tf_disp_diff_cap2= tf.constant(disp_diff_cap*disp_diff_cap, dtype=tf.float32, name="disp_diff_cap2")
tf_disp_diff_slope= tf.constant(disp_diff_slope, dtype=tf.float32, name="disp_diff_slope")
......
......@@ -4,7 +4,7 @@ __license__ = "GPL-3.0+"
__email__ = "andrey@elphel.com"
#from numpy import float64
import numpy as np
#import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
......@@ -13,15 +13,16 @@ def lrelu(x):
return tf.maximum(x*0.2,x)
# return tf.nn.relu(x)
def sym_inputs8(inp):
def sym_inputs8(inp, cluster_radius = 2):
"""
get input vector [?:4*9*9+1] (last being target_disparity) and reorder for horizontal flip,
vertical flip and transpose (8 variants, mode + 1 - hor, +2 - vert, +4 - transpose)
return same lengh, reordered
"""
tile_side = 2 * cluster_radius + 1
with tf.name_scope("sym_inputs8"):
td = inp[:,-1:] # tf.reshape(inp,[-1], name = "td")[-1]
inp_corr = tf.reshape(inp[:,:-1],[-1,4,TILE_SIDE,TILE_SIDE], name = "inp_corr")
inp_corr = tf.reshape(inp[:,:-1],[-1,4,tile_side,tile_side], name = "inp_corr")
inp_corr_h = tf.stack([-inp_corr [:,0,:,-1::-1], inp_corr [:,1,:,-1::-1], -inp_corr [:,3,:,-1::-1], -inp_corr [:,2,:,-1::-1]], axis=1, name = "inp_corr_h")
inp_corr_v = tf.stack([ inp_corr [:,0,-1::-1,:],-inp_corr [:,1,-1::-1,:], inp_corr [:,3,-1::-1,:], inp_corr [:,2,-1::-1,:]], axis=1, name = "inp_corr_v")
inp_corr_hv = tf.stack([ inp_corr_h[:,0,-1::-1,:],-inp_corr_h[:,1,-1::-1,:], inp_corr_h[:,3,-1::-1,:], inp_corr_h[:,2,-1::-1,:]], axis=1, name = "inp_corr_hv")
......@@ -52,7 +53,7 @@ def sym_inputs8(inp):
tf.concat([tf.reshape(inp_corr_vt, [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_vt"),
tf.concat([tf.reshape(inp_corr_hvt,[inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_hvt")]
"""
cl = 4 * TILE_SIDE * TILE_SIDE
cl = 4 * tile_side * tile_side
return [tf.concat([tf.reshape(inp_corr, [-1,cl]),td], axis=1,name = "out_corr"),
tf.concat([tf.reshape(inp_corr_h, [-1,cl]),td], axis=1,name = "out_corr_h"),
tf.concat([tf.reshape(inp_corr_v, [-1,cl]),td], axis=1,name = "out_corr_v"),
......@@ -64,12 +65,13 @@ def sym_inputs8(inp):
# inp_corr_h, inp_corr_v, inp_corr_hv, inp_corr_t, inp_corr_ht, inp_corr_vt, inp_corr_hvt]
def network_sub(input,
def network_sub(input_tensor,
input_global, #add to all layers (but first) if not None
layout,
reuse,
sym8 = False):
last_indx = None;
sym8 = False,
cluster_radius = 2):
# last_indx = None;
fc = []
inp_weights = []
for i, num_outs in enumerate (layout):
......@@ -81,9 +83,9 @@ def network_sub(input,
inp = tf.concat([fc[-1], input_global], axis = 1)
fc.append(slim.fully_connected(inp, num_outs, activation_fn=lrelu, scope='g_fc_sub'+str(i), reuse = reuse))
else:
inp = input
inp = input_tensor
if sym8:
inp8 = sym_inputs8(inp)
inp8 = sym_inputs8(inp, cluster_radius)
num_non_sum = num_outs % len(inp8) # if number of first layer outputs is not multiple of 8
num_sym8 = num_outs // len(inp8) # number of symmetrical groups
fc_sym = []
......@@ -111,12 +113,12 @@ def network_sub(input,
return fc[-1], inp_weights
def network_inter(input,
def network_inter(input_tensor,
input_global, #add to all layers (but first) if not None
layout,
reuse=False,
use_confidence=False):
last_indx = None;
#last_indx = None;
fc = []
for i, num_outs in enumerate (layout):
if num_outs:
......@@ -126,7 +128,7 @@ def network_inter(input,
else:
inp = tf.concat([fc[-1], input_global], axis = 1)
else:
inp = input
inp = input_tensor
fc.append(slim.fully_connected(inp, num_outs, activation_fn=lrelu, scope='g_fc_inter'+str(i), reuse = reuse))
if use_confidence:
fc_out = slim.fully_connected(fc[-1], 2, activation_fn=lrelu, scope='g_fc_inter_out', reuse = reuse)
......@@ -135,7 +137,7 @@ def network_inter(input,
#If using residual disparity, split last layer into 2 or remove activation and add rectifier to confidence only
return fc_out
def networks_siam(input, # now [?,9,325]-> [?,25,325]
def networks_siam(input_tensor, # now [?,9,325]-> [?,25,325]
input_global, # add to all layers (but first) if not None
layout1,
layout2,
......@@ -143,12 +145,13 @@ def networks_siam(input, # now [?,9,325]-> [?,25,325]
sym8 = False,
only_tile = None, # just for debugging - feed only data from the center sub-network
partials = None,
use_confidence=False):
use_confidence=False,
cluster_radius = 2):
center_index = (input.shape[1] - 1) // 2
center_index = (input_tensor.shape[1] - 1) // 2
with tf.name_scope("Siam_net"):
inp_weights = []
num_legs = input.shape[1] # == 25
num_legs = input_tensor.shape[1] # == 25
if partials is None:
partials = [[True] * num_legs]
inter_lists = [[] for _ in partials]
......@@ -159,11 +162,12 @@ def networks_siam(input, # now [?,9,325]-> [?,25,325]
ig = None
else:
ig =input_global[:,i,:]
ns, ns_weights = network_sub(input[:,i,:],
ns, ns_weights = network_sub(input_tensor[:,i,:],
ig, # input_global[:,i,:],
layout= layout1,
reuse= reuse,
sym8 = sym8)
sym8 = sym8,
cluster_radius = cluster_radius)
for n, partial in enumerate(partials):
if partial[i]:
inter_lists[n].append(ns)
......@@ -178,7 +182,7 @@ def networks_siam(input, # now [?,9,325]-> [?,25,325]
else:
ig =input_global[:,center_index,:]
outs.append(network_inter (input = tf.concat(inter_lists[n],
outs.append(network_inter (input_tensor = tf.concat(inter_lists[n],
axis=1,
name='inter_tensor'+str(n)),
input_global = [None, ig][inter_convergence], # optionally feed all convergence values (from each tile of a cluster)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment