Commit 45a473f0 authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

testing

parent 9c8f781e
......@@ -65,118 +65,43 @@ import tensorflow.contrib.slim as slim
print("TensorCrawl imported")
print_time()
IS_TEST = False
# BEGIN IF IS_TEST
if not IS_TEST:
tlist = glob.glob(src+"/*.tiff")
print("\n".join(tlist))
print("Found "+str(len(tlist))+" preprocessed tiff files:")
print_time()
pass
''' WARNING, assuming:
- timestamps and part of names match
- layer order and names are identical
'''
# open the first one to get dimensions and other info
tiff = ijt.imagej_tiff(tlist[0])
#del tlist[0]
# shape as tiles? make a copy or make writeable
# (242, 324, 9, 9, 5)
# get labels
labels = tiff.labels.copy()
labels.remove(VALUES_LAYER_NAME)
print("Image data layers: "+str(labels))
print("Layers of interest: "+str(LAYERS_OF_INTEREST))
print("Values layer: "+str([VALUES_LAYER_NAME]))
# create copies
tiles = np.copy(tiff.getstack(labels,shape_as_tiles=True))
values = np.copy(tiff.getvalues(label=VALUES_LAYER_NAME))
#gt = values[:,:,1:3]
print("Mixed tiled input data shape: "+str(tiles.shape))
#print_time()
# now generate a layer of indices to get other tiles
indices = np.random.random_integers(0,len(tlist)-1,size=(tiles.shape[0],tiles.shape[1]))
#print(indices.shape)
# counts tiles from a certain tiff
shuffle_counter = np.zeros(len(tlist),np.int32)
shuffle_counter[0] = tiles.shape[0]*tiles.shape[1]
for i in range(1,len(tlist)):
#print(tlist[i])
tmp_tiff = ijt.imagej_tiff(tlist[i])
tmp_tiles = tmp_tiff.getstack(labels,shape_as_tiles=True)
tmp_vals = tmp_tiff.getvalues(label=VALUES_LAYER_NAME)
#tmp_tiles =
#tiles[indices==i] = tmp_tiff[indices==i]
# straight and clear
# can do quicker?
for y,x in itertools.product(range(indices.shape[0]),range(indices.shape[1])):
if indices[y,x]==i:
tiles[y,x] = tmp_tiles[y,x]
values[y,x] = tmp_vals[y,x]
shuffle_counter[i] +=1
# check shuffle counter
for i in range(1,len(shuffle_counter)):
shuffle_counter[0] -= shuffle_counter[i]
print("Tiff files parts count in the mixed input = "+str(shuffle_counter))
print_time()
# test later
# might not need it because going to loop through anyway
packed_tiles = pile.pack(tiles)
packed_tiles = np.dstack((packed_tiles,values[:,:,0]))
print("Packed (81x4 -> 1x(25*4+1)) tiled input shape: "+str(packed_tiles.shape))
print("Values shape "+str(values.shape))
print_time()
tlist = glob.glob(src+"/*.tiff")
print("Found "+str(len(tlist))+" preprocessed tiff files:")
print("\n".join(tlist))
print_time()
# END IF IS_TEST
tiff = ijt.imagej_tiff(tlist[0])
# get labels
labels = tiff.labels.copy()
labels.remove(VALUES_LAYER_NAME)
#print("CHECKPOINTE")
print("Image data layers: "+str(labels))
print("Layers of interest: "+str(LAYERS_OF_INTEREST))
print("Values layer: "+str([VALUES_LAYER_NAME]))
#for i in range(tiles.shape[0]):
# for j in range(tiles.shape[1]):
# nn_input = pile.get_tile_with_neighbors(tiles,i,j,RADIUS)
# print("tile: "+str(i)+", "+str(j)+": shape = "+str(nn_input.shape))
#print_time()
result_dir = './result/'
checkpoint_dir = './result/'
save_freq = 500
def lrelu(x):
return tf.maximum(x*0.2,x)
#return tf.maximum(x*0.2,x)
return tf.nn.relu(x)
def network(input):
fc1 = slim.fully_connected(input,2048,activation_fn=lrelu,scope='g_fc1')
fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2')
fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3')
fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5')
fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6')
fc1 = slim.fully_connected(input,512,activation_fn=lrelu,scope='g_fc1')
fc2 = slim.fully_connected(fc1, 2,activation_fn=lrelu,scope='g_fc2')
return fc2
return fc6
#fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2')
#fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3')
#fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
#fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5')
#fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6')
#return fc6
sess = tf.Session()
......@@ -241,9 +166,6 @@ lastepoch = 0
for folder in allfolders:
lastepoch = np.maximum(lastepoch, int(folder[-4:]))
g_loss = np.zeros((packed_tiles.shape[0]*packed_tiles.shape[1],1))
recorded_loss = []
recorded_mean_loss = []
......@@ -253,7 +175,7 @@ recorded_gt_c = []
recorded_pr_d = []
recorded_pr_c = []
LR = 1e-5
LR = 1e-3
print(bcolors.HEADER+"Last Epoch = "+str(lastepoch)+bcolors.ENDC)
......@@ -274,12 +196,9 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
print(bcolors.HEADER+"Epoch #"+str(epoch)+bcolors.ENDC)
#for epoch in range(lastepoch,4001):
if os.path.isdir("result/%04d"%epoch):
continue
cnt=0
tlist_index = epoch - lastepoch
print(bcolors.OKGREEN+"Processing "+tlist[tlist_index]+bcolors.ENDC)
......@@ -295,49 +214,35 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
#if epoch > 2000:
# LR = 1e-5
vsteps = packed_tiles.shape[0]//5
hsteps = packed_tiles.shape[1]//5
for ind in range(hsteps*vsteps):
#for ind in np.random.permutation(packed_tiles.shape[0]*packed_tiles.shape[1]):
#print("Iteration "+str(cnt))
# so, here get the image, remove nans and run for 100x times
packed_tiles[np.isnan(packed_tiles)] = 0.0
tmp_vals[np.isnan(tmp_vals)] = 0.0
st=time.time()
cnt+=1
#i = int(ind/packed_tiles.shape[1])
#j = ind%packed_tiles.shape[1]
#packed_tiles = packed_tiles[::,::]
values = tmp_vals
i = 2 + 5*(ind//hsteps)
j = 2 + 5*(ind%hsteps)
input_patch = np.reshape(packed_tiles,(-1,101))
gt_patch = np.reshape(values[:,:,1:3],(-1,2))
#input_patch = tiles[i,j]
input_patch = np.empty((vsteps*hsteps,packed_tiles.shape[2]))
input_patch = np.reshape(packed_tiles[i-2:i+3,j-2:j+3],(-1,101))
g_loss = np.zeros(input_patch.shape[0])
gt_patch = np.empty((vsteps*hsteps,2))
gt_patch = np.reshape(values[i-2:i+3,j-2:j+3,1:3],(-1,2))
#print(input_patch)
#print(gt_patch)
for i in range(100):
#gt_patch[gt_patch==-256] = np.nan
#gt_patch[np.isnan(gt_patch)] = 0
print(bcolors.OKBLUE+"Iteration "+str(i)+bcolors.ENDC)
input_patch[np.isnan(input_patch)] = 0
st=time.time()
skip_iteration = False
# if nan skip run!
if np.isnan(np.sum(gt_patch)):
print("GT has NaNs")
skip_iteration = True
#skip_iteration = True
if np.isnan(np.sum(input_patch)):
print("Patch has NaNs")
skip_iteration = True
#skip_iteration = True
if skip_iteration:
#print(bcolors.WARNING+"Found NaN, skipping iteration for tile "+str(i)+","+str(j)+bcolors.ENDC)
......@@ -350,7 +255,7 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
_,G_current,output,summary = sess.run([G_opt,G_loss,out,merged],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR},options=run_options,run_metadata=run_metadata)
#_,G_current,output = sess.run([G_opt,G_loss,out],feed_dict={in_tile:input_patch,gt:gt_patch,lr:LR})
g_loss[ind]=G_current
g_loss[i]=G_current
mean_loss = np.mean(g_loss[np.where(g_loss)])
if DEBUG_PLT_LOSS:
......@@ -394,7 +299,7 @@ for epoch in range(lastepoch,lastepoch+len(tlist)):
plt.pause(0.001)
else:
print("%d %d Loss=%.3f CurrentLoss=%.3f Time=%.3f"%(epoch,cnt,mean_loss,G_current,time.time()-st))
print("%d %d Loss=%.3f CurrentLoss=%.3f Time=%.3f"%(epoch,i,mean_loss,G_current,time.time()-st))
#train_writer.add_run_metadata(run_metadata, 'step%d' % cnt)
#test_writer.add_summary(summary,cnt)
#train_writer.add_summary(summary, cnt)
......
......@@ -68,14 +68,17 @@ def lrelu(x):
def network(input):
fc1 = slim.fully_connected(input,2048,activation_fn=lrelu,scope='g_fc1')
fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2')
fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3')
fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5')
fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6')
return fc6
fc1 = slim.fully_connected(input,512,activation_fn=lrelu,scope='g_fc1')
fc2 = slim.fully_connected(fc1, 2,activation_fn=lrelu,scope='g_fc2')
return fc2
#fc2 = slim.fully_connected(fc1, 1024,activation_fn=lrelu,scope='g_fc2')
#fc3 = slim.fully_connected(fc2, 512,activation_fn=lrelu,scope='g_fc3')
#fc4 = slim.fully_connected(fc3, 8,activation_fn=lrelu,scope='g_fc4')
#fc5 = slim.fully_connected(fc4, 4,activation_fn=lrelu,scope='g_fc5')
#fc6 = slim.fully_connected(fc5, 2,activation_fn=lrelu,scope='g_fc6')
#return fc6
sess = tf.Session()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment