numpy_visualize_weights.py 3.24 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import math

# input: np.array(a,b) - 1 channel
# output: np.array(a,b,3) - 3 color channels
def coldmap(img,zero_span=0.2):

  out = np.dstack(3*[img])

  img_min = np.nanmin(img)
  img_max = np.nanmax(img)

  #print("min: "+str(img_min)+", max: "+str(img_max))

  ch_r = out[...,0]
  ch_g = out[...,1]
  ch_b = out[...,2]

  # blue for <0
  ch_r[img<0] = 0
  ch_g[img<0] = 0
  ch_b[img<0] = -ch_b[img<0]

  # red for >0
  ch_r[img>0] = ch_b[img>0]
  ch_g[img>0] = 0
  ch_b[img>0] = 0

  # green for 0
  ch_r[img==0] = 0
  ch_g[img==0] = img_max
  ch_b[img==0] = 0

  # green for zero vicinity
  ch_r[abs(img)<zero_span/2] = 0
  ch_g[abs(img)<zero_span/2] = img_max/2
  ch_b[abs(img)<zero_span/2] = 0

  return out

# has to be pre transposed
# it just suppose to match
def tiles(img,shape,tiles_per_line=1,borders=True):

  # shape is (n0,n1,n2,n3)
  # n0*n1*n2*n3 = img.shape[1]
  img_min = np.nanmin(img)
  img_max = np.nanmax(img)

  outer_color = [img_max,img_max,img_min]
  outer_color = [img_max,img_max,img_max]

  inner_color = [img_max/4,img_max/4,img_min]
  inner_color = [img_min,img_min,img_min]
58
  #inner_color = [img_max,img_max,img_min]
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

  group_h = shape[0]
  group_w = shape[1]
  group_size = group_h*group_w

  tile_h = shape[2]
  tile_w = shape[3]
  tile_size = tile_h*tile_w

  tpl = tiles_per_line

  # main

  tmp1 = []

  for i in range(img.shape[0]):

    if i%tpl==0:
      tmp2 = []

    tmp3 = []

    for igh in range(group_h):

      tmp4 = []

      for igw in range(group_w):

        si = (group_w*igh + igw + 0)*tile_size
        ei = (group_w*igh + igw + 1)*tile_size

        tile = img[i,si:ei]
        tile = np.reshape(tile,(tile_h,tile_w,tile.shape[1]))

        if borders:

          if igw==group_w-1:

            b_h_inner = [[inner_color]*(tile_w+0)]*(       1)
            b_h_outer = [[outer_color]*(tile_w+0)]*(       1)
            b_v_outer = [[outer_color]*(       1)]*(tile_h+1)

            # outer hor
            if igh==group_h-1:
              tile = np.concatenate([tile,b_h_outer],axis=0)
            # inner hor
            else:
              tile = np.concatenate([tile,b_h_inner],axis=0)
            # outer vert
            tile = np.concatenate([tile,b_v_outer],axis=1)

          else:

            b_v_inner = [[inner_color]*(       1)]*(tile_h+0)
            b_h_inner = [[inner_color]*(tile_w+1)]*(       1)
            b_h_outer = [[outer_color]*(tile_w+1)]*(       1)

            # inner vert
            tile = np.concatenate([tile,b_v_inner],axis=1)

            # outer hor
            if igh==group_h-1:
              tile = np.concatenate([tile,b_h_outer],axis=0)
            # inner hor
            else:
              tile = np.concatenate([tile,b_h_inner],axis=0)

        tmp4.append(tile)

      tmp3.append(np.concatenate(tmp4,axis=1))

    tmp2.append(np.concatenate(tmp3,axis=0))

    if i%tpl==(tpl-1):
      tmp1.append(np.concatenate(tmp2,axis=1))

  out = np.concatenate(tmp1,axis=0)
  #out = img
  return out

if __name__=="__main__":

141 142 143
  #image = np.zeros((32,144))
  image = np.random.rand(32,144)
  rgb_img_0 = tiles(coldmap(image),(3,3,4,4),tiles_per_line=8,borders=True)
144 145

  fig = plt.figure()
146
  fig.suptitle("Test")
147 148
  plt.imshow(rgb_img_0)
  plt.show()