test_nn_feed.py 10.8 KB
Newer Older
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#!/usr/bin/env python3

__copyright__ = "Copyright 2018, Elphel, Inc."
__license__   = "GPL-3.0+"
__email__     = "oleg@elphel.com"

'''
Open all tiffs in a folder, combine a single tiff from randomly selected
tiles from originals
'''

from PIL import Image

import os
import sys
import glob

import imagej_tiff as ijt
import pack_tile as pile

import numpy as np
import itertools

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
24 25
import time

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
26 27
import matplotlib.pyplot as plt

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#http://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[38;5;214m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    BOLDWHITE = '\033[1;37m'
    UNDERLINE = '\033[4m'


def print_time():
  print(bcolors.BOLDWHITE+"time: "+str(time.time())+bcolors.ENDC)

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
44 45 46
# USAGE: python3 test_3.py some-path

VALUES_LAYER_NAME = 'other'
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
47 48
LAYERS_OF_INTEREST = ['diagm-pair', 'diago-pair', 'hor-pairs', 'vert-pairs']
RADIUS = 1
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
49

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
50
DEBUG_PLT_LOSS = True
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
51 52 53
# If false - will not pack or rescal
DEBUG_PACK_TILES = True

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
54 55 56 57 58
try:
  src = sys.argv[1]
except IndexError:
  src = "."

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
59
print("Importing TensorCrawl")
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
60 61
print_time()

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
62 63
import tensorflow as tf
import tensorflow.contrib.slim as slim
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
64

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
65
print("TensorCrawl imported")
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
66
print_time()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
67

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
68
IS_TEST = False
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
69

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
70 71
# BEGIN IF IS_TEST
if not IS_TEST:
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
72

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
73
  tlist = glob.glob(src+"/*.tiff")
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
74

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
75 76 77
  print("\n".join(tlist))
  print("Found "+str(len(tlist))+" preprocessed tiff files:")
  print_time()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
78 79 80

  pass

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
81 82 83 84
  ''' WARNING, assuming:
        - timestamps and part of names match
        - layer order and names are identical
  '''
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
85

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
86 87 88
  # open the first one to get dimensions and other info
  tiff = ijt.imagej_tiff(tlist[0])
  #del tlist[0]
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
89

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
90 91
  # shape as tiles? make a copy or make writeable
  # (242, 324, 9, 9, 5)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
92

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
93 94 95
  # get labels
  labels = tiff.labels.copy()
  labels.remove(VALUES_LAYER_NAME)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
96

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
97 98 99
  print("Image data layers:  "+str(labels))
  print("Layers of interest: "+str(LAYERS_OF_INTEREST))
  print("Values layer: "+str([VALUES_LAYER_NAME]))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
100

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
101 102 103
  # create copies
  tiles  = np.copy(tiff.getstack(labels,shape_as_tiles=True))
  values = np.copy(tiff.getvalues(label=VALUES_LAYER_NAME))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
104

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
105
  #gt = values[:,:,1:3]
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
106

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
107 108
  print("Mixed tiled input data shape: "+str(tiles.shape))
  #print_time()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
109

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
110 111
  # now generate a layer of indices to get other tiles
  indices = np.random.random_integers(0,len(tlist)-1,size=(tiles.shape[0],tiles.shape[1]))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
112

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
113
  #print(indices.shape)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
114

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
115 116 117
  # counts tiles from a certain tiff
  shuffle_counter = np.zeros(len(tlist),np.int32)
  shuffle_counter[0] = tiles.shape[0]*tiles.shape[1]
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
118

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
119 120 121 122 123 124 125
  for i in range(1,len(tlist)):
    #print(tlist[i])
    tmp_tiff  = ijt.imagej_tiff(tlist[i])
    tmp_tiles = tmp_tiff.getstack(labels,shape_as_tiles=True)
    tmp_vals  = tmp_tiff.getvalues(label=VALUES_LAYER_NAME)
    #tmp_tiles =
    #tiles[indices==i] = tmp_tiff[indices==i]
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
126

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
127 128 129 130 131 132 133
    # straight and clear
    # can do quicker?
    for y,x in itertools.product(range(indices.shape[0]),range(indices.shape[1])):
      if indices[y,x]==i:
        tiles[y,x]  = tmp_tiles[y,x]
        values[y,x] = tmp_vals[y,x]
        shuffle_counter[i] +=1
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
134

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
135 136 137
  # check shuffle counter
  for i in range(1,len(shuffle_counter)):
    shuffle_counter[0] -= shuffle_counter[i]
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
138

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
139 140
  print("Tiff files parts count in the mixed input = "+str(shuffle_counter))
  print_time()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
141

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
142
  # test later
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
143

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
144
  # might not need it because going to loop through anyway
145
  packed_tiles = pile.pack(tiles)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
146 147 148 149 150 151 152 153 154 155
  packed_tiles = np.dstack((packed_tiles,values[:,:,0]))

  print("Packed (81x4 -> 1x(25*4+1)) tiled input shape: "+str(packed_tiles.shape))
  print("Values shape "+str(values.shape))
  print_time()

# END IF IS_TEST


#print("CHECKPOINTE")
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
156 157 158 159 160 161 162 163

#for i in range(tiles.shape[0]):
#  for j in range(tiles.shape[1]):
#    nn_input = pile.get_tile_with_neighbors(tiles,i,j,RADIUS)
#    print("tile: "+str(i)+", "+str(j)+": shape = "+str(nn_input.shape))
#print_time()

result_dir = './result/'
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
164
checkpoint_dir = './result/'
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
165 166 167 168 169 170 171
save_freq = 500

def lrelu(x):
    return tf.maximum(x*0.2,x)

def network(input):

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
172 173 174 175 176 177
  fc1  = slim.fully_connected(input,2048,activation_fn=lrelu,scope='g_fc1')
  fc2  = slim.fully_connected(fc1,  1024,activation_fn=lrelu,scope='g_fc2')
  fc3  = slim.fully_connected(fc2,   512,activation_fn=lrelu,scope='g_fc3')
  fc4  = slim.fully_connected(fc3,     8,activation_fn=lrelu,scope='g_fc4')
  fc5  = slim.fully_connected(fc4,     4,activation_fn=lrelu,scope='g_fc5')
  fc6  = slim.fully_connected(fc5,     2,activation_fn=lrelu,scope='g_fc6')
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
178

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
179
  return fc6
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
180 181


Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
182
sess = tf.Session()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
183

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
184 185
in_tile = tf.placeholder(tf.float32,[None,101])
gt      = tf.placeholder(tf.float32,[None,2])
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
186 187 188 189 190 191 192 193 194 195


#losses    = tf.get_variable("losses", [None])
#update_operation = tf.assign(losses,tf.concat([losses,G_loss]))
#mean_loss = tf.reduce_mean(losses)

#tf.summary.scalar('gt_value', gt[0])
#tf.summary.scalar('gt_confidence', gt[1])
#tf.summary.scalar('gt_value',gt[0,0])

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
196
#cf_cutoff = tf.constant(tf.float32,[None,1])
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
197 198
out = network(in_tile)

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
199 200 201
#tf.summary.scalar('out_value', out[0,0])
#tf.summary.scalar('out_confidence', out[1])

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
202 203 204
# min cutoff
cf_cutoff = 0.173303
cf_w = tf.pow(tf.maximum(gt[:,1]-cf_cutoff,0.0),1)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
205 206 207
#cf_wsum = tf.reduce_sum(cf_w[~tf.is_nan(cf_w)])
#cf_w_norm = cf_w/cf_wsum
cf_w_norm = tf.nn.softmax(cf_w)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
208 209 210

#out_cf = out[:,1]

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
211
#G_loss = tf.reduce_mean(tf.abs(tf.nn.softmax(out[:,1])*out[:,0]-cf_w_norm*gt[:,0]))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
212
#G_loss = tf.reduce_mean(tf.squared_difference(out[:,0], gt[:,0]))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
213
#G_loss = tf.reduce_mean(tf.abs(out[:,0]-gt[:,0]))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
214
G_loss = tf.losses.mean_squared_error(gt[:,0],out[:,0],cf_w)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
215 216 217 218

tf.summary.scalar('loss', G_loss)
tf.summary.scalar('prediction', out[0,0])
tf.summary.scalar('ground truth', gt[0,0])
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
219

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
220 221 222 223 224
t_vars=tf.trainable_variables()
lr=tf.placeholder(tf.float32)
G_opt=tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss,var_list=[var for var in t_vars if var.name.startswith('g_')])

saver=tf.train.Saver()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
225 226 227 228 229 230

# ?!!!!!
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(result_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(result_dir + '/test')

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
231 232 233 234 235 236 237 238 239 240 241 242 243
sess.run(tf.global_variables_initializer())
ckpt=tf.train.get_checkpoint_state(checkpoint_dir)

if ckpt:
  print('loaded '+ckpt.model_checkpoint_path)
  saver.restore(sess,ckpt.model_checkpoint_path)


allfolders = glob.glob('./result/*0')
lastepoch = 0
for folder in allfolders:
  lastepoch = np.maximum(lastepoch, int(folder[-4:]))

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
244
g_loss = np.zeros((packed_tiles.shape[0]*packed_tiles.shape[1],1))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
245

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
246 247 248 249 250 251 252 253 254 255

recorded_loss = []
recorded_mean_loss = []

recorded_gt_d = []
recorded_gt_c = []

recorded_pr_d = []
recorded_pr_c = []

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
256
LR = 1e-5
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
257 258 259

print(bcolors.HEADER+"Last Epoch = "+str(lastepoch)+bcolors.ENDC)

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
260 261 262 263 264 265 266 267 268
if DEBUG_PLT_LOSS:
  plt.ion()   # something about plotting
  plt.figure(1, figsize=(4,12))
  pass




# RUN
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
269 270 271 272 273 274 275
# epoch is one image



for epoch in range(lastepoch,lastepoch+len(tlist)):

  print(bcolors.HEADER+"Epoch #"+str(epoch)+bcolors.ENDC)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
276

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
277
#for epoch in range(lastepoch,4001):
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
278 279
  if os.path.isdir("result/%04d"%epoch):
    continue
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
280

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
281 282
  cnt=0

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
283 284 285 286 287 288 289 290 291
  tlist_index = epoch - lastepoch

  print(bcolors.OKGREEN+"Processing "+tlist[tlist_index]+bcolors.ENDC)

  tmp_tiff  = ijt.imagej_tiff(tlist[tlist_index])
  tmp_tiles = tmp_tiff.getstack(labels,shape_as_tiles=True)
  tmp_vals  = tmp_tiff.getvalues(label=VALUES_LAYER_NAME)

  # might not need it because going to loop through anyway
292
  packed_tiles = pile.pack(tmp_tiles)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
293 294
  packed_tiles = np.dstack((packed_tiles,tmp_vals[:,:,0]))

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
295 296 297 298 299 300 301 302
  #if epoch > 2000:
  #  LR = 1e-5

  vsteps = packed_tiles.shape[0]//5
  hsteps = packed_tiles.shape[1]//5

  for ind in range(hsteps*vsteps):
  #for ind in np.random.permutation(packed_tiles.shape[0]*packed_tiles.shape[1]):
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
303

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
304
    #print("Iteration "+str(cnt))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
305

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
306 307
    st=time.time()
    cnt+=1
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
308

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
309 310
    #i = int(ind/packed_tiles.shape[1])
    #j = ind%packed_tiles.shape[1]
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
311

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
312 313
    i = 2 + 5*(ind//hsteps)
    j = 2 + 5*(ind%hsteps)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
314

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
315 316 317 318 319 320
    #input_patch = tiles[i,j]
    input_patch = np.empty((vsteps*hsteps,packed_tiles.shape[2]))
    input_patch = np.reshape(packed_tiles[i-2:i+3,j-2:j+3],(-1,101))

    gt_patch = np.empty((vsteps*hsteps,2))
    gt_patch = np.reshape(values[i-2:i+3,j-2:j+3,1:3],(-1,2))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
321 322 323 324 325

    #print(input_patch)
    #print(gt_patch)

    gt_patch[gt_patch==-256] = np.nan
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
326 327 328
    gt_patch[np.isnan(gt_patch)] = 0

    input_patch[np.isnan(input_patch)] = 0
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
329 330

    skip_iteration = False
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
331 332 333 334 335

    if values[i,j,0]==-256 and values[i,j,1]==-256:
      print("Have to SKIP!")
      skip_iteration = True

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
336
    # if nan skip run!
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
337
    if np.isnan(np.sum(gt_patch)):
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
338 339
      skip_iteration = True

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
340
    if np.isnan(np.sum(input_patch)):
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
341 342 343 344 345 346 347
      skip_iteration = True


    if skip_iteration:
      #print(bcolors.WARNING+"Found NaN, skipping iteration for tile "+str(i)+","+str(j)+bcolors.ENDC)
      pass
    else:
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
348 349 350 351 352 353 354

      run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
      run_metadata = tf.RunMetadata()

      _,G_current,output,summary = sess.run([G_opt,G_loss,out,merged],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR},options=run_options,run_metadata=run_metadata)
      #_,G_current,output = sess.run([G_opt,G_loss,out],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR})

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
355
      g_loss[ind]=G_current
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
356 357 358 359 360 361 362 363 364 365 366 367 368 369
      mean_loss = np.mean(g_loss[np.where(g_loss)])

      if DEBUG_PLT_LOSS:

        recorded_loss.append(G_current)
        recorded_mean_loss.append(mean_loss)

        recorded_pr_d.append(output[0,0])
        recorded_pr_c.append(output[0,1])

        recorded_gt_d.append(gt_patch[0,0])
        recorded_gt_c.append(gt_patch[0,1])

        plt.clf()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
370

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
        plt.subplot(311)

        plt.plot(recorded_loss,  label='loss')
        plt.plot(recorded_mean_loss,  label='mean loss', color='red')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.title("Loss=%.5f, Mean Loss=%.5f"%(G_current,mean_loss), fontdict={'size': 20, 'color': 'red'})
        #plt.text(0.5, 0.5, 'Loss=%.5f' % G_current, fontdict={'size': 20, 'color': 'red'})

        plt.subplot(312)

        plt.xlabel('Iteration')
        plt.ylabel('Disparities')
        plt.plot(recorded_gt_d,  label='gt_d',color='green')
        plt.plot(recorded_pr_d,  label='pr_d',color='red')
        plt.legend(loc='best',ncol=1)

        plt.subplot(313)

        plt.xlabel('Iteration')
        plt.ylabel('Confidences')
        plt.plot(recorded_gt_c,  label='gt_c',color='green')
        plt.plot(recorded_pr_c,  label='pr_c',color='red')
        plt.legend(loc='best',ncol=1)

        plt.pause(0.001)

      else:
        print("%d %d Loss=%.3f CurrentLoss=%.3f Time=%.3f"%(epoch,cnt,mean_loss,G_current,time.time()-st))
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
400
        #train_writer.add_run_metadata(run_metadata, 'step%d' % cnt)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
401
        #test_writer.add_summary(summary,cnt)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
402
        #train_writer.add_summary(summary, cnt)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
403 404 405 406 407 408

    if epoch%save_freq==0:
      if not os.path.isdir(result_dir + '%04d'%epoch):
        os.makedirs(result_dir + '%04d'%epoch)

  saver.save(sess, checkpoint_dir + 'model.ckpt')
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
409 410
  train_writer.close()
  test_writer.close()
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
411 412 413

print_time()
print(bcolors.OKGREEN+"time: "+str(time.time())+bcolors.ENDC)
Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
414

Oleg Dzhimiev's avatar
Oleg Dzhimiev committed
415 416
plt.ioff()
plt.show()