Commit d939f570 authored by Andrey Filippov's avatar Andrey Filippov

combining low/high disparity cluster variance into the same batch

parent ded0d987
...@@ -46,8 +46,8 @@ RUN_TOT_AVG = 100 # last batches to average. Epoch is 307 training batche ...@@ -46,8 +46,8 @@ RUN_TOT_AVG = 100 # last batches to average. Epoch is 307 training batche
#BATCH_SIZE = 1080//9 # == 120 Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches #BATCH_SIZE = 1080//9 # == 120 Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches
BATCH_SIZE = 2*1080//9 # == 120 Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches BATCH_SIZE = 2*1080//9 # == 120 Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches
SHUFFLE_EPOCH = True SHUFFLE_EPOCH = True
NET_ARCH1 = 0 # 4 # 3 # overwrite with argv? NET_ARCH1 = 6 #0 # 4 # 3 # overwrite with argv?
NET_ARCH2 = 0 # 3 # overwrite with argv? NET_ARCH2 = 6 # 0 # 3 # overwrite with argv?
ONLY_TILE = None # 4 # None # 0 # 4# None # (remove all but center tile data), put None here for normal operation) ONLY_TILE = None # 4 # None # 0 # 4# None # (remove all but center tile data), put None here for normal operation)
...@@ -60,7 +60,9 @@ NN_LAYOUTS = {0:[0, 0, 0, 32, 20, 16], ...@@ -60,7 +60,9 @@ NN_LAYOUTS = {0:[0, 0, 0, 32, 20, 16],
1:[0, 0, 0, 256, 128, 64], 1:[0, 0, 0, 256, 128, 64],
2:[0, 128, 32, 32, 32, 16], 2:[0, 128, 32, 32, 32, 16],
3:[0, 0, 40, 32, 20, 16], 3:[0, 0, 40, 32, 20, 16],
4:[0, 0, 0, 0, 16, 16] 4:[0, 0, 0, 0, 16, 16],
5:[0, 0, 64, 32, 32, 16],
6:[0, 0, 32, 16, 16, 16],
} }
NN_LAYOUT1 = NN_LAYOUTS[NET_ARCH1] NN_LAYOUT1 = NN_LAYOUTS[NET_ARCH1]
NN_LAYOUT2 = NN_LAYOUTS[NET_ARCH2] NN_LAYOUT2 = NN_LAYOUTS[NET_ARCH2]
...@@ -609,12 +611,14 @@ G_opt=tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) ...@@ -609,12 +611,14 @@ G_opt=tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)
saver=tf.train.Saver() saver=tf.train.Saver()
ROOT_PATH = './attic/nn_ds_neibs_graph'+SUFFIX+"/" ROOT_PATH = './attic/nn_ds_neibs_graph'+SUFFIX+"/"
TRAIN_PATH = ROOT_PATH + 'train' TRAIN_PATH = ROOT_PATH + 'train'
TEST_PATH = ROOT_PATH + 'test' TEST_PATH = ROOT_PATH + 'test'
TEST_PATH1 = ROOT_PATH + 'test1'
# CLEAN OLD STAFF # CLEAN OLD STAFF
shutil.rmtree(TRAIN_PATH, ignore_errors=True) shutil.rmtree(TRAIN_PATH, ignore_errors=True)
shutil.rmtree(TEST_PATH, ignore_errors=True) shutil.rmtree(TEST_PATH, ignore_errors=True)
shutil.rmtree(TEST_PATH1, ignore_errors=True)
with tf.Session() as sess: with tf.Session() as sess:
...@@ -624,6 +628,7 @@ with tf.Session() as sess: ...@@ -624,6 +628,7 @@ with tf.Session() as sess:
merged = tf.summary.merge_all() merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(TRAIN_PATH, sess.graph) train_writer = tf.summary.FileWriter(TRAIN_PATH, sess.graph)
test_writer = tf.summary.FileWriter(TEST_PATH, sess.graph) test_writer = tf.summary.FileWriter(TEST_PATH, sess.graph)
test_writer1 = tf.summary.FileWriter(TEST_PATH1, sess.graph)
loss_train_hist= np.empty(dataset_train_size, dtype=np.float32) loss_train_hist= np.empty(dataset_train_size, dtype=np.float32)
loss_test_hist= np.empty(dataset_test_size, dtype=np.float32) loss_test_hist= np.empty(dataset_test_size, dtype=np.float32)
loss2_train_hist= np.empty(dataset_train_size, dtype=np.float32) loss2_train_hist= np.empty(dataset_train_size, dtype=np.float32)
...@@ -682,14 +687,16 @@ with tf.Session() as sess: ...@@ -682,14 +687,16 @@ with tf.Session() as sess:
train2_avg = np.average(loss2_train_hist).astype(np.float32) train2_avg = np.average(loss2_train_hist).astype(np.float32)
gtvar_train_avg = np.average(gtvar_train_hist).astype(np.float32) gtvar_train_avg = np.average(gtvar_train_hist).astype(np.float32)
for dataset_test in datasets_test: test_summaries = [0.0]*len(datasets_test)
for ntest,dataset_test in enumerate(datasets_test):
sess.run(iterator_tt.initializer, feed_dict={corr2d_train_placeholder: dataset_test['corr2d'], sess.run(iterator_tt.initializer, feed_dict={corr2d_train_placeholder: dataset_test['corr2d'],
target_disparity_train_placeholder: dataset_test['target_disparity'], target_disparity_train_placeholder: dataset_test['target_disparity'],
gt_ds_train_placeholder: dataset_test['gt_ds']}) gt_ds_train_placeholder: dataset_test['gt_ds']})
for i in range(dataset_test_size): for i in range(dataset_test_size):
try: try:
test_summary, G_loss_tested, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, gt_variance = sess.run( # test_summary, G_loss_tested, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, gt_variance = sess.run(
test_summaries[ntest], G_loss_tested, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, gt_variance = sess.run(
[merged, [merged,
G_loss, G_loss,
out, out,
...@@ -717,13 +724,15 @@ with tf.Session() as sess: ...@@ -717,13 +724,15 @@ with tf.Session() as sess:
# _,_=sess.run([tf_ph_G_loss,tf_ph_sq_diff],feed_dict={tf_ph_G_loss:test_avg, tf_ph_sq_diff:test2_avg}) # _,_=sess.run([tf_ph_G_loss,tf_ph_sq_diff],feed_dict={tf_ph_G_loss:test_avg, tf_ph_sq_diff:test2_avg})
train_writer.add_summary(train_summary, epoch) train_writer.add_summary(train_summary, epoch)
test_writer.add_summary(test_summary, epoch) test_writer.add_summary(test_summaries[0], epoch)
test_writer1.add_summary(test_summaries[1], epoch)
print_time("%d:%d -> %f %f (%f %f) dbg:%f %f"%(epoch,i,train_avg, test_avg,train2_avg, test2_avg, gtvar_train_avg, gtvar_test_avg)) print_time("%d:%d -> %f %f (%f %f) dbg:%f %f"%(epoch,i,train_avg, test_avg,train2_avg, test2_avg, gtvar_train_avg, gtvar_test_avg))
# Close writers # Close writers
train_writer.close() train_writer.close()
test_writer.close() test_writer.close()
test_writer1.close()
#reports error: Exception ignored in: <bound method BaseSession.__del__ of <tensorflow.python.client.session.Session object at 0x7efc5f720ef0>> if there is no print before exit() #reports error: Exception ignored in: <bound method BaseSession.__del__ of <tensorflow.python.client.session.Session object at 0x7efc5f720ef0>> if there is no print before exit()
print("All done") print("All done")
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment