Commit cca06172 authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

Merge branch 'master' of git.elphel.com:Elphel/python3-imagej-tiff

parents 1d988652 c64289c7
This diff is collapsed.
...@@ -36,6 +36,7 @@ ABSOLUTE_DISPARITY = False # True # False ...@@ -36,6 +36,7 @@ ABSOLUTE_DISPARITY = False # True # False
DEBUG_PLT_LOSS = True DEBUG_PLT_LOSS = True
FEATURES_PER_TILE = 324 FEATURES_PER_TILE = 324
EPOCHS_TO_RUN = 10000 #0 EPOCHS_TO_RUN = 10000 #0
EPOCHS_SAME_FILE = 20
RUN_TOT_AVG = 100 # last batches to average. Epoch is 307 training batches RUN_TOT_AVG = 100 # last batches to average. Epoch is 307 training batches
BATCH_SIZE = 1000 # Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches BATCH_SIZE = 1000 # Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches
SHUFFLE_EPOCH = True SHUFFLE_EPOCH = True
...@@ -115,12 +116,13 @@ def read_and_decode(filename_queue): ...@@ -115,12 +116,13 @@ def read_and_decode(filename_queue):
try: try:
train_filenameTFR = sys.argv[1] train_filenameTFR = sys.argv[1]
except IndexError: except IndexError:
train_filenameTFR = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/train.tfrecords" train_filenameTFR = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/train_00.tfrecords"
try: try:
test_filenameTFR = sys.argv[2] test_filenameTFR = sys.argv[2]
except IndexError: except IndexError:
test_filenameTFR = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/test.tfrecords" test_filenameTFR = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/test.tfrecords"
#FILES_PER_SCENE #FILES_PER_SCENE
train_filenameTFR1 = "/mnt/dde6f983-d149-435e-b4a2-88749245cc6c/home/eyesis/x3d_data/data_sets/tf_data/train_01.tfrecords"
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
...@@ -128,6 +130,13 @@ import tensorflow.contrib.slim as slim ...@@ -128,6 +130,13 @@ import tensorflow.contrib.slim as slim
print_time("Importing training data... ", end="") print_time("Importing training data... ", end="")
corr2d_train, target_disparity_train, gt_ds_train = readTFRewcordsEpoch(train_filenameTFR) corr2d_train, target_disparity_train, gt_ds_train = readTFRewcordsEpoch(train_filenameTFR)
print_time(" Done") print_time(" Done")
print_time("Importing second training data... ", end="")
corr2d_train1, target_disparity_train1, gt_ds_train1 = readTFRewcordsEpoch(train_filenameTFR1)
print_time(" Done")
corr2d_trains = [corr2d_train, corr2d_train1]
target_disparity_trains = [target_disparity_train, target_disparity_train1]
gt_ds_trains = [gt_ds_train, gt_ds_train1]
corr2d_train_placeholder = tf.placeholder(corr2d_train.dtype, (None,324)) # corr2d_train.shape) corr2d_train_placeholder = tf.placeholder(corr2d_train.dtype, (None,324)) # corr2d_train.shape)
target_disparity_train_placeholder = tf.placeholder(target_disparity_train.dtype, (None,1)) #target_disparity_train.shape) target_disparity_train_placeholder = tf.placeholder(target_disparity_train.dtype, (None,1)) #target_disparity_train.shape)
...@@ -382,14 +391,14 @@ with tf.Session() as sess: ...@@ -382,14 +391,14 @@ with tf.Session() as sess:
train2_avg = 0.0 train2_avg = 0.0
test_avg = 0.0 test_avg = 0.0
test2_avg = 0.0 test2_avg = 0.0
for epoch in range (EPOCHS_TO_RUN):
for epoch in range(EPOCHS_TO_RUN): # file_index = (epoch // 20) % 2
file_index = (epoch // 1) % 2
# if SHUFFLE_EPOCH: # if SHUFFLE_EPOCH:
# dataset_train = dataset_train.shuffle(buffer_size=10000) # dataset_train = dataset_train.shuffle(buffer_size=10000)
sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_train, sess.run(iterator_train.initializer, feed_dict={corr2d_train_placeholder: corr2d_trains[file_index],
target_disparity_train_placeholder: target_disparity_train, target_disparity_train_placeholder: target_disparity_trains[file_index],
gt_ds_train_placeholder: gt_ds_train}) gt_ds_train_placeholder: gt_ds_trains[file_index]})
for i in range(dataset_train_size): for i in range(dataset_train_size):
try: try:
train_summary,_, G_loss_trained, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, corr2d325_out = sess.run( train_summary,_, G_loss_trained, output, disp_slice, d_gt_slice, out_diff, out_diff2, w_norm, out_wdiff2, out_cost1, corr2d325_out = sess.run(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment