Commit 2fbab86b authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

n-files

parent cca06172
...@@ -42,6 +42,7 @@ SHUFFLE_EPOCH = True ...@@ -42,6 +42,7 @@ SHUFFLE_EPOCH = True
NET_ARCH = 3 # overwrite with argv? NET_ARCH = 3 # overwrite with argv?
#DEBUG_PACK_TILES = True #DEBUG_PACK_TILES = True
SUFFIX=str(NET_ARCH)+ (["R","A"][ABSOLUTE_DISPARITY]) SUFFIX=str(NET_ARCH)+ (["R","A"][ABSOLUTE_DISPARITY])
MAX_TRAIN_FILES_TFR = 1
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python #http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
class bcolors: class bcolors:
HEADER = '\033[95m' HEADER = '\033[95m'
...@@ -126,6 +127,8 @@ if os.path.isdir(train_filenameTFR): ...@@ -126,6 +127,8 @@ if os.path.isdir(train_filenameTFR):
else: else:
train_filesTFR = [train_filenameTFR] train_filesTFR = [train_filenameTFR]
print("Train tfrecords: "+str(train_filesTFR))
# tfrecords' paths for testing # tfrecords' paths for testing
try: try:
test_filenameTFR = sys.argv[2] test_filenameTFR = sys.argv[2]
...@@ -139,12 +142,29 @@ if os.path.isdir(test_filenameTFR): ...@@ -139,12 +142,29 @@ if os.path.isdir(test_filenameTFR):
else: else:
test_filesTFR = [test_filenameTFR] test_filesTFR = [test_filenameTFR]
print("Test tfrecords: "+str(test_filesTFR))
# Now we are left with 2 lists - train and test list
n_allowed_train_filesTFR = min(MAX_TRAIN_FILES_TFR,len(train_filesTFR))
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
print_time("Importing training data... ", end="") print_time("Importing training data... ", end="")
corr2d_train, target_disparity_train, gt_ds_train = readTFRewcordsEpoch(train_filenameTFR)
corr2d_trains = [None]*n_allowed_train_filesTFR
target_disparity_trains = [None]*n_allowed_train_filesTFR
gt_ds_trains = [None]*n_allowed_train_filesTFR
# Load maximum files from the list
for i in range(n_allowed_train_filesTFR):
corr2d_trains[i], target_disparity_trains[i], gt_ds_trains[i] = readTFRewcordsEpoch(train_filesTFR[i])
corr2d_train = corr2d_trains[0]
target_disparity_train = target_disparity_trains[0]
gt_ds_train = gt_ds_trains[0]
print_time(" Done") print_time(" Done")
corr2d_train_placeholder = tf.placeholder(corr2d_train.dtype, (None,324)) # corr2d_train.shape) corr2d_train_placeholder = tf.placeholder(corr2d_train.dtype, (None,324)) # corr2d_train.shape)
...@@ -390,14 +410,19 @@ shutil.rmtree(TEST_PATH, ignore_errors=True) ...@@ -390,14 +410,19 @@ shutil.rmtree(TEST_PATH, ignore_errors=True)
# threading # threading
from threading import Thread from threading import Thread
thr_result = []
def read_new_tfrecord_file(filename,result): def read_new_tfrecord_file(filename,result):
global thr_result
a,b,c = readTFRewcordsEpoch(filename) a,b,c = readTFRewcordsEpoch(filename)
result = [a,b,c] #result = [a,b,c]
result.append(a)
result.append(b)
result.append(c)
print("Loaded new tfrecord file: "+str(filename)) print("Loaded new tfrecord file: "+str(filename))
tfrecord_filename = train_filenameTFR train_record_index_counter = 0
train_file_index = 0
tfrecord_file_counter = 0
with tf.Session() as sess: with tf.Session() as sess:
...@@ -418,26 +443,44 @@ with tf.Session() as sess: ...@@ -418,26 +443,44 @@ with tf.Session() as sess:
for epoch in range(EPOCHS_TO_RUN): for epoch in range(EPOCHS_TO_RUN):
if epoch%30==0: train_file_index = epoch%n_allowed_train_filesTFR
print_time("Time to begin loading a new tfrecord file") if epoch%10==0:
# if there are more files than python3 memory allows
if (n_allowed_train_filesTFR<len(train_filesTFR)):
# circular loading?
tmp_train_index = (n_allowed_train_filesTFR+train_record_index_counter)%len(train_filesTFR)
# wait for old thread # wait for old thread
if epoch!=0: if epoch!=0:
if thr.is_alive():
print_time("Waiting until tfrecord gets loaded")
thr.join() thr.join()
# do replacement
## remove the first
corr2d_trains.pop(0)
target_disparity_trains.pop(0)
gt_ds_trains.pop(0)
## append
corr2d_trains.append(thr_result[0])
target_disparity_trains.append(thr_result[1])
gt_ds_trains.append(thr_result[2])
print_time("Time to begin loading a new tfrecord file")
# new thread # new thread
thr_result = [] thr_result = []
thr = Thread(target=read_new_tfrecord_file, args=(tfrecord_filename,thr_result)) thr = Thread(target=read_new_tfrecord_file, args=(train_filesTFR[tmp_train_index],thr_result))
# start # start
thr.start() thr.start()
train_record_index_counter += 1
# if SHUFFLE_EPOCH: # if SHUFFLE_EPOCH:
# dataset_train = dataset_train.shuffle(buffer_size=10000) # dataset_train = dataset_train.shuffle(buffer_size=10000)
sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_train,
target_disparity_train_placeholder: target_disparity_train, sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_trains[train_file_index],
gt_ds_train_placeholder: gt_ds_train}) target_disparity_train_placeholder: target_disparity_trains[train_file_index],
gt_ds_train_placeholder: gt_ds_trains[train_file_index]})
for i in range(dataset_train_size): for i in range(dataset_train_size):
try: try:
train_summary,_, G_loss_trained, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, corr2d325_out = sess.run( train_summary,_, G_loss_trained, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, corr2d325_out = sess.run(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment