Commit c43f15ba authored by Andrey Filippov's avatar Andrey Filippov

Merge branch 'master' of git@git.elphel.com:Elphel/python3-imagej-tiff.git

parents 2b5c84c2 a8911582
......@@ -30,7 +30,8 @@ FILES_PER_SCENE = 5 # number of random offset files for the scene to select f
#MIN_BATCH_CHOICES = 10 # minimal number of tiles in a file for each bin to select from
#MAX_BATCH_FILES = 10 #maximal number of files to use in a batch
MAX_EPOCH = 500
LR = 1e-4 # learning rate
#LR = 1e-4 # learning rate
LR = 1e-3 # learning rate
USE_CONFIDENCE = False
ABSOLUTE_DISPARITY = False # True # False
DEBUG_PLT_LOSS = True
......@@ -42,6 +43,7 @@ SHUFFLE_EPOCH = True
NET_ARCH = 3 # overwrite with argv?
#DEBUG_PACK_TILES = True
SUFFIX=str(NET_ARCH)+ (["R","A"][ABSOLUTE_DISPARITY])
MAX_TRAIN_FILES_TFR = 4
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
class bcolors:
HEADER = '\033[95m'
......@@ -111,22 +113,63 @@ def read_and_decode(filename_queue):
#http://adventuresinmachinelearning.com/introduction-tensorflow-queuing/
#Main code
# Main code
# tfrecords' paths for training
try:
train_filenameTFR = sys.argv[1]
except IndexError:
train_filenameTFR = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/train.tfrecords"
# if the path is a directory
if os.path.isdir(train_filenameTFR):
train_filesTFR = glob.glob(train_filenameTFR+"/*train-*.tfrecords")
train_filenameTFR = train_filesTFR[0]
else:
train_filesTFR = [train_filenameTFR]
train_filesTFR.sort()
print("Train tfrecords: "+str(train_filesTFR))
# tfrecords' paths for testing
try:
test_filenameTFR = sys.argv[2]
except IndexError:
test_filenameTFR = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/test.tfrecords"
#FILES_PER_SCENE
# if the path is a directory
if os.path.isdir(test_filenameTFR):
test_filesTFR = glob.glob(test_filenameTFR+"/test_*.tfrecords")
test_filenameTFR = test_filesTFR[0]
else:
test_filesTFR = [test_filenameTFR]
test_filesTFR.sort()
print("Test tfrecords: "+str(test_filesTFR))
# Now we are left with 2 lists - train and test list
n_allowed_train_filesTFR = min(MAX_TRAIN_FILES_TFR,len(train_filesTFR))
import tensorflow as tf
import tensorflow.contrib.slim as slim
print_time("Importing training data... ", end="")
corr2d_train, target_disparity_train, gt_ds_train = readTFRewcordsEpoch(train_filenameTFR)
#print_time("Importing training data... ", end="")
print_time("Importing training data... ")
corr2d_trains = [None]*n_allowed_train_filesTFR
target_disparity_trains = [None]*n_allowed_train_filesTFR
gt_ds_trains = [None]*n_allowed_train_filesTFR
# Load maximum files from the list
for i in range(n_allowed_train_filesTFR):
corr2d_trains[i], target_disparity_trains[i], gt_ds_trains[i] = readTFRewcordsEpoch(train_filesTFR[i])
print_time("Parsed "+train_filesTFR[i])
corr2d_train = corr2d_trains[0]
target_disparity_train = target_disparity_trains[0]
gt_ds_train = gt_ds_trains[0]
print_time(" Done")
corr2d_train_placeholder = tf.placeholder(corr2d_train.dtype, (None,324)) # corr2d_train.shape)
......@@ -141,7 +184,7 @@ dataset_train_size = len(corr2d_train)
print_time("dataset_train.output_types "+str(dataset_train.output_types)+", dataset_train.output_shapes "+str(dataset_train.output_shapes)+", number of elements="+str(dataset_train_size))
dataset_train = dataset_train.batch(BATCH_SIZE)
#dataset_train = dataset_train.prefetch(BATCH_SIZE)
dataset_train = dataset_train.prefetch(BATCH_SIZE)
dataset_train_size //= BATCH_SIZE
print("dataset_train.output_types "+str(dataset_train.output_types)+", dataset_train.output_shapes "+str(dataset_train.output_shapes)+", number of elements="+str(dataset_train_size))
......@@ -177,11 +220,17 @@ def network_fc_simple(input, arch = 0):
fc = []
for i, num_outs in enumerate (layout):
if num_outs:
if fc:
inp = fc[-1]
else:
if fc:
inp = fc[-1]
else:
inp = input
fc.append(slim.fully_connected(inp, num_outs, activation_fn=lrelu,scope='g_fc'+str(i)))
fc.append(slim.fully_connected(inp, num_outs, activation_fn=lrelu,scope='g_fc'+str(i)))
with tf.variable_scope('g_fc'+str(i)+'/fully_connected',reuse=tf.AUTO_REUSE):
w = tf.get_variable('weights',shape=[inp.shape[1],num_outs])
b = tf.get_variable('weights',shape=[inp.shape[1],num_outs])
tf.summary.histogram("weights",w)
tf.summary.histogram("biases",b)
"""
# fc1 = slim.fully_connected(input, 256, activation_fn=lrelu,scope='g_fc1')
# fc2 = slim.fully_connected(fc1, 128, activation_fn=lrelu,scope='g_fc2')
......@@ -195,8 +244,21 @@ def network_fc_simple(input, arch = 0):
if USE_CONFIDENCE:
fc_out = slim.fully_connected(fc[-1], 2, activation_fn=lrelu,scope='g_fc_out')
with tf.variable_scope('g_fc_out',reuse=tf.AUTO_REUSE):
w = tf.get_variable('weights',shape=[fc[-1].shape[1],2])
b = tf.get_variable('biases',shape=[fc[-1].shape[1],2])
tf.summary.histogram("weights",w)
tf.summary.histogram("biases",b)
else:
fc_out = slim.fully_connected(fc[-1], 1, activation_fn=None,scope='g_fc_out')
with tf.variable_scope('g_fc_out',reuse=tf.AUTO_REUSE):
w = tf.get_variable('weights',shape=[fc[-1].shape[1],1])
b = tf.get_variable('biases',shape=[1])
tf.summary.histogram("weights",w)
tf.summary.histogram("biases",b)
#If using residual disparity, split last layer into 2 or remove activation and add rectifier to confidence only
return fc_out
......@@ -357,7 +419,8 @@ with tf.name_scope('epoch_average'):
t_vars=tf.trainable_variables()
lr=tf.placeholder(tf.float32)
G_opt=tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)
#G_opt=tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)
G_opt=tf.train.AdamOptimizer(learning_rate=lr).minimize(_cost1)
saver=tf.train.Saver()
......@@ -372,12 +435,19 @@ shutil.rmtree(TEST_PATH, ignore_errors=True)
# threading
from threading import Thread
thr_result = []
def read_new_tfrecord_file(filename,result):
global thr_result
a,b,c = readTFRewcordsEpoch(filename)
result = [a,b,c]
#result = [a,b,c]
result.append(a)
result.append(b)
result.append(c)
print("Loaded new tfrecord file: "+str(filename))
tfrecord_filename = train_filenameTFR
train_record_index_counter = 0
train_file_index = 0
with tf.Session() as sess:
......@@ -398,26 +468,45 @@ with tf.Session() as sess:
for epoch in range(EPOCHS_TO_RUN):
if epoch%30==0:
print_time("Time to begin loading a new tfrecord file")
# wait for old thread
if epoch!=0:
thr.join()
train_file_index = epoch%n_allowed_train_filesTFR
print("train_file_index: "+str(train_file_index))
if epoch%10==0:
# new thread
thr_result = []
thr = Thread(target=read_new_tfrecord_file, args=(tfrecord_filename,thr_result))
# if there are more files than python3 memory allows
if (n_allowed_train_filesTFR<len(train_filesTFR)):
# circular loading?
tmp_train_index = (n_allowed_train_filesTFR+train_record_index_counter)%len(train_filesTFR)
# wait for old thread
if epoch!=0:
if thr.is_alive():
print_time("Waiting until tfrecord gets loaded")
thr.join()
# do replacement
## remove the first
corr2d_trains.pop(0)
target_disparity_trains.pop(0)
gt_ds_trains.pop(0)
## append
corr2d_trains.append(thr_result[0])
target_disparity_trains.append(thr_result[1])
gt_ds_trains.append(thr_result[2])
# start
thr.start()
print_time("Time to begin loading a new tfrecord file")
# new thread
thr_result = []
thr = Thread(target=read_new_tfrecord_file, args=(train_filesTFR[tmp_train_index],thr_result))
# start
thr.start()
train_record_index_counter += 1
# if SHUFFLE_EPOCH:
# dataset_train = dataset_train.shuffle(buffer_size=10000)
sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_train,
target_disparity_train_placeholder: target_disparity_train,
gt_ds_train_placeholder: gt_ds_train})
sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_trains[train_file_index],
target_disparity_train_placeholder: target_disparity_trains[train_file_index],
gt_ds_train_placeholder: gt_ds_trains[train_file_index]})
for i in range(dataset_train_size):
try:
train_summary,_, G_loss_trained, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, corr2d325_out = sess.run(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment