Commit 45a473f0 authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

testing

parent 9c8f781e
...@@ -65,118 +65,43 @@ import tensorflow.contrib.slim as slim ...@@ -65,118 +65,43 @@ import tensorflow.contrib.slim as slim
print("TensorCrawl imported") print("TensorCrawl imported")
print_time() print_time()
IS_TEST = False tlist = glob.glob(src+"/*.tiff")
print("Found "+str(len(tlist))+" preprocessed tiff files:")
# BEGIN IF IS_TEST print("\n".join(tlist))
if not IS_TEST: print_time()
tlist = glob.glob(src+"/*.tiff")
print("\n".join(tlist))
print("Found "+str(len(tlist))+" preprocessed tiff files:")
print_time()
pass
''' WARNING, assuming:
- timestamps and part of names match
- layer order and names are identical
'''
# open the first one to get dimensions and other info
tiff = ijt.imagej_tiff(tlist[0])
#del tlist[0]
# shape as tiles? make a copy or make writeable
# (242, 324, 9, 9, 5)
# get labels
labels = tiff.labels.copy()
labels.remove(VALUES_LAYER_NAME)
print("Image data layers: "+str(labels))
print("Layers of interest: "+str(LAYERS_OF_INTEREST))
print("Values layer: "+str([VALUES_LAYER_NAME]))
# create copies
tiles = np.copy(tiff.getstack(labels,shape_as_tiles=True))
values = np.copy(tiff.getvalues(label=VALUES_LAYER_NAME))
#gt = values[:,:,1:3]
print("Mixed tiled input data shape: "+str(tiles.shape))
#print_time()
# now generate a layer of indices to get other tiles
indices = np.random.random_integers(0,len(tlist)-1,size=(tiles.shape[0],tiles.shape[1]))
#print(indices.shape)
# counts tiles from a certain tiff
shuffle_counter = np.zeros(len(tlist),np.int32)
shuffle_counter[0] = tiles.shape[0]*tiles.shape[1]
for i in range(1,len(tlist)):
#print(tlist[i])
tmp_tiff = ijt.imagej_tiff(tlist[i])
tmp_tiles = tmp_tiff.getstack(labels,shape_as_tiles=True)
tmp_vals = tmp_tiff.getvalues(label=VALUES_LAYER_NAME)
#tmp_tiles =
#tiles[indices==i] = tmp_tiff[indices==i]
# straight and clear
# can do quicker?
for y,x in itertools.product(range(indices.shape[0]),range(indices.shape[1])):
if indices[y,x]==i:
tiles[y,x] = tmp_tiles[y,x]
values[y,x] = tmp_vals[y,x]
shuffle_counter[i] +=1
# check shuffle counter
for i in range(1,len(shuffle_counter)):
shuffle_counter[0] -= shuffle_counter[i]
print("Tiff files parts count in the mixed input = "+str(shuffle_counter))
print_time()
# test later
# might not need it because going to loop through anyway
packed_tiles = pile.pack(tiles)
packed_tiles = np.dstack((packed_tiles,values[:,:,0]))
print("Packed (81x4 -> 1x(25*4+1)) tiled input shape: "+str(packed_tiles.shape))
print("Values shape "+str(values.shape))
print_time()
# END IF IS_TEST tiff = ijt.imagej_tiff(tlist[0])
# get labels
labels = tiff.labels.copy()
labels.remove(VALUES_LAYER_NAME)
#print("CHECKPOINTE") print("Image data layers: "+str(labels))
print("Layers of interest: "+str(LAYERS_OF_INTEREST))
print("Values layer: "+str([VALUES_LAYER_NAME]))
#for i in range(tiles.shape[0]):
# for j in range(tiles.shape[1]):
# nn_input = pile.get_tile_with_neighbors(tiles,i,j,RADIUS)
# print("tile: "+str(i)+", "+str(j)+": shape = "+str(nn_input.shape))
#print_time()
result_dir = './result/' result_dir = './result/'
checkpoint_dir = './result/' checkpoint_dir = './result/'
save_freq = 500 save_freq = 500
def lrelu(x): def lrelu(x):
return tf.maximum(x*0.2,x) #return tf.maximum(x*0.2,x)
return tf.nn.relu(x)
def network(input): def network(input):
fc1 = slim.fully_connected(input,2048,activation_fn=lrelu,scope='g_fc1') fc1 = slim.fully_connected(input,512,activation_fn=lrelu,scope='g_fc1')
fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2') fc2 = slim.fully_connected(fc1, 2,activation_fn=lrelu,scope='g_fc2')
fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3') return fc2
fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5') #fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2')
fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6') #fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3')
#fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
#fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5')
#fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6')
return fc6 #return fc6
sess = tf.Session() sess = tf.Session()
...@@ -241,9 +166,6 @@ lastepoch = 0 ...@@ -241,9 +166,6 @@ lastepoch = 0
for folder in allfolders: for folder in allfolders:
lastepoch = np.maximum(lastepoch, int(folder[-4:])) lastepoch = np.maximum(lastepoch, int(folder[-4:]))
g_loss = np.zeros((packed_tiles.shape[0]*packed_tiles.shape[1],1))
recorded_loss = [] recorded_loss = []
recorded_mean_loss = [] recorded_mean_loss = []
...@@ -253,7 +175,7 @@ recorded_gt_c = [] ...@@ -253,7 +175,7 @@ recorded_gt_c = []
recorded_pr_d = [] recorded_pr_d = []
recorded_pr_c = [] recorded_pr_c = []
LR = 1e-5 LR = 1e-3
print(bcolors.HEADER+"Last Epoch = "+str(lastepoch)+bcolors.ENDC) print(bcolors.HEADER+"Last Epoch = "+str(lastepoch)+bcolors.ENDC)
...@@ -274,12 +196,9 @@ for epoch in range(lastepoch,lastepoch+len(tlist)): ...@@ -274,12 +196,9 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
print(bcolors.HEADER+"Epoch #"+str(epoch)+bcolors.ENDC) print(bcolors.HEADER+"Epoch #"+str(epoch)+bcolors.ENDC)
#for epoch in range(lastepoch,4001):
if os.path.isdir("result/%04d"%epoch): if os.path.isdir("result/%04d"%epoch):
continue continue
cnt=0
tlist_index = epoch - lastepoch tlist_index = epoch - lastepoch
print(bcolors.OKGREEN+"Processing "+tlist[tlist_index]+bcolors.ENDC) print(bcolors.OKGREEN+"Processing "+tlist[tlist_index]+bcolors.ENDC)
...@@ -295,49 +214,35 @@ for epoch in range(lastepoch,lastepoch+len(tlist)): ...@@ -295,49 +214,35 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
#if epoch > 2000: #if epoch > 2000:
# LR = 1e-5 # LR = 1e-5
vsteps = packed_tiles.shape[0]//5 # so, here get the image, remove nans and run for 100x times
hsteps = packed_tiles.shape[1]//5 packed_tiles[np.isnan(packed_tiles)] = 0.0
tmp_vals[np.isnan(tmp_vals)] = 0.0
for ind in range(hsteps*vsteps): #packed_tiles = packed_tiles[::,::]
#for ind in np.random.permutation(packed_tiles.shape[0]*packed_tiles.shape[1]): values = tmp_vals
#print("Iteration "+str(cnt)) input_patch = np.reshape(packed_tiles,(-1,101))
gt_patch = np.reshape(values[:,:,1:3],(-1,2))
st=time.time()
cnt+=1
#i = int(ind/packed_tiles.shape[1]) g_loss = np.zeros(input_patch.shape[0])
#j = ind%packed_tiles.shape[1]
i = 2 + 5*(ind//hsteps)
j = 2 + 5*(ind%hsteps)
#input_patch = tiles[i,j] for i in range(100):
input_patch = np.empty((vsteps*hsteps,packed_tiles.shape[2]))
input_patch = np.reshape(packed_tiles[i-2:i+3,j-2:j+3],(-1,101))
gt_patch = np.empty((vsteps*hsteps,2)) print(bcolors.OKBLUE+"Iteration "+str(i)+bcolors.ENDC)
gt_patch = np.reshape(values[i-2:i+3,j-2:j+3,1:3],(-1,2))
#print(input_patch) st=time.time()
#print(gt_patch)
#gt_patch[gt_patch==-256] = np.nan
#gt_patch[np.isnan(gt_patch)] = 0
input_patch[np.isnan(input_patch)] = 0
skip_iteration = False skip_iteration = False
# if nan skip run! # if nan skip run!
if np.isnan(np.sum(gt_patch)): if np.isnan(np.sum(gt_patch)):
print("GT has NaNs") print("GT has NaNs")
skip_iteration = True #skip_iteration = True
if np.isnan(np.sum(input_patch)): if np.isnan(np.sum(input_patch)):
print("Patch has NaNs") print("Patch has NaNs")
skip_iteration = True #skip_iteration = True
if skip_iteration: if skip_iteration:
#print(bcolors.WARNING+"Found NaN, skipping iteration for tile "+str(i)+","+str(j)+bcolors.ENDC) #print(bcolors.WARNING+"Found NaN, skipping iteration for tile "+str(i)+","+str(j)+bcolors.ENDC)
...@@ -350,7 +255,7 @@ for epoch in range(lastepoch,lastepoch+len(tlist)): ...@@ -350,7 +255,7 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
_,G_current,output,summary = sess.run([G_opt,G_loss,out,merged],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR},options=run_options,run_metadata=run_metadata) _,G_current,output,summary = sess.run([G_opt,G_loss,out,merged],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR},options=run_options,run_metadata=run_metadata)
#_,G_current,output = sess.run([G_opt,G_loss,out],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR}) #_,G_current,output = sess.run([G_opt,G_loss,out],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR})
g_loss[ind]=G_current g_loss[i]=G_current
mean_loss = np.mean(g_loss[np.where(g_loss)]) mean_loss = np.mean(g_loss[np.where(g_loss)])
if DEBUG_PLT_LOSS: if DEBUG_PLT_LOSS:
...@@ -394,7 +299,7 @@ for epoch in range(lastepoch,lastepoch+len(tlist)): ...@@ -394,7 +299,7 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
plt.pause(0.001) plt.pause(0.001)
else: else:
print("%d %d Loss=%.3f CurrentLoss=%.3f Time=%.3f"%(epoch,cnt,mean_loss,G_current,time.time()-st)) print("%d %d Loss=%.3f CurrentLoss=%.3f Time=%.3f"%(epoch,i,mean_loss,G_current,time.time()-st))
#train_writer.add_run_metadata(run_metadata, 'step%d' % cnt) #train_writer.add_run_metadata(run_metadata, 'step%d' % cnt)
#test_writer.add_summary(summary,cnt) #test_writer.add_summary(summary,cnt)
#train_writer.add_summary(summary, cnt) #train_writer.add_summary(summary, cnt)
......
...@@ -68,14 +68,17 @@ def lrelu(x): ...@@ -68,14 +68,17 @@ def lrelu(x):
def network(input): def network(input):
fc1 = slim.fully_connected(input,2048,activation_fn=lrelu,scope='g_fc1') fc1 = slim.fully_connected(input,512,activation_fn=lrelu,scope='g_fc1')
fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2') fc2 = slim.fully_connected(fc1, 2,activation_fn=lrelu,scope='g_fc2')
fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3') return fc2
fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5') #fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2')
fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6') #fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3')
#fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
return fc6 #fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5')
#fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6')
#return fc6
sess = tf.Session() sess = tf.Session()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment