qcstereo_network.py 11.4 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python3
__copyright__ = "Copyright 2018, Elphel, Inc."
__license__   = "GPL-3.0+"
__email__     = "andrey@elphel.com"

#from numpy import float64
7
#import numpy as np
8 9 10
import tensorflow as tf
import tensorflow.contrib.slim as slim

11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
NN_LAYOUTS = {0:[0,   0,   0,   32,  20,  16],
              1:[0,   0,   0,  256, 128,  64],
              2:[0, 128,  32,   32,  32,  16],
              3:[0,   0,  40,   32,  20,  16],
              4:[0,   0,   0,    0,  16,  16],
              5:[0,   0,  64,   32,  32,  16],
              6:[0,   0,  32,   16,  16,  16],
              7:[0,   0,  64,   16,  16,  16],
              8:[0,   0,   0,   64,  20,  16],
              9:[0,   0, 256,   64,  32,  16],
             10:[0, 256, 128,   64,  32,  16],
             11:[0,   0,   0,    0,  64,  32],
             12:[0,   0, 256,  128,  64,  32],
             13:[0,   0,   0,  256, 128,  32],
              }
26 27 28

def lrelu(x):
    return tf.maximum(x*0.2,x)
29
#    return tf.nn.relu(x) 
30

31
def sym_inputs8(inp, cluster_radius = 2):
32 33 34 35 36
    """
    get input vector [?:4*9*9+1] (last being target_disparity) and reorder for horizontal flip,
    vertical flip and transpose (8 variants, mode + 1 - hor, +2 - vert, +4 - transpose)
    return same lengh, reordered
    """
37
    tile_side = 2 * cluster_radius + 1
38 39
    with tf.name_scope("sym_inputs8"):
        td =           inp[:,-1:] # tf.reshape(inp,[-1], name = "td")[-1]
40
        inp_corr =     tf.reshape(inp[:,:-1],[-1,4,tile_side,tile_side], name = "inp_corr")
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        inp_corr_h =   tf.stack([-inp_corr  [:,0,:,-1::-1], inp_corr  [:,1,:,-1::-1], -inp_corr  [:,3,:,-1::-1], -inp_corr  [:,2,:,-1::-1]], axis=1, name = "inp_corr_h")
        inp_corr_v =   tf.stack([ inp_corr  [:,0,-1::-1,:],-inp_corr  [:,1,-1::-1,:],  inp_corr  [:,3,-1::-1,:],  inp_corr  [:,2,-1::-1,:]], axis=1, name = "inp_corr_v")
        inp_corr_hv =  tf.stack([ inp_corr_h[:,0,-1::-1,:],-inp_corr_h[:,1,-1::-1,:],  inp_corr_h[:,3,-1::-1,:],  inp_corr_h[:,2,-1::-1,:]], axis=1, name = "inp_corr_hv")
        inp_corr_t =   tf.stack([tf.transpose(inp_corr   [:,1], perm=[0,2,1]),
                                 tf.transpose(inp_corr   [:,0], perm=[0,2,1]),
                                 tf.transpose(inp_corr   [:,2], perm=[0,2,1]),
                                -tf.transpose(inp_corr   [:,3], perm=[0,2,1])], axis=1, name = "inp_corr_t")
        inp_corr_ht =  tf.stack([tf.transpose(inp_corr_h [:,1], perm=[0,2,1]),
                                 tf.transpose(inp_corr_h [:,0], perm=[0,2,1]),
                                 tf.transpose(inp_corr_h [:,2], perm=[0,2,1]),
                                -tf.transpose(inp_corr_h [:,3], perm=[0,2,1])], axis=1, name = "inp_corr_ht")
        inp_corr_vt =  tf.stack([tf.transpose(inp_corr_v [:,1], perm=[0,2,1]),
                                 tf.transpose(inp_corr_v [:,0], perm=[0,2,1]),
                                 tf.transpose(inp_corr_v [:,2], perm=[0,2,1]),
                                -tf.transpose(inp_corr_v [:,3], perm=[0,2,1])], axis=1, name = "inp_corr_vt")
        inp_corr_hvt = tf.stack([tf.transpose(inp_corr_hv[:,1], perm=[0,2,1]),
                                 tf.transpose(inp_corr_hv[:,0], perm=[0,2,1]),
                                 tf.transpose(inp_corr_hv[:,2], perm=[0,2,1]),
                                -tf.transpose(inp_corr_hv[:,3], perm=[0,2,1])], axis=1, name = "inp_corr_hvt")
#        return td, [inp_corr, inp_corr_h, inp_corr_v, inp_corr_hv, inp_corr_t, inp_corr_ht, inp_corr_vt, inp_corr_hvt]
        """
        return [tf.concat([tf.reshape(inp_corr,    [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr"),
                tf.concat([tf.reshape(inp_corr_h,  [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_h"),
                tf.concat([tf.reshape(inp_corr_v,  [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_v"),
                tf.concat([tf.reshape(inp_corr_hv, [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_hv"),
                tf.concat([tf.reshape(inp_corr_t,  [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_t"),
                tf.concat([tf.reshape(inp_corr_ht, [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_ht"),
                tf.concat([tf.reshape(inp_corr_vt, [inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_vt"),
                tf.concat([tf.reshape(inp_corr_hvt,[inp_corr.shape[0],-1]),td], axis=1,name = "out_corr_hvt")]
        """
71
        cl = 4 * tile_side * tile_side
72 73 74 75 76 77 78 79 80 81 82
        return [tf.concat([tf.reshape(inp_corr,    [-1,cl]),td], axis=1,name = "out_corr"),
                tf.concat([tf.reshape(inp_corr_h,  [-1,cl]),td], axis=1,name = "out_corr_h"),
                tf.concat([tf.reshape(inp_corr_v,  [-1,cl]),td], axis=1,name = "out_corr_v"),
                tf.concat([tf.reshape(inp_corr_hv, [-1,cl]),td], axis=1,name = "out_corr_hv"),
                tf.concat([tf.reshape(inp_corr_t,  [-1,cl]),td], axis=1,name = "out_corr_t"),
                tf.concat([tf.reshape(inp_corr_ht, [-1,cl]),td], axis=1,name = "out_corr_ht"),
                tf.concat([tf.reshape(inp_corr_vt, [-1,cl]),td], axis=1,name = "out_corr_vt"),
                tf.concat([tf.reshape(inp_corr_hvt,[-1,cl]),td], axis=1,name = "out_corr_hvt")]
#                           inp_corr_h, inp_corr_v, inp_corr_hv, inp_corr_t, inp_corr_ht, inp_corr_vt, inp_corr_hvt]
    

83
def network_sub(input_tensor,
84 85 86
                input_global,  #add to all layers (but first) if not None
                layout,
                reuse,
87 88 89
                sym8 = False,
                cluster_radius = 2):
#    last_indx = None;
90 91 92 93
    fc = []
    inp_weights = []
    for i, num_outs in enumerate (layout):
        if num_outs:
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            if fc:
                if input_global is None:
                    inp = fc[-1]
                else:
                    inp = tf.concat([fc[-1], input_global], axis = 1)
                fc.append(slim.fully_connected(inp,    num_outs, activation_fn=lrelu, scope='g_fc_sub'+str(i), reuse = reuse))
            else:
                inp = input_tensor
                if sym8:
                    inp8 = sym_inputs8(inp, cluster_radius)
                    num_non_sum = num_outs %  len(inp8) # if number of first layer outputs is not multiple of 8
                    num_sym8 =    num_outs // len(inp8) # number of symmetrical groups
                    fc_sym = []
                    for j in range (len(inp8)): # ==8
                        reuse_this = reuse | (j > 0)
                        scp = 'g_fc_sub'+str(i)
                        fc_sym.append(slim.fully_connected(inp8[j],    num_sym8, activation_fn=lrelu, scope= scp,     reuse = reuse_this))
                        if not reuse_this:
                            with tf.variable_scope(scp,reuse=True) : # tf.AUTO_REUSE):
                                inp_weights.append(tf.get_variable('weights')) # ,shape=[inp.shape[1],num_outs])) 
                    if num_non_sum > 0:
                        reuse_this = reuse
                        scp = 'g_fc_sub'+str(i)+"r"
                        fc_sym.append(slim.fully_connected(inp,     num_non_sum, activation_fn=lrelu, scope=scp, reuse = reuse_this))    
                        if not reuse_this:
                            with tf.variable_scope(scp,reuse=True) : # tf.AUTO_REUSE):
                                inp_weights.append(tf.get_variable('weights')) # ,shape=[inp.shape[1],num_outs])) 
                    fc.append(tf.concat(fc_sym, 1, name='sym_input_layer'))
                else:
                    scp = 'g_fc_sub'+str(i)
                    fc.append(slim.fully_connected(inp,    num_outs, activation_fn=lrelu, scope= scp, reuse = reuse))
                    if not reuse:
                        with tf.variable_scope(scp, reuse=True) : # tf.AUTO_REUSE):
                            inp_weights.append(tf.get_variable('weights')) # ,shape=[inp.shape[1],num_outs])) 
128 129 130
           
    return fc[-1], inp_weights

131
def network_inter(input_tensor,
132 133 134 135
                  input_global,  #add to all layers (but first) if not None
                  layout,
                  reuse=False,
                  use_confidence=False):
136
    #last_indx = None;
137 138 139
    fc = []
    for i, num_outs in enumerate (layout):
        if num_outs:
140 141 142 143 144 145 146 147
            if fc:
                if input_global is None:
                    inp = fc[-1]
                else:
                    inp = tf.concat([fc[-1], input_global], axis = 1)
            else:
                inp = input_tensor
            fc.append(slim.fully_connected(inp,    num_outs, activation_fn=lrelu, scope='g_fc_inter'+str(i), reuse = reuse))
148 149 150 151 152 153 154
    if use_confidence:
        fc_out  = slim.fully_connected(fc[-1],     2, activation_fn=lrelu, scope='g_fc_inter_out', reuse = reuse)
    else:     
        fc_out  = slim.fully_connected(fc[-1],     1, activation_fn=None, scope='g_fc_inter_out', reuse = reuse)
        #If using residual disparity, split last layer into 2 or remove activation and add rectifier to confidence only  
    return fc_out

155
def networks_siam(input_tensor, # now [?,9,325]-> [?,25,325]
156 157 158 159 160 161 162
                  input_global, # add to all layers (but first) if not None
                  layout1, 
                  layout2,
                  inter_convergence,
                  sym8 =        False,
                  only_tile =   None, # just for debugging - feed only data from the center sub-network
                  partials =    None,
163 164
                  use_confidence=False,
                  cluster_radius = 2):
165
                  
166
    center_index = (input_tensor.shape[1] - 1) // 2 
167 168
    with tf.name_scope("Siam_net"):
        inp_weights = []
169
        num_legs =  input_tensor.shape[1] # == 25
170 171 172 173 174 175 176 177 178 179
        if partials is None:
            partials = [[True] * num_legs]
        inter_lists = [[] for _ in partials]
        reuse = False
        for i in range (num_legs):
            if ((only_tile is None) or (i == only_tile)) and any([p[i] for p in partials]) :
                if input_global is None:
                    ig = None
                else:
                    ig =input_global[:,i,:]
180
                ns, ns_weights = network_sub(input_tensor[:,i,:],
181 182 183
                                             ig, # input_global[:,i,:],
                                             layout= layout1,
                                             reuse= reuse,
184 185
                                             sym8 = sym8,
                                             cluster_radius = cluster_radius)
186 187 188 189 190 191 192 193 194 195 196 197 198 199
                for n, partial in enumerate(partials):
                    if partial[i]:
                        inter_lists[n].append(ns)
                    else:
                        inter_lists[n].append(tf.zeros_like(ns))
                inp_weights += ns_weights
                reuse = True
        outs = []         
        for n, _ in enumerate(partials):
            if input_global is None:
                ig = None
            else:
                ig =input_global[:,center_index,:]
            
200
            outs.append(network_inter (input_tensor = tf.concat(inter_lists[n],
201 202 203 204 205 206 207 208 209 210
                                                 axis=1,
                                                 name='inter_tensor'+str(n)),
                                       input_global =   [None, ig][inter_convergence], # optionally feed all convergence values (from each tile of a cluster)
                                       layout =         layout2,
                                       reuse =          (n > 0),
                                       use_confidence = use_confidence))
        return  outs,  inp_weights