Commit 2fbab86b authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

n-files

parent cca06172
......@@ -42,6 +42,7 @@ SHUFFLE_EPOCH = True
NET_ARCH = 3 # overwrite with argv?
#DEBUG_PACK_TILES = True
SUFFIX=str(NET_ARCH)+ (["R","A"][ABSOLUTE_DISPARITY])
MAX_TRAIN_FILES_TFR = 1
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
class bcolors:
HEADER = '\033[95m'
......@@ -125,6 +126,8 @@ if os.path.isdir(train_filenameTFR):
train_filenameTFR = train_filesTFR[0]
else:
train_filesTFR = [train_filenameTFR]
print("Train tfrecords: "+str(train_filesTFR))
# tfrecords' paths for testing
try:
......@@ -138,13 +141,30 @@ if os.path.isdir(test_filenameTFR):
test_filenameTFR = test_filesTFR[0]
else:
test_filesTFR = [test_filenameTFR]
print("Test tfrecords: "+str(test_filesTFR))
# Now we are left with 2 lists - train and test list
n_allowed_train_filesTFR = min(MAX_TRAIN_FILES_TFR,len(train_filesTFR))
import tensorflow as tf
import tensorflow.contrib.slim as slim
print_time("Importing training data... ", end="")
corr2d_train, target_disparity_train, gt_ds_train = readTFRewcordsEpoch(train_filenameTFR)
corr2d_trains = [None]*n_allowed_train_filesTFR
target_disparity_trains = [None]*n_allowed_train_filesTFR
gt_ds_trains = [None]*n_allowed_train_filesTFR
# Load maximum files from the list
for i in range(n_allowed_train_filesTFR):
corr2d_trains[i], target_disparity_trains[i], gt_ds_trains[i] = readTFRewcordsEpoch(train_filesTFR[i])
corr2d_train = corr2d_trains[0]
target_disparity_train = target_disparity_trains[0]
gt_ds_train = gt_ds_trains[0]
print_time(" Done")
corr2d_train_placeholder = tf.placeholder(corr2d_train.dtype, (None,324)) # corr2d_train.shape)
......@@ -390,14 +410,19 @@ shutil.rmtree(TEST_PATH, ignore_errors=True)
# threading
from threading import Thread
thr_result = []
def read_new_tfrecord_file(filename,result):
global thr_result
a,b,c = readTFRewcordsEpoch(filename)
result = [a,b,c]
#result = [a,b,c]
result.append(a)
result.append(b)
result.append(c)
print("Loaded new tfrecord file: "+str(filename))
tfrecord_filename = train_filenameTFR
tfrecord_file_counter = 0
train_record_index_counter = 0
train_file_index = 0
with tf.Session() as sess:
......@@ -418,26 +443,44 @@ with tf.Session() as sess:
for epoch in range(EPOCHS_TO_RUN):
if epoch%30==0:
print_time("Time to begin loading a new tfrecord file")
# wait for old thread
if epoch!=0:
thr.join()
train_file_index = epoch%n_allowed_train_filesTFR
if epoch%10==0:
# new thread
thr_result = []
thr = Thread(target=read_new_tfrecord_file, args=(tfrecord_filename,thr_result))
# if there are more files than python3 memory allows
if (n_allowed_train_filesTFR<len(train_filesTFR)):
# circular loading?
tmp_train_index = (n_allowed_train_filesTFR+train_record_index_counter)%len(train_filesTFR)
# wait for old thread
if epoch!=0:
if thr.is_alive():
print_time("Waiting until tfrecord gets loaded")
thr.join()
# do replacement
## remove the first
corr2d_trains.pop(0)
target_disparity_trains.pop(0)
gt_ds_trains.pop(0)
## append
corr2d_trains.append(thr_result[0])
target_disparity_trains.append(thr_result[1])
gt_ds_trains.append(thr_result[2])
# start
thr.start()
print_time("Time to begin loading a new tfrecord file")
# new thread
thr_result = []
thr = Thread(target=read_new_tfrecord_file, args=(train_filesTFR[tmp_train_index],thr_result))
# start
thr.start()
train_record_index_counter += 1
# if SHUFFLE_EPOCH:
# dataset_train = dataset_train.shuffle(buffer_size=10000)
sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_train,
target_disparity_train_placeholder: target_disparity_train,
gt_ds_train_placeholder: gt_ds_train})
sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_trains[train_file_index],
target_disparity_train_placeholder: target_disparity_trains[train_file_index],
gt_ds_train_placeholder: gt_ds_trains[train_file_index]})
for i in range(dataset_train_size):
try:
train_summary,_, G_loss_trained, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, corr2d325_out = sess.run(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment