Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
python3-imagej-tiff
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Elphel
python3-imagej-tiff
Commits
2fbab86b
Commit
2fbab86b
authored
Aug 07, 2018
by
Oleg Dzhimiev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
n-files
parent
cca06172
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
20 deletions
+63
-20
nn_ds_inmem4_tmp.py
nn_ds_inmem4_tmp.py
+63
-20
No files found.
nn_ds_inmem4_tmp.py
View file @
2fbab86b
...
...
@@ -42,6 +42,7 @@ SHUFFLE_EPOCH = True
NET_ARCH
=
3
# overwrite with argv?
#DEBUG_PACK_TILES = True
SUFFIX
=
str
(
NET_ARCH
)
+
([
"R"
,
"A"
][
ABSOLUTE_DISPARITY
])
MAX_TRAIN_FILES_TFR
=
1
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
class
bcolors
:
HEADER
=
'
\033
[95m'
...
...
@@ -126,6 +127,8 @@ if os.path.isdir(train_filenameTFR):
else
:
train_filesTFR
=
[
train_filenameTFR
]
print
(
"Train tfrecords: "
+
str
(
train_filesTFR
))
# tfrecords' paths for testing
try
:
test_filenameTFR
=
sys
.
argv
[
2
]
...
...
@@ -139,12 +142,29 @@ if os.path.isdir(test_filenameTFR):
else
:
test_filesTFR
=
[
test_filenameTFR
]
print
(
"Test tfrecords: "
+
str
(
test_filesTFR
))
# Now we are left with 2 lists - train and test list
n_allowed_train_filesTFR
=
min
(
MAX_TRAIN_FILES_TFR
,
len
(
train_filesTFR
))
import
tensorflow
as
tf
import
tensorflow.contrib.slim
as
slim
print_time
(
"Importing training data... "
,
end
=
""
)
corr2d_train
,
target_disparity_train
,
gt_ds_train
=
readTFRewcordsEpoch
(
train_filenameTFR
)
corr2d_trains
=
[
None
]
*
n_allowed_train_filesTFR
target_disparity_trains
=
[
None
]
*
n_allowed_train_filesTFR
gt_ds_trains
=
[
None
]
*
n_allowed_train_filesTFR
# Load maximum files from the list
for
i
in
range
(
n_allowed_train_filesTFR
):
corr2d_trains
[
i
],
target_disparity_trains
[
i
],
gt_ds_trains
[
i
]
=
readTFRewcordsEpoch
(
train_filesTFR
[
i
])
corr2d_train
=
corr2d_trains
[
0
]
target_disparity_train
=
target_disparity_trains
[
0
]
gt_ds_train
=
gt_ds_trains
[
0
]
print_time
(
" Done"
)
corr2d_train_placeholder
=
tf
.
placeholder
(
corr2d_train
.
dtype
,
(
None
,
324
))
# corr2d_train.shape)
...
...
@@ -390,14 +410,19 @@ shutil.rmtree(TEST_PATH, ignore_errors=True)
# threading
from
threading
import
Thread
thr_result
=
[]
def
read_new_tfrecord_file
(
filename
,
result
):
global
thr_result
a
,
b
,
c
=
readTFRewcordsEpoch
(
filename
)
result
=
[
a
,
b
,
c
]
#result = [a,b,c]
result
.
append
(
a
)
result
.
append
(
b
)
result
.
append
(
c
)
print
(
"Loaded new tfrecord file: "
+
str
(
filename
))
tfrecord_filename
=
train_filenameTFR
tfrecord_file_counter
=
0
train_record_index_counter
=
0
train_file_index
=
0
with
tf
.
Session
()
as
sess
:
...
...
@@ -418,26 +443,44 @@ with tf.Session() as sess:
for
epoch
in
range
(
EPOCHS_TO_RUN
):
if
epoch
%
30
==
0
:
train_file_index
=
epoch
%
n_allowed_train_filesTFR
print_time
(
"Time to begin loading a new tfrecord file"
)
if
epoch
%
10
==
0
:
# if there are more files than python3 memory allows
if
(
n_allowed_train_filesTFR
<
len
(
train_filesTFR
)):
# circular loading?
tmp_train_index
=
(
n_allowed_train_filesTFR
+
train_record_index_counter
)
%
len
(
train_filesTFR
)
# wait for old thread
if
epoch
!=
0
:
if
thr
.
is_alive
():
print_time
(
"Waiting until tfrecord gets loaded"
)
thr
.
join
()
# do replacement
## remove the first
corr2d_trains
.
pop
(
0
)
target_disparity_trains
.
pop
(
0
)
gt_ds_trains
.
pop
(
0
)
## append
corr2d_trains
.
append
(
thr_result
[
0
])
target_disparity_trains
.
append
(
thr_result
[
1
])
gt_ds_trains
.
append
(
thr_result
[
2
])
print_time
(
"Time to begin loading a new tfrecord file"
)
# new thread
thr_result
=
[]
thr
=
Thread
(
target
=
read_new_tfrecord_file
,
args
=
(
tfrecord_filename
,
thr_result
))
thr
=
Thread
(
target
=
read_new_tfrecord_file
,
args
=
(
train_filesTFR
[
tmp_train_index
],
thr_result
))
# start
thr
.
start
()
train_record_index_counter
+=
1
# if SHUFFLE_EPOCH:
# dataset_train = dataset_train.shuffle(buffer_size=10000)
sess
.
run
(
iterator_train
.
initializer
,
feed_dict
=
{
corr2d_train_placeholder
:
corr2d_train
,
target_disparity_train_placeholder
:
target_disparity_train
,
gt_ds_train_placeholder
:
gt_ds_train
})
sess
.
run
(
iterator_train
.
initializer
,
feed_dict
=
{
corr2d_train_placeholder
:
corr2d_trains
[
train_file_index
],
target_disparity_train_placeholder
:
target_disparity_trains
[
train_file_index
],
gt_ds_train_placeholder
:
gt_ds_trains
[
train_file_index
]})
for
i
in
range
(
dataset_train_size
):
try
:
train_summary
,
_
,
G_loss_trained
,
output
,
disp_slice
,
d_gt_slice
,
out_diff
,
out_diff2
,
w_norm
,
out_wdiff2
,
out_cost1
,
corr2d325_out
=
sess
.
run
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment