Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
python3-imagej-tiff
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Elphel
python3-imagej-tiff
Commits
07c7d46a
Commit
07c7d46a
authored
Aug 10, 2018
by
Oleg Dzhimiev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
display weights
parent
a8911582
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
160 additions
and
79 deletions
+160
-79
nn_ds_inmem4_tmp.py
nn_ds_inmem4_tmp.py
+160
-79
No files found.
nn_ds_inmem4_tmp.py
View file @
07c7d46a
#!/usr/bin/env python3
#!/usr/bin/env python3
from
numpy
import
float64
from
numpy
import
float64
from
_stat
import
S_IEXEC
__copyright__
=
"Copyright 2018, Elphel, Inc."
__copyright__
=
"Copyright 2018, Elphel, Inc."
__license__
=
"GPL-3.0+"
__license__
=
"GPL-3.0+"
...
@@ -27,23 +28,23 @@ DEBUG_LEVEL= 1
...
@@ -27,23 +28,23 @@ DEBUG_LEVEL= 1
DISP_BATCH_BINS
=
20
# Number of batch disparity bins
DISP_BATCH_BINS
=
20
# Number of batch disparity bins
STR_BATCH_BINS
=
10
# Number of batch strength bins
STR_BATCH_BINS
=
10
# Number of batch strength bins
FILES_PER_SCENE
=
5
# number of random offset files for the scene to select from (0 - use all available)
FILES_PER_SCENE
=
5
# number of random offset files for the scene to select from (0 - use all available)
#MIN_BATCH_CHOICES = 10 # minimal number of tiles in a file for each bin to select from
#MIN_BATCH_CHOICES = 10 # minimal number of tiles in a file for each bin to select from
#MAX_BATCH_FILES = 10 #maximal number of files to use in a batch
#MAX_BATCH_FILES = 10 #maximal number of files to use in a batch
MAX_EPOCH
=
500
MAX_EPOCH
=
500
#LR = 1e-4 # learning rate
#LR = 1e-4 # learning rate
LR
=
1e-3
# learning rate
LR
=
1e-3
# learning rate
USE_CONFIDENCE
=
False
USE_CONFIDENCE
=
False
ABSOLUTE_DISPARITY
=
Fals
e
# True # False
ABSOLUTE_DISPARITY
=
Tru
e
# True # False
DEBUG_PLT_LOSS
=
True
DEBUG_PLT_LOSS
=
True
FEATURES_PER_TILE
=
324
FEATURES_PER_TILE
=
324
EPOCHS_TO_RUN
=
10000
#0
EPOCHS_TO_RUN
=
10000
#0
RUN_TOT_AVG
=
100
# last batches to average. Epoch is 307 training batches
RUN_TOT_AVG
=
100
# last batches to average. Epoch is 307 training batches
BATCH_SIZE
=
1000
# Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches
BATCH_SIZE
=
1000
# Each batch of tiles has balanced D/S tiles, shuffled batches but not inside batches
SHUFFLE_EPOCH
=
True
SHUFFLE_EPOCH
=
True
NET_ARCH
=
3
# overwrite with argv?
NET_ARCH
=
0
# overwrite with argv?
#DEBUG_PACK_TILES = True
#DEBUG_PACK_TILES = True
SUFFIX
=
str
(
NET_ARCH
)
+
([
"R"
,
"A"
][
ABSOLUTE_DISPARITY
])
SUFFIX
=
str
(
NET_ARCH
)
+
([
"R"
,
"A"
][
ABSOLUTE_DISPARITY
])
MAX_TRAIN_FILES_TFR
=
4
MAX_TRAIN_FILES_TFR
=
6
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
class
bcolors
:
class
bcolors
:
HEADER
=
'
\033
[95m'
HEADER
=
'
\033
[95m'
...
@@ -82,7 +83,7 @@ def readTFRewcordsEpoch(train_filename):
...
@@ -82,7 +83,7 @@ def readTFRewcordsEpoch(train_filename):
corr2d
=
np
.
array
(
corr2d_list
)
corr2d
=
np
.
array
(
corr2d_list
)
target_disparity
=
np
.
array
(
target_disparity_list
)
target_disparity
=
np
.
array
(
target_disparity_list
)
gt_ds
=
np
.
array
(
gt_ds_list
)
gt_ds
=
np
.
array
(
gt_ds_list
)
return
corr2d
,
target_disparity
,
gt_ds
return
corr2d
,
target_disparity
,
gt_ds
#from http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
#from http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
def
read_and_decode
(
filename_queue
):
def
read_and_decode
(
filename_queue
):
...
@@ -102,7 +103,7 @@ def read_and_decode(filename_queue):
...
@@ -102,7 +103,7 @@ def read_and_decode(filename_queue):
gt_ds
=
tf
.
cast
(
features
[
'gt_ds'
],
tf
.
float32
)
# tf.decode_raw(features['gt_ds'], tf.float32)
gt_ds
=
tf
.
cast
(
features
[
'gt_ds'
],
tf
.
float32
)
# tf.decode_raw(features['gt_ds'], tf.float32)
in_features
=
tf
.
concat
([
corr2d
,
target_disparity
],
0
)
in_features
=
tf
.
concat
([
corr2d
,
target_disparity
],
0
)
# still some nan-s in correlation data?
# still some nan-s in correlation data?
# in_features_clean = tf.where(tf.is_nan(in_features), tf.zeros_like(in_features), in_features)
# in_features_clean = tf.where(tf.is_nan(in_features), tf.zeros_like(in_features), in_features)
# corr2d_out, target_disparity_out, gt_ds_out = tf.train.shuffle_batch( [in_features_clean, target_disparity, gt_ds],
# corr2d_out, target_disparity_out, gt_ds_out = tf.train.shuffle_batch( [in_features_clean, target_disparity, gt_ds],
corr2d_out
,
target_disparity_out
,
gt_ds_out
=
tf
.
train
.
shuffle_batch
(
[
in_features
,
target_disparity
,
gt_ds
],
corr2d_out
,
target_disparity_out
,
gt_ds_out
=
tf
.
train
.
shuffle_batch
(
[
in_features
,
target_disparity
,
gt_ds
],
batch_size
=
1000
,
# 2,
batch_size
=
1000
,
# 2,
...
@@ -124,14 +125,14 @@ except IndexError:
...
@@ -124,14 +125,14 @@ except IndexError:
# if the path is a directory
# if the path is a directory
if
os
.
path
.
isdir
(
train_filenameTFR
):
if
os
.
path
.
isdir
(
train_filenameTFR
):
train_filesTFR
=
glob
.
glob
(
train_filenameTFR
+
"/*train-*.tfrecords"
)
train_filesTFR
=
glob
.
glob
(
train_filenameTFR
+
"/*train-*.tfrecords"
)
train_filenameTFR
=
train_filesTFR
[
0
]
train_filenameTFR
=
train_filesTFR
[
0
]
else
:
else
:
train_filesTFR
=
[
train_filenameTFR
]
train_filesTFR
=
[
train_filenameTFR
]
train_filesTFR
.
sort
()
train_filesTFR
.
sort
()
print
(
"Train tfrecords: "
+
str
(
train_filesTFR
))
print
(
"Train tfrecords: "
+
str
(
train_filesTFR
))
# tfrecords' paths for testing
# tfrecords' paths for testing
try
:
try
:
test_filenameTFR
=
sys
.
argv
[
2
]
test_filenameTFR
=
sys
.
argv
[
2
]
except
IndexError
:
except
IndexError
:
...
@@ -140,13 +141,13 @@ except IndexError:
...
@@ -140,13 +141,13 @@ except IndexError:
# if the path is a directory
# if the path is a directory
if
os
.
path
.
isdir
(
test_filenameTFR
):
if
os
.
path
.
isdir
(
test_filenameTFR
):
test_filesTFR
=
glob
.
glob
(
test_filenameTFR
+
"/test_*.tfrecords"
)
test_filesTFR
=
glob
.
glob
(
test_filenameTFR
+
"/test_*.tfrecords"
)
test_filenameTFR
=
test_filesTFR
[
0
]
test_filenameTFR
=
test_filesTFR
[
0
]
else
:
else
:
test_filesTFR
=
[
test_filenameTFR
]
test_filesTFR
=
[
test_filenameTFR
]
test_filesTFR
.
sort
()
test_filesTFR
.
sort
()
print
(
"Test tfrecords: "
+
str
(
test_filesTFR
))
print
(
"Test tfrecords: "
+
str
(
test_filesTFR
))
# Now we are left with 2 lists - train and test list
# Now we are left with 2 lists - train and test list
n_allowed_train_filesTFR
=
min
(
MAX_TRAIN_FILES_TFR
,
len
(
train_filesTFR
))
n_allowed_train_filesTFR
=
min
(
MAX_TRAIN_FILES_TFR
,
len
(
train_filesTFR
))
...
@@ -165,11 +166,11 @@ gt_ds_trains = [None]*n_allowed_train_filesTFR
...
@@ -165,11 +166,11 @@ gt_ds_trains = [None]*n_allowed_train_filesTFR
for
i
in
range
(
n_allowed_train_filesTFR
):
for
i
in
range
(
n_allowed_train_filesTFR
):
corr2d_trains
[
i
],
target_disparity_trains
[
i
],
gt_ds_trains
[
i
]
=
readTFRewcordsEpoch
(
train_filesTFR
[
i
])
corr2d_trains
[
i
],
target_disparity_trains
[
i
],
gt_ds_trains
[
i
]
=
readTFRewcordsEpoch
(
train_filesTFR
[
i
])
print_time
(
"Parsed "
+
train_filesTFR
[
i
])
print_time
(
"Parsed "
+
train_filesTFR
[
i
])
corr2d_train
=
corr2d_trains
[
0
]
corr2d_train
=
corr2d_trains
[
0
]
target_disparity_train
=
target_disparity_trains
[
0
]
target_disparity_train
=
target_disparity_trains
[
0
]
gt_ds_train
=
gt_ds_trains
[
0
]
gt_ds_train
=
gt_ds_trains
[
0
]
print_time
(
" Done"
)
print_time
(
" Done"
)
corr2d_train_placeholder
=
tf
.
placeholder
(
corr2d_train
.
dtype
,
(
None
,
324
))
# corr2d_train.shape)
corr2d_train_placeholder
=
tf
.
placeholder
(
corr2d_train
.
dtype
,
(
None
,
324
))
# corr2d_train.shape)
...
@@ -211,6 +212,9 @@ def lrelu(x):
...
@@ -211,6 +212,9 @@ def lrelu(x):
# return tf.nn.relu(x)
# return tf.nn.relu(x)
def
network_fc_simple
(
input
,
arch
=
0
):
def
network_fc_simple
(
input
,
arch
=
0
):
global
image_summary_op1
layouts
=
{
0
:[
0
,
0
,
0
,
32
,
20
,
16
],
layouts
=
{
0
:[
0
,
0
,
0
,
32
,
20
,
16
],
1
:[
0
,
0
,
0
,
256
,
128
,
64
],
1
:[
0
,
0
,
0
,
256
,
128
,
64
],
2
:[
0
,
128
,
32
,
32
,
32
,
16
],
2
:[
0
,
128
,
32
,
32
,
32
,
16
],
...
@@ -224,11 +228,88 @@ def network_fc_simple(input, arch = 0):
...
@@ -224,11 +228,88 @@ def network_fc_simple(input, arch = 0):
inp
=
fc
[
-
1
]
inp
=
fc
[
-
1
]
else
:
else
:
inp
=
input
inp
=
input
fc
.
append
(
slim
.
fully_connected
(
inp
,
num_outs
,
activation_fn
=
lrelu
,
scope
=
'g_fc'
+
str
(
i
)))
fc
.
append
(
slim
.
fully_connected
(
inp
,
num_outs
,
activation_fn
=
lrelu
,
scope
=
'g_fc'
+
str
(
i
)))
with
tf
.
variable_scope
(
'g_fc'
+
str
(
i
)
+
'/fully_connected'
,
reuse
=
tf
.
AUTO_REUSE
):
#with tf.variable_scope('g_fc'+str(i)+'/fully_connected',reuse=tf.AUTO_REUSE):
with
tf
.
variable_scope
(
'g_fc'
+
str
(
i
),
reuse
=
tf
.
AUTO_REUSE
):
w
=
tf
.
get_variable
(
'weights'
,
shape
=
[
inp
.
shape
[
1
],
num_outs
])
w
=
tf
.
get_variable
(
'weights'
,
shape
=
[
inp
.
shape
[
1
],
num_outs
])
b
=
tf
.
get_variable
(
'weights'
,
shape
=
[
inp
.
shape
[
1
],
num_outs
])
#image = tf.get_variable('w_images',shape=[1, inp.shape[1],num_outs,1])
if
(
i
==
3
):
# red border
grid
=
tf
.
constant
([
255
,
100
,
100
],
dtype
=
tf
.
float32
,
name
=
"GRID"
)
# (325,32)
wimg_1
=
w
# (32,325)
wimg_2
=
tf
.
transpose
(
wimg_1
,[
1
,
0
])
# (32,324)
wimg_3
=
wimg_2
[:,:
-
1
]
# res?
#wimg_res = tf.get_variable('wimg_res',shape=[32*(9+1),(9+1)*4, 3])
# long list
tmp1
=
[]
for
mi
in
range
(
32
):
tmp2
=
[]
for
mj
in
range
(
4
):
s_i
=
mj
*
81
e_i
=
(
mj
+
1
)
*
81
tile
=
tf
.
reshape
(
wimg_3
[
mi
,
s_i
:
e_i
],
shape
=
(
9
,
9
))
tiles
=
tf
.
stack
([
tile
]
*
3
,
axis
=
2
)
#gtiles1 = tf.concat([tiles, tf.reshape(9*[grid],shape=(1,9,3))],axis=0)
gtiles1
=
tf
.
concat
([
tiles
,
tf
.
expand_dims
(
9
*
[
grid
],
0
)],
axis
=
0
)
gtiles2
=
tf
.
concat
([
gtiles1
,
tf
.
expand_dims
(
10
*
[
grid
],
1
)],
axis
=
1
)
tmp2
.
append
(
gtiles2
)
ts
=
tf
.
concat
(
tmp2
,
axis
=
2
)
tmp1
.
append
(
ts
)
image_summary_op2
=
tf
.
concat
(
tmp1
,
axis
=
0
)
#image_summary_op1 = tf.assign(wimg_res,tf.zeros(shape=[32*(9+1),(9+1)*4, 3],dtype=tf.float32))
#wimgo1 = tf.zeros(shape=[32*(9+1),(9+1)*4, 3],dtype=tf.float32)
#tf.summary.image("wimg_res1",tf.reshape(wimg_res,[1,32*(9+1),(9+1)*4, 3]))
#tf.summary.image("wimgo1",tf.reshape(wimgo1,[1,32*(9+1),(9+1)*4, 3]))
#tf.summary.image("wimgo2",tf.reshape(wimgo2,[1,32*(9+1),(9+1)*4, 3]))
tf
.
summary
.
image
(
"SWEIGTS"
,
tf
.
reshape
(
gtiles2
,[
1
,
10
,
10
,
3
]))
tf
.
summary
.
image
(
"WEIGTS"
,
tf
.
reshape
(
image_summary_op2
,[
1
,
320
,
40
,
3
]))
# borders
#for mi in range(0,wimg_res.shape[0],10):
# for mj in range(wimg_res.shape[1]):
# wimg_res[mi,mj].assign([255,255,255])
#wimg_res[9::(9+1),:].assign([255,0,0])
#wimg_res[:,9::(9+1)].assign([255,0,0])
#for mi in range(0,wimg_res.shape[0],10):
# print(mi)
#wimg_res = tf.stack([wing_res,])
#wimg_1 = tf.reshape(w,[1,inp.shape[1],num_outs,1])
#wimg_1t = tf.transpose(wimg_1,[0,2,1,3])
# w = w[a,b]
# wt = w[b,a]
# for i in range(b):
# tmp =
#tf.summary.image("wimg_1",wimg_1)
#tf.summary.image("wimg_1t",wimg_1t)
#tf.summary.image("wimg_res1",tf.reshape(wimg_res,[1,32*(9+1),(9+1)*4, 3]))
b
=
tf
.
get_variable
(
'biases'
,
shape
=
[
num_outs
])
tf
.
summary
.
histogram
(
"weights"
,
w
)
tf
.
summary
.
histogram
(
"weights"
,
w
)
tf
.
summary
.
histogram
(
"biases"
,
b
)
tf
.
summary
.
histogram
(
"biases"
,
b
)
"""
"""
...
@@ -241,32 +322,34 @@ def network_fc_simple(input, arch = 0):
...
@@ -241,32 +322,34 @@ def network_fc_simple(input, arch = 0):
### fc3 = slim.fully_connected(input, 32, activation_fn=lrelu,scope='g_fc3')
### fc3 = slim.fully_connected(input, 32, activation_fn=lrelu,scope='g_fc3')
### fc4 = slim.fully_connected(fc3, 20, activation_fn=lrelu,scope='g_fc4')
### fc4 = slim.fully_connected(fc3, 20, activation_fn=lrelu,scope='g_fc4')
### fc5 = slim.fully_connected(fc4, 16, activation_fn=lrelu,scope='g_fc5')
### fc5 = slim.fully_connected(fc4, 16, activation_fn=lrelu,scope='g_fc5')
if
USE_CONFIDENCE
:
if
USE_CONFIDENCE
:
fc_out
=
slim
.
fully_connected
(
fc
[
-
1
],
2
,
activation_fn
=
lrelu
,
scope
=
'g_fc_out'
)
fc_out
=
slim
.
fully_connected
(
fc
[
-
1
],
2
,
activation_fn
=
lrelu
,
scope
=
'g_fc_out'
)
with
tf
.
variable_scope
(
'g_fc_out'
,
reuse
=
tf
.
AUTO_REUSE
):
with
tf
.
variable_scope
(
'g_fc_out'
,
reuse
=
tf
.
AUTO_REUSE
):
w
=
tf
.
get_variable
(
'weights'
,
shape
=
[
fc
[
-
1
]
.
shape
[
1
],
2
])
w
=
tf
.
get_variable
(
'weights'
,
shape
=
[
fc
[
-
1
]
.
shape
[
1
],
2
])
b
=
tf
.
get_variable
(
'biases'
,
shape
=
[
fc
[
-
1
]
.
shape
[
1
],
2
])
tf
.
summary
.
image
(
"wimage"
,
tf
.
reshape
(
w
,[
1
,
fc
[
-
1
]
.
shape
[
1
],
2
,
1
]))
b
=
tf
.
get_variable
(
'biases'
,
shape
=
[
2
])
tf
.
summary
.
histogram
(
"weights"
,
w
)
tf
.
summary
.
histogram
(
"weights"
,
w
)
tf
.
summary
.
histogram
(
"biases"
,
b
)
tf
.
summary
.
histogram
(
"biases"
,
b
)
else
:
else
:
fc_out
=
slim
.
fully_connected
(
fc
[
-
1
],
1
,
activation_fn
=
None
,
scope
=
'g_fc_out'
)
fc_out
=
slim
.
fully_connected
(
fc
[
-
1
],
1
,
activation_fn
=
None
,
scope
=
'g_fc_out'
)
with
tf
.
variable_scope
(
'g_fc_out'
,
reuse
=
tf
.
AUTO_REUSE
):
with
tf
.
variable_scope
(
'g_fc_out'
,
reuse
=
tf
.
AUTO_REUSE
):
w
=
tf
.
get_variable
(
'weights'
,
shape
=
[
fc
[
-
1
]
.
shape
[
1
],
1
])
w
=
tf
.
get_variable
(
'weights'
,
shape
=
[
fc
[
-
1
]
.
shape
[
1
],
1
])
tf
.
summary
.
image
(
"wimage"
,
tf
.
reshape
(
w
,[
1
,
fc
[
-
1
]
.
shape
[
1
],
1
,
1
]))
b
=
tf
.
get_variable
(
'biases'
,
shape
=
[
1
])
b
=
tf
.
get_variable
(
'biases'
,
shape
=
[
1
])
tf
.
summary
.
histogram
(
"weights"
,
w
)
tf
.
summary
.
histogram
(
"weights"
,
w
)
tf
.
summary
.
histogram
(
"biases"
,
b
)
tf
.
summary
.
histogram
(
"biases"
,
b
)
#If using residual disparity, split last layer into 2 or remove activation and add rectifier to confidence only
#If using residual disparity, split last layer into 2 or remove activation and add rectifier to confidence only
return
fc_out
return
fc_out
def
batchLoss
(
out_batch
,
# [batch_size,(1..2)] tf_result
def
batchLoss
(
out_batch
,
# [batch_size,(1..2)] tf_result
target_disparity_batch
,
# [batch_size] tf placeholder
target_disparity_batch
,
# [batch_size] tf placeholder
gt_ds_batch
,
# [batch_size,2] tf placeholder
gt_ds_batch
,
# [batch_size,2] tf placeholder
absolute_disparity
=
True
,
#when false there should be no activation on disparity output !
absolute_disparity
=
True
,
#when false there should be no activation on disparity output !
use_confidence
=
True
,
use_confidence
=
True
,
lambda_conf_avg
=
0.01
,
lambda_conf_avg
=
0.01
,
lambda_conf_pwr
=
0.1
,
lambda_conf_pwr
=
0.1
,
conf_pwr
=
2.0
,
conf_pwr
=
2.0
,
...
@@ -276,7 +359,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -276,7 +359,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
disp_wmin
=
1.0
,
# minimal disparity to apply weight boosting for small disparities
disp_wmin
=
1.0
,
# minimal disparity to apply weight boosting for small disparities
disp_wmax
=
8.0
,
# maximal disparity to apply weight boosting for small disparities
disp_wmax
=
8.0
,
# maximal disparity to apply weight boosting for small disparities
use_out
=
False
):
# use calculated disparity for disparity weight boosting (False - use target disparity)
use_out
=
False
):
# use calculated disparity for disparity weight boosting (False - use target disparity)
with
tf
.
name_scope
(
"BatchLoss"
):
with
tf
.
name_scope
(
"BatchLoss"
):
"""
"""
Here confidence should be after relU. Disparity - may be also if absolute, but no activation if output is residual disparity
Here confidence should be after relU. Disparity - may be also if absolute, but no activation if output is residual disparity
...
@@ -295,7 +378,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -295,7 +378,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
else
:
else
:
# w_slice = tf.slice(gt_ds_batch,[0,1],[-1,1], name = "w_gt_slice")
# w_slice = tf.slice(gt_ds_batch,[0,1],[-1,1], name = "w_gt_slice")
w_slice
=
tf
.
reshape
(
gt_ds_batch
[:,
1
],[
-
1
],
name
=
"w_gt_slice"
)
w_slice
=
tf
.
reshape
(
gt_ds_batch
[:,
1
],[
-
1
],
name
=
"w_gt_slice"
)
w_sub
=
tf
.
subtract
(
w_slice
,
tf_gt_conf_offset
,
name
=
"w_sub"
)
w_sub
=
tf
.
subtract
(
w_slice
,
tf_gt_conf_offset
,
name
=
"w_sub"
)
# w_clip = tf.clip_by_value(w_sub, tf_0f,tf_maxw, name = "w_clip")
# w_clip = tf.clip_by_value(w_sub, tf_0f,tf_maxw, name = "w_clip")
w_clip
=
tf
.
maximum
(
w_sub
,
tf_0f
,
name
=
"w_clip"
)
w_clip
=
tf
.
maximum
(
w_sub
,
tf_0f
,
name
=
"w_clip"
)
...
@@ -303,7 +386,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -303,7 +386,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
w
=
w_clip
w
=
w_clip
else
:
else
:
w
=
tf
.
pow
(
w_clip
,
tf_gt_conf_pwr
,
name
=
"w_pow"
)
w
=
tf
.
pow
(
w_clip
,
tf_gt_conf_pwr
,
name
=
"w_pow"
)
if
use_confidence
:
if
use_confidence
:
tf_num_tilesf
=
tf
.
cast
(
tf_num_tiles
,
dtype
=
tf
.
float32
,
name
=
"tf_num_tilesf"
)
tf_num_tilesf
=
tf
.
cast
(
tf_num_tiles
,
dtype
=
tf
.
float32
,
name
=
"tf_num_tilesf"
)
# conf_slice = tf.slice(out_batch,[0,1],[-1,1], name = "conf_slice")
# conf_slice = tf.slice(out_batch,[0,1],[-1,1], name = "conf_slice")
...
@@ -313,7 +396,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -313,7 +396,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
conf_avg1
=
tf
.
subtract
(
conf_avg
,
tf_1f
,
name
=
"conf_avg1"
)
conf_avg1
=
tf
.
subtract
(
conf_avg
,
tf_1f
,
name
=
"conf_avg1"
)
conf_avg2
=
tf
.
square
(
conf_avg1
,
name
=
"conf_avg2"
)
conf_avg2
=
tf
.
square
(
conf_avg1
,
name
=
"conf_avg2"
)
cost2
=
tf
.
multiply
(
conf_avg2
,
tf_lambda_conf_avg
,
name
=
"cost2"
)
cost2
=
tf
.
multiply
(
conf_avg2
,
tf_lambda_conf_avg
,
name
=
"cost2"
)
iconf_avg
=
tf
.
divide
(
tf_1f
,
conf_avg
,
name
=
"iconf_avg"
)
iconf_avg
=
tf
.
divide
(
tf_1f
,
conf_avg
,
name
=
"iconf_avg"
)
nconf
=
tf
.
multiply
(
conf_slice
,
iconf_avg
,
name
=
"nconf"
)
#normalized confidence
nconf
=
tf
.
multiply
(
conf_slice
,
iconf_avg
,
name
=
"nconf"
)
#normalized confidence
nconf_pwr
=
tf
.
pow
(
nconf
,
conf_pwr
,
name
=
"nconf_pwr"
)
nconf_pwr
=
tf
.
pow
(
nconf
,
conf_pwr
,
name
=
"nconf_pwr"
)
...
@@ -324,17 +407,17 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -324,17 +407,17 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
else
:
else
:
w_all
=
w
w_all
=
w
# cost2 = 0.0
# cost2 = 0.0
# cost3 = 0.0
# cost3 = 0.0
# normalize weights
# normalize weights
w_sum
=
tf
.
reduce_sum
(
w_all
,
name
=
"w_sum"
)
w_sum
=
tf
.
reduce_sum
(
w_all
,
name
=
"w_sum"
)
iw_sum
=
tf
.
divide
(
tf_1f
,
w_sum
,
name
=
"iw_sum"
)
iw_sum
=
tf
.
divide
(
tf_1f
,
w_sum
,
name
=
"iw_sum"
)
w_norm
=
tf
.
multiply
(
w_all
,
iw_sum
,
name
=
"w_norm"
)
w_norm
=
tf
.
multiply
(
w_all
,
iw_sum
,
name
=
"w_norm"
)
# disp_slice = tf.slice(out_batch,[0,0],[-1,1], name = "disp_slice")
# disp_slice = tf.slice(out_batch,[0,0],[-1,1], name = "disp_slice")
# d_gt_slice = tf.slice(gt_ds_batch,[0,0],[-1,1], name = "d_gt_slice")
# d_gt_slice = tf.slice(gt_ds_batch,[0,0],[-1,1], name = "d_gt_slice")
disp_slice
=
tf
.
reshape
(
out_batch
[:,
0
],[
-
1
],
name
=
"disp_slice"
)
disp_slice
=
tf
.
reshape
(
out_batch
[:,
0
],[
-
1
],
name
=
"disp_slice"
)
d_gt_slice
=
tf
.
reshape
(
gt_ds_batch
[:,
0
],[
-
1
],
name
=
"d_gt_slice"
)
d_gt_slice
=
tf
.
reshape
(
gt_ds_batch
[:,
0
],[
-
1
],
name
=
"d_gt_slice"
)
"""
"""
if absolute_disparity:
if absolute_disparity:
out_diff = tf.subtract(disp_slice, d_gt_slice, name = "out_diff")
out_diff = tf.subtract(disp_slice, d_gt_slice, name = "out_diff")
...
@@ -342,7 +425,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -342,7 +425,7 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
td_flat = tf.reshape(target_disparity_batch,[-1], name = "td_flat")
td_flat = tf.reshape(target_disparity_batch,[-1], name = "td_flat")
residual_disp = tf.subtract(d_gt_slice, td_flat, name = "residual_disp")
residual_disp = tf.subtract(d_gt_slice, td_flat, name = "residual_disp")
out_diff = tf.subtract(disp_slice, residual_disp, name = "out_diff")
out_diff = tf.subtract(disp_slice, residual_disp, name = "out_diff")
"""
"""
td_flat
=
tf
.
reshape
(
target_disparity_batch
,[
-
1
],
name
=
"td_flat"
)
td_flat
=
tf
.
reshape
(
target_disparity_batch
,[
-
1
],
name
=
"td_flat"
)
if
absolute_disparity
:
if
absolute_disparity
:
adisp
=
disp_slice
adisp
=
disp_slice
...
@@ -350,16 +433,16 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -350,16 +433,16 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
# td_flat = tf.reshape(target_disparity_batch,[-1], name = "td_flat")
# td_flat = tf.reshape(target_disparity_batch,[-1], name = "td_flat")
adisp
=
tf
.
add
(
disp_slice
,
td_flat
,
name
=
"adisp"
)
adisp
=
tf
.
add
(
disp_slice
,
td_flat
,
name
=
"adisp"
)
out_diff
=
tf
.
subtract
(
adisp
,
d_gt_slice
,
name
=
"out_diff"
)
out_diff
=
tf
.
subtract
(
adisp
,
d_gt_slice
,
name
=
"out_diff"
)
out_diff2
=
tf
.
square
(
out_diff
,
name
=
"out_diff2"
)
out_diff2
=
tf
.
square
(
out_diff
,
name
=
"out_diff2"
)
out_wdiff2
=
tf
.
multiply
(
out_diff2
,
w_norm
,
name
=
"out_wdiff2"
)
out_wdiff2
=
tf
.
multiply
(
out_diff2
,
w_norm
,
name
=
"out_wdiff2"
)
cost1
=
tf
.
reduce_sum
(
out_wdiff2
,
name
=
"cost1"
)
cost1
=
tf
.
reduce_sum
(
out_wdiff2
,
name
=
"cost1"
)
out_diff2_offset
=
tf
.
subtract
(
out_diff2
,
error2_offset
,
name
=
"out_diff2_offset"
)
out_diff2_offset
=
tf
.
subtract
(
out_diff2
,
error2_offset
,
name
=
"out_diff2_offset"
)
out_diff2_biased
=
tf
.
maximum
(
out_diff2_offset
,
0.0
,
name
=
"out_diff2_biased"
)
out_diff2_biased
=
tf
.
maximum
(
out_diff2_offset
,
0.0
,
name
=
"out_diff2_biased"
)
# calculate disparity-based weight boost
# calculate disparity-based weight boost
if
use_out
:
if
use_out
:
dispw
=
tf
.
clip_by_value
(
adisp
,
disp_wmin
,
disp_wmax
,
name
=
"dispw"
)
dispw
=
tf
.
clip_by_value
(
adisp
,
disp_wmin
,
disp_wmax
,
name
=
"dispw"
)
...
@@ -370,19 +453,19 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
...
@@ -370,19 +453,19 @@ def batchLoss(out_batch, # [batch_size,(1..2)] tf_result
dispw_sum
=
tf
.
reduce_sum
(
dispw_comp
,
name
=
"dispw_sum"
)
dispw_sum
=
tf
.
reduce_sum
(
dispw_comp
,
name
=
"dispw_sum"
)
idispw_sum
=
tf
.
divide
(
tf_1f
,
dispw_sum
,
name
=
"idispw_sum"
)
idispw_sum
=
tf
.
divide
(
tf_1f
,
dispw_sum
,
name
=
"idispw_sum"
)
dispw_norm
=
tf
.
multiply
(
dispw_comp
,
idispw_sum
,
name
=
"dispw_norm"
)
dispw_norm
=
tf
.
multiply
(
dispw_comp
,
idispw_sum
,
name
=
"dispw_norm"
)
out_diff2_wbiased
=
tf
.
multiply
(
out_diff2_biased
,
dispw_norm
,
name
=
"out_diff2_wbiased"
)
out_diff2_wbiased
=
tf
.
multiply
(
out_diff2_biased
,
dispw_norm
,
name
=
"out_diff2_wbiased"
)
# out_diff2_wbiased = tf.multiply(out_diff2_biased, w_norm, name = "out_diff2_wbiased")
# out_diff2_wbiased = tf.multiply(out_diff2_biased, w_norm, name = "out_diff2_wbiased")
cost1b
=
tf
.
reduce_sum
(
out_diff2_wbiased
,
name
=
"cost1b"
)
cost1b
=
tf
.
reduce_sum
(
out_diff2_wbiased
,
name
=
"cost1b"
)
if
use_confidence
:
if
use_confidence
:
cost12
=
tf
.
add
(
cost1b
,
cost2
,
name
=
"cost12"
)
cost12
=
tf
.
add
(
cost1b
,
cost2
,
name
=
"cost12"
)
cost123
=
tf
.
add
(
cost12
,
cost3
,
name
=
"cost123"
)
cost123
=
tf
.
add
(
cost12
,
cost3
,
name
=
"cost123"
)
return
cost123
,
disp_slice
,
d_gt_slice
,
out_diff
,
out_diff2
,
w_norm
,
out_wdiff2
,
cost1
return
cost123
,
disp_slice
,
d_gt_slice
,
out_diff
,
out_diff2
,
w_norm
,
out_wdiff2
,
cost1
else
:
else
:
return
cost1b
,
disp_slice
,
d_gt_slice
,
out_diff
,
out_diff2
,
w_norm
,
out_wdiff2
,
cost1
return
cost1b
,
disp_slice
,
d_gt_slice
,
out_diff
,
out_diff2
,
w_norm
,
out_wdiff2
,
cost1
#corr2d325 = tf.concat([corr2d,target_disparity],0)
#corr2d325 = tf.concat([corr2d,target_disparity],0)
#corr2d325 = tf.concat([next_element_train['corr2d'],tf.reshape(next_element_train['target_disparity'],(-1,1))],1)
#corr2d325 = tf.concat([next_element_train['corr2d'],tf.reshape(next_element_train['target_disparity'],(-1,1))],1)
...
@@ -397,7 +480,7 @@ G_loss, _disp_slice, _d_gt_slice, _out_diff, _out_diff2, _w_norm, _out_wdiff2, _
...
@@ -397,7 +480,7 @@ G_loss, _disp_slice, _d_gt_slice, _out_diff, _out_diff2, _w_norm, _out_wdiff2, _
target_disparity_batch
=
next_element_train
[
'target_disparity'
],
# target_disparity, ### target_d, # [batch_size] tf placeholder
target_disparity_batch
=
next_element_train
[
'target_disparity'
],
# target_disparity, ### target_d, # [batch_size] tf placeholder
gt_ds_batch
=
next_element_train
[
'gt_ds'
],
# gt_ds, ### gt, # [batch_size,2] tf placeholder
gt_ds_batch
=
next_element_train
[
'gt_ds'
],
# gt_ds, ### gt, # [batch_size,2] tf placeholder
absolute_disparity
=
ABSOLUTE_DISPARITY
,
absolute_disparity
=
ABSOLUTE_DISPARITY
,
use_confidence
=
USE_CONFIDENCE
,
# True,
use_confidence
=
USE_CONFIDENCE
,
# True,
lambda_conf_avg
=
0.01
,
lambda_conf_avg
=
0.01
,
lambda_conf_pwr
=
0.1
,
lambda_conf_pwr
=
0.1
,
conf_pwr
=
2.0
,
conf_pwr
=
2.0
,
...
@@ -407,7 +490,7 @@ G_loss, _disp_slice, _d_gt_slice, _out_diff, _out_diff2, _w_norm, _out_wdiff2, _
...
@@ -407,7 +490,7 @@ G_loss, _disp_slice, _d_gt_slice, _out_diff, _out_diff2, _w_norm, _out_wdiff2, _
disp_wmin
=
1.0
,
# minimal disparity to apply weight boosting for small disparities
disp_wmin
=
1.0
,
# minimal disparity to apply weight boosting for small disparities
disp_wmax
=
8.0
,
# maximal disparity to apply weight boosting for small disparities
disp_wmax
=
8.0
,
# maximal disparity to apply weight boosting for small disparities
use_out
=
False
)
# use calculated disparity for disparity weight boosting (False - use target disparity)
use_out
=
False
)
# use calculated disparity for disparity weight boosting (False - use target disparity)
tf_ph_G_loss
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
None
,
name
=
'G_loss_avg'
)
tf_ph_G_loss
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
None
,
name
=
'G_loss_avg'
)
tf_ph_sq_diff
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
None
,
name
=
'sq_diff_avg'
)
tf_ph_sq_diff
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
None
,
name
=
'sq_diff_avg'
)
with
tf
.
name_scope
(
'sample'
):
with
tf
.
name_scope
(
'sample'
):
...
@@ -446,14 +529,14 @@ def read_new_tfrecord_file(filename,result):
...
@@ -446,14 +529,14 @@ def read_new_tfrecord_file(filename,result):
result
.
append
(
c
)
result
.
append
(
c
)
print
(
"Loaded new tfrecord file: "
+
str
(
filename
))
print
(
"Loaded new tfrecord file: "
+
str
(
filename
))
train_record_index_counter
=
0
train_record_index_counter
=
0
train_file_index
=
0
train_file_index
=
0
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
merged
=
tf
.
summary
.
merge_all
()
merged
=
tf
.
summary
.
merge_all
()
train_writer
=
tf
.
summary
.
FileWriter
(
TRAIN_PATH
,
sess
.
graph
)
train_writer
=
tf
.
summary
.
FileWriter
(
TRAIN_PATH
,
sess
.
graph
)
test_writer
=
tf
.
summary
.
FileWriter
(
TEST_PATH
,
sess
.
graph
)
test_writer
=
tf
.
summary
.
FileWriter
(
TEST_PATH
,
sess
.
graph
)
...
@@ -461,29 +544,29 @@ with tf.Session() as sess:
...
@@ -461,29 +544,29 @@ with tf.Session() as sess:
loss_test_hist
=
np
.
empty
(
dataset_test_size
,
dtype
=
np
.
float32
)
loss_test_hist
=
np
.
empty
(
dataset_test_size
,
dtype
=
np
.
float32
)
loss2_train_hist
=
np
.
empty
(
dataset_train_size
,
dtype
=
np
.
float32
)
loss2_train_hist
=
np
.
empty
(
dataset_train_size
,
dtype
=
np
.
float32
)
loss2_test_hist
=
np
.
empty
(
dataset_test_size
,
dtype
=
np
.
float32
)
loss2_test_hist
=
np
.
empty
(
dataset_test_size
,
dtype
=
np
.
float32
)
train_avg
=
0.0
train_avg
=
0.0
train2_avg
=
0.0
train2_avg
=
0.0
test_avg
=
0.0
test_avg
=
0.0
test2_avg
=
0.0
test2_avg
=
0.0
for
epoch
in
range
(
EPOCHS_TO_RUN
):
for
epoch
in
range
(
EPOCHS_TO_RUN
):
train_file_index
=
epoch
%
n_allowed_train_filesTFR
train_file_index
=
epoch
%
n_allowed_train_filesTFR
print
(
"train_file_index: "
+
str
(
train_file_index
))
print
(
"train_file_index: "
+
str
(
train_file_index
))
if
epoch
%
10
==
0
:
if
epoch
%
10
==
0
:
# if there are more files than python3 memory allows
# if there are more files than python3 memory allows
if
(
n_allowed_train_filesTFR
<
len
(
train_filesTFR
)):
if
(
n_allowed_train_filesTFR
<
len
(
train_filesTFR
)):
# circular loading?
# circular loading?
tmp_train_index
=
(
n_allowed_train_filesTFR
+
train_record_index_counter
)
%
len
(
train_filesTFR
)
tmp_train_index
=
(
n_allowed_train_filesTFR
+
train_record_index_counter
)
%
len
(
train_filesTFR
)
# wait for old thread
# wait for old thread
if
epoch
!=
0
:
if
epoch
!=
0
:
if
thr
.
is_alive
():
if
thr
.
is_alive
():
print_time
(
"Waiting until tfrecord gets loaded"
)
print_time
(
"Waiting until tfrecord gets loaded"
)
thr
.
join
()
thr
.
join
()
# do replacement
# do replacement
## remove the first
## remove the first
corr2d_trains
.
pop
(
0
)
corr2d_trains
.
pop
(
0
)
target_disparity_trains
.
pop
(
0
)
target_disparity_trains
.
pop
(
0
)
gt_ds_trains
.
pop
(
0
)
gt_ds_trains
.
pop
(
0
)
...
@@ -491,19 +574,21 @@ with tf.Session() as sess:
...
@@ -491,19 +574,21 @@ with tf.Session() as sess:
corr2d_trains
.
append
(
thr_result
[
0
])
corr2d_trains
.
append
(
thr_result
[
0
])
target_disparity_trains
.
append
(
thr_result
[
1
])
target_disparity_trains
.
append
(
thr_result
[
1
])
gt_ds_trains
.
append
(
thr_result
[
2
])
gt_ds_trains
.
append
(
thr_result
[
2
])
print_time
(
"Time to begin loading a new tfrecord file"
)
print_time
(
"Time to begin loading a new tfrecord file"
)
# new thread
# new thread
thr_result
=
[]
thr_result
=
[]
thr
=
Thread
(
target
=
read_new_tfrecord_file
,
args
=
(
train_filesTFR
[
tmp_train_index
],
thr_result
))
thr
=
Thread
(
target
=
read_new_tfrecord_file
,
args
=
(
train_filesTFR
[
tmp_train_index
],
thr_result
))
# start
# start
thr
.
start
()
thr
.
start
()
train_record_index_counter
+=
1
train_record_index_counter
+=
1
# if SHUFFLE_EPOCH:
# if SHUFFLE_EPOCH:
# dataset_train = dataset_train.shuffle(buffer_size=10000)
# dataset_train = dataset_train.shuffle(buffer_size=10000)
# RUN TRAIN SESSION
sess
.
run
(
iterator_train
.
initializer
,
feed_dict
=
{
corr2d_train_placeholder
:
corr2d_trains
[
train_file_index
],
sess
.
run
(
iterator_train
.
initializer
,
feed_dict
=
{
corr2d_train_placeholder
:
corr2d_trains
[
train_file_index
],
target_disparity_train_placeholder
:
target_disparity_trains
[
train_file_index
],
target_disparity_train_placeholder
:
target_disparity_trains
[
train_file_index
],
gt_ds_train_placeholder
:
gt_ds_trains
[
train_file_index
]})
gt_ds_train_placeholder
:
gt_ds_trains
[
train_file_index
]})
...
@@ -524,24 +609,22 @@ with tf.Session() as sess:
...
@@ -524,24 +609,22 @@ with tf.Session() as sess:
corr2d325
,
corr2d325
,
],
],
feed_dict
=
{
lr
:
LR
,
tf_ph_G_loss
:
train_avg
,
tf_ph_sq_diff
:
train2_avg
})
# pfrevious value of *_avg
feed_dict
=
{
lr
:
LR
,
tf_ph_G_loss
:
train_avg
,
tf_ph_sq_diff
:
train2_avg
})
# pfrevious value of *_avg
# save all for now as a test
# save all for now as a test
#train_writer.add_summary(summary, i)
#train_writer.add_summary(summary, i)
#train_writer.add_summary(train_summary, i)
#train_writer.add_summary(train_summary, i)
loss_train_hist
[
i
]
=
G_loss_trained
loss_train_hist
[
i
]
=
G_loss_trained
loss2_train_hist
[
i
]
=
out_cost1
loss2_train_hist
[
i
]
=
out_cost1
except
tf
.
errors
.
OutOfRangeError
:
except
tf
.
errors
.
OutOfRangeError
:
print
(
"train done at step
%
d"
%
(
i
))
print
(
"train done at step
%
d"
%
(
i
))
break
break
train_avg
=
np
.
average
(
loss_train_hist
)
.
astype
(
np
.
float32
)
train_avg
=
np
.
average
(
loss_train_hist
)
.
astype
(
np
.
float32
)
train2_avg
=
np
.
average
(
loss2_train_hist
)
.
astype
(
np
.
float32
)
train2_avg
=
np
.
average
(
loss2_train_hist
)
.
astype
(
np
.
float32
)
#_,_=sess.run([tf_ph_G_loss,tf_ph_sq_diff],feed_dict={tf_ph_G_loss:train_avg, tf_ph_sq_diff:train2_avg})
#tf_ph_G_loss = tf.placeholder(tf.float32,shape=None,name='G_loss_avg')
# RUN TEST SESSION
#tf_ph_sq_diff = tf.placeholder(tf.float32,shape=None,name='sq_diff_avg')
sess
.
run
(
iterator_train
.
initializer
,
feed_dict
=
{
corr2d_train_placeholder
:
corr2d_test
,
sess
.
run
(
iterator_train
.
initializer
,
feed_dict
=
{
corr2d_train_placeholder
:
corr2d_test
,
target_disparity_train_placeholder
:
target_disparity_test
,
target_disparity_train_placeholder
:
target_disparity_test
,
gt_ds_train_placeholder
:
gt_ds_test
})
gt_ds_train_placeholder
:
gt_ds_test
})
...
@@ -566,18 +649,16 @@ with tf.Session() as sess:
...
@@ -566,18 +649,16 @@ with tf.Session() as sess:
except
tf
.
errors
.
OutOfRangeError
:
except
tf
.
errors
.
OutOfRangeError
:
print
(
"test done at step
%
d"
%
(
i
))
print
(
"test done at step
%
d"
%
(
i
))
break
break
# print_time("%d:%d -> %f"%(epoch,i,G_current))
test_avg
=
np
.
average
(
loss_test_hist
)
.
astype
(
np
.
float32
)
test_avg
=
np
.
average
(
loss_test_hist
)
.
astype
(
np
.
float32
)
test2_avg
=
np
.
average
(
loss2_test_hist
)
.
astype
(
np
.
float32
)
test2_avg
=
np
.
average
(
loss2_test_hist
)
.
astype
(
np
.
float32
)
# _,_=sess.run([tf_ph_G_loss,tf_ph_sq_diff],feed_dict={tf_ph_G_loss:test_avg, tf_ph_sq_diff:test2_avg})
# they include image summaries as well
train_writer
.
add_summary
(
train_summary
,
epoch
)
train_writer
.
add_summary
(
train_summary
,
epoch
)
test_writer
.
add_summary
(
test_summary
,
epoch
)
test_writer
.
add_summary
(
test_summary
,
epoch
)
print_time
(
"
%
d:
%
d ->
%
f
%
f (
%
f
%
f)"
%
(
epoch
,
i
,
train_avg
,
test_avg
,
train2_avg
,
test2_avg
))
print_time
(
"
%
d:
%
d ->
%
f
%
f (
%
f
%
f)"
%
(
epoch
,
i
,
train_avg
,
test_avg
,
train2_avg
,
test2_avg
))
# Close writers
# Close writers
train_writer
.
close
()
train_writer
.
close
()
test_writer
.
close
()
test_writer
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment