Commit a91aab50 authored by Andrey Filippov's avatar Andrey Filippov

Tested tensorflow model on the data in inside this program

parent 6e5aa3ce
...@@ -387,6 +387,7 @@ private Panel panel1, ...@@ -387,6 +387,7 @@ private Panel panel1,
public PixelMapping.InterSensor.DisparityTiles DISPARITY_TILES=null; public PixelMapping.InterSensor.DisparityTiles DISPARITY_TILES=null;
public ImagePlus DBG_IMP = null; public ImagePlus DBG_IMP = null;
public ImagePlus CORRELATE_IMP = null; public ImagePlus CORRELATE_IMP = null;
public TensorflowInferModel TENSORFLOW_INFER_MODEL = null;
public class SyncCommand{ public class SyncCommand{
public boolean isRunning= false; public boolean isRunning= false;
...@@ -4754,18 +4755,21 @@ private Panel panel1, ...@@ -4754,18 +4755,21 @@ private Panel panel1,
return; return;
/* ======================================================================== */ /* ======================================================================== */
} else if (label.equals("TF TEST")) { } else if (label.equals("TF TEST")) {
DEBUG_LEVEL=MASTER_DEBUG_LEVEL;
// link 1 (general): https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary boolean keep_empty = true;
// link 2 (example of an TF & IJ plugin): https://github.com/google/microscopeimagequality/blob/main/microscopeimagequality/data/imagej/src/main/java/MicroscopeImageFocusQualityClassifier.java if (TENSORFLOW_INFER_MODEL == null) {
TENSORFLOW_INFER_MODEL = new TensorflowInferModel(
TensorflowExamplePlugin t = new TensorflowExamplePlugin(); 324, // int tilesX,
t.run(); 242, // int tilesY,
9, // int corr_side,
4 //int num_layers
);
}
TENSORFLOW_INFER_MODEL.test_tensorflow(keep_empty);
return; return;
//JTabbedTest //JTabbedTest
// End of buttons code // End of buttons code
} }
DEBUG_LEVEL=MASTER_DEBUG_LEVEL;
} }
public String getSaveCongigPath() { public String getSaveCongigPath() {
......
...@@ -3,20 +3,17 @@ ...@@ -3,20 +3,17 @@
* SPDX-License-Identifier: GPL-3.0-or-later * SPDX-License-Identifier: GPL-3.0-or-later
*/ */
import org.tensorflow.Graph; import java.nio.FloatBuffer;
import org.tensorflow.Session; import java.nio.IntBuffer;
import org.tensorflow.Tensor; import java.util.ArrayList;
import org.tensorflow.Tensors;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle; import org.tensorflow.SavedModelBundle;
import org.tensorflow.OperationBuilder; import org.tensorflow.Tensor;
import org.tensorflow.Shape;
import org.tensorflow.Output;
import org.tensorflow.Operation;
import java.util.ArrayList; import ij.IJ;
import java.util.Collection; import ij.ImagePlus;
import java.util.List; import ij.ImageStack;
import ij.WindowManager;
/** /**
...@@ -48,223 +45,297 @@ Sessions and Tensors ...@@ -48,223 +45,297 @@ Sessions and Tensors
*/ */
public class TensorflowExamplePlugin public class TensorflowInferModel
{ {
public final static String EXPORTDIR = "/home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir"; final static String TRAINED_MODEL = "trained_model"; // /home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir";
// tf.saved_model.tag_constants.SERVING = "serve" final static String SERVING = "serve";
public final static String SERVING = "serve"; final int tilesX, tilesY, num_tiles, num_layers;
final int corr_side, corr_side2;
// final long [] shape_corr2d;
public static void run() // final long [] shape_target_disparity;
{ // final long [] shape_tiles_stage1;
System.out.println("TensorflowExamplePlugin run"); // final long [] shape_tiles_stage2;
try { final FloatBuffer fb_corr2d;
main(); final FloatBuffer fb_target_disparity;
} catch (Exception e) { final IntBuffer fb_tiles_stage1;
// TODO Auto-generated catch block final IntBuffer fb_tiles_stage2;
e.printStackTrace(); final FloatBuffer fb_predicted;
}
} /*
long[] shape = new long[] {batch, imageSize};
/**
* From https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*/
/**
* wraps a single float in a tensor
* @param f the float to wrap
* @return a tensor containing the float
*/ */
private static Tensor<Float> toTensor(final float f, final Collection<Tensor<?>> tensorsToClose) { // public
final Tensor<Float> t = Tensors.create(f); final SavedModelBundle bundle;
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static Tensor<Float> toTensor2DFloat(final float [][] f, final Collection<Tensor<?>> tensorsToClose) {
final Tensor<Float> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static Tensor<Integer> toTensor1DInt(final int [] f, final Collection<Tensor<?>> tensorsToClose) {
final Tensor<Integer> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static void closeTensors(final Collection<Tensor<?>> ts) {
for (final Tensor<?> t : ts) {
try {
t.close();
} catch (final Exception e) {
System.err.println("Error closing Tensor.");
e.printStackTrace();
}
}
ts.clear();
}
public TensorflowInferModel(int tilesX, int tilesY, int corr_side, int num_layers)
public static void main() throws Exception { {
this.tilesX = tilesX;
final Graph smpb; this.tilesY = tilesY;
this.num_tiles = tilesX*tilesY;
// init for variable? this.num_layers = num_layers;
float [][] rv_stage1_out = new float[78408][32]; this.corr_side = corr_side;
this.corr_side2 = corr_side * corr_side;
// from infer_qcds_01.py // allocate buffers to be used for tensors
float [][] img_corr2d = new float[78408][324]; this.fb_corr2d = FloatBuffer.allocate(num_tiles * corr_side2*num_layers);
float [][] img_target = new float[78408][ 1]; this.fb_target_disparity = FloatBuffer.allocate(num_tiles );
int [] img_ntile = new int[78408]; this.fb_tiles_stage1 = IntBuffer.allocate(num_tiles );
this.fb_tiles_stage2 = IntBuffer.allocate(num_tiles );
// init ntile for testing? this.fb_predicted = FloatBuffer.allocate(num_tiles );
for(int i=0;i<img_ntile.length;i++){
img_ntile[i] = i; String abs_model_path = getClass().getClassLoader().getResource(TRAINED_MODEL).getFile();
System.out.println("TensorflowInferModel model path: "+abs_model_path);
// this will load graph/data and open a session that does not need to be closed until the program is closed
//// bundle = null;
bundle = SavedModelBundle.load(abs_model_path,SERVING);
// Operation opr = bundle.graph().operation("rv_stage1_out");
} }
final SavedModelBundle bundle = SavedModelBundle.load(EXPORTDIR,SERVING);
final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(5);
System.out.println("OK");
try { public void run_stage1(
FloatBuffer fb_corr2d,
System.out.println("S0:"); FloatBuffer fb_target_disparity,
// read Variable info test IntBuffer fb_tiles_stage1)
Operation opr = bundle.graph().operation("rv_stage1_out"); {
System.out.println(opr.toString()); int ntiles = fb_tiles_stage1.limit(); // actual number of entries
long [] shape_corr2d = new long [] {ntiles, corr_side2*num_layers};
System.out.println("S1:"); long [] shape_target_disparity = new long [] {ntiles, 1};
long [] shape_tiles_stage1 = new long [] {ntiles};
//opr = bundle.graph().operation("rv_stageY_out");
//System.out.println(opr.toString()); final Tensor<Float> t_corr2d = Tensor.create(shape_corr2d, fb_corr2d);
final Tensor<Float> target_disparity = Tensor.create(shape_target_disparity, fb_target_disparity);
// init variable via constant? final Tensor<Integer> t_tiles_stage1 = Tensor.create(shape_tiles_stage1,fb_tiles_stage1);
//Tensor<Float> tsr = toTensor2DFloat(rv_stage1_out, tensorsToClose);
/* // this run does not output any data, but maybe it should still be captured and disposed of?
Output builder_init = bundle.graph() final Tensor<?> t_result_stage1 = bundle.session().runner()
.opBuilder("Const", "rv_stage1_out_init") .feed("ph_corr2d", t_corr2d)
.setAttr ("dtype", tsr.dataType()) .feed("ph_target_disparity", target_disparity)
.setAttr ("value", tsr) .feed("ph_ntile", t_tiles_stage1)
.build()
.output(0);
*/
//System.out.println(builder_init);
// variable
//OperationBuilder builder2 = bundle.graph().opBuilder("Variable", "rv_stage1_out_extra_variable");
//.addInput(builder_init);
//builder2.
//bundle.graph().opBuilder("Assign", "Assign/" + builder2.op().name()).addInput(variable).addInput(value).build().output(0);
//Tensor<Float> tensorVal = tsr;
//Output oValue = bundle.graph().opBuilder("Const", "rv_stage1_out_2").setAttr("dtype", tensorVal.dataType()).setAttr("value", tensorVal).build().output(0);
//System.out.println(oValue);
//Output oValue = bundle.graph().opBuilder("Variable", "rv_stage1_out").setAttr("value", tensorVal).build().output(0);
//bundle.graph().opBuilder("Assign", "Assign/rv_stage1_out").setAttr("value", tsr).build();
System.out.println("Stage 0.1");
//bundle.session().runner().fetch("rv_stageY_out").run();
System.out.println("Stage 0.2");
bundle.session().runner().fetch("rv_stage1_out").run();
System.out.println("Stage 1");
// stage 1
bundle.session().runner()
.feed("ph_corr2d",toTensor2DFloat(img_corr2d, tensorsToClose))
.feed("ph_target_disparity",toTensor2DFloat(img_target, tensorsToClose))
.feed("ph_ntile",toTensor1DInt(img_ntile, tensorsToClose))
.fetch("Disparity_net/stage1done:0") .fetch("Disparity_net/stage1done:0")
.run() .run()
.get(0); .get(0);
System.out.println("Stage 1 DONE"); t_result_stage1. close();
t_corr2d. close();
System.out.println("Stage 2"); target_disparity. close();
t_tiles_stage1. close();
}
// stage 2 public void run_stage2(
final Tensor<?> result = bundle.session().runner() IntBuffer fb_tiles_stage2,
.feed("ph_ntile_out",toTensor1DInt(img_ntile, tensorsToClose)) FloatBuffer fb_disparity) {
int ntiles = fb_tiles_stage2.limit(); // actual number of entries
long [] shape_tiles_stage2 = new long [] {ntiles};
final Tensor<Integer> t_tiles_stage2 = Tensor.create(shape_tiles_stage2, fb_tiles_stage2);
final Tensor<?> t_result_disparity = bundle.session().runner()
.feed("ph_ntile_out",t_tiles_stage2)
.fetch("Disparity_net/stage2_out_sparse:0") .fetch("Disparity_net/stage2_out_sparse:0")
.run() .run()
.get(0); .get(0);
fb_disparity.rewind();
t_result_disparity.writeTo(fb_disparity);
System.out.println("Stage 2 DONE: "+result.shape()); t_result_disparity. close();
t_tiles_stage2. close();
}
tensorsToClose.add(result);
System.out.println("Copy result to variable"); //**************************
float [][] resultValues = (float[][]) result.copyTo(new float[78408][1]); // helper class to prepare TileProcessor task
class TfInTile{
float [] corr2d;
float target_disparity;
float gt_disparity;
float gt_strength;
int tile;
public void setFromCorr2d(
float [][] data,
int tileX,
int tileY,
int corr_side,
int tilesX
) {
int corr_side2 = corr_side*corr_side;
int layers = data.length -1;
this.corr2d = new float [layers * corr_side2];
float [] other = new float [corr_side2];
int width = tilesX * corr_side;
int tl = tileY * corr_side * width + tileX * corr_side; // index to the
for (int nl = 0; nl <= layers; nl++) {
if (nl < layers) {
for (int row = 0; row < corr_side; row++) {
System.arraycopy(data[nl], tl+width*row, corr2d, nl*corr_side2 + row*corr_side, corr_side);
}
} else {
for (int row = 0; row < corr_side; row++) {
System.arraycopy(data[layers], tl+width*row, other, row*corr_side, corr_side);
}
this.target_disparity = other[ImageDtt.ML_OTHER_TARGET];
this.gt_disparity = other[ImageDtt.ML_OTHER_GTRUTH];
this.gt_strength = other[ImageDtt.ML_OTHER_GTRUTH_STRENGTH];
}
}
if (Float.isNaN(this.target_disparity)) {
corr2d = new float[corr2d.length]; // zero them all
}
this.tile = tileY*tilesX + tileX;
System.out.println("DONE"); }
//System.arraycopy(sym_conv, 0, tile_in, 0, n2*n2)
//} catch (final IllegalStateException ise) { }
// System.out.println("Very Bad Error (VBE): "+ise); public int test_tensorflow(
// closeTensors(tensorsToClose); boolean keep_empty) {
} catch (final NumberFormatException nfe) {
//just skip unparsable lines ?! int dbgX = 162;
} finally { int dbgY = 121;
closeTensors(tensorsToClose); // int dbgT = tilesX ^ dbgY + dbgX;
int corr_side = 9;
String [] slices = {"hor-pairs","vert-pairs","diagm-pair","vert-pairs","other"};
ImagePlus imp_src = WindowManager.getCurrentImage();
if (imp_src==null){
IJ.showMessage("Error","2D Correlation image stack required");
return -1;
}
ImageStack corr_stack = imp_src.getStack();
String [] labels = corr_stack.getSliceLabels();
// for (int ii = 0; ii < labels.length; ii++) {
// System.out.println(ii+": "+labels[ii]);
// }
int tilesX= corr_stack.getWidth() / corr_side;
int tilesY= corr_stack.getHeight() / corr_side;
float [][] corr_data = new float [slices.length][];
for (int nslice = 0; nslice < slices.length; nslice++ ) {
int ns = -1;
for (int i = 0; i < labels.length; i++) {
if (slices[nslice].equals(labels[i])) {
ns = i;
break;
}
}
if (ns < 0) {
System.out.println("Slice "+slices[nslice]+" is not found in the image");
return -1;
} else {
corr_data[nslice] = (float[]) corr_stack.getPixels(ns + 1) ;
}
}
ArrayList<TfInTile> tf_tiles = new ArrayList<TfInTile>();
for (int tileY = 0; tileY < tilesY; tileY++) {
for (int tileX = 0; tileX < tilesX; tileX++) {
if ((tileY==dbgY) && (tileX==dbgX)) {
System.out.println("tileY = "+tileY+", tileX = "+tileX+" tile = "+(tileY*tilesX + tileX));
}
TfInTile tf_tile = new TfInTile();
tf_tile.setFromCorr2d(
corr_data, // float [][] data,
tileX, // int tileX,
tileY, // int tileY,
corr_side, // int corr_side,
tilesX); // int tilesX
if (keep_empty || !Float.isNaN(tf_tile.target_disparity)) {
// if (Float.isNaN(tf_tile.target_disparity)) {
// tf_tile.target_disparity = 0.0f;
// tf_tile.gt_strength = 0.0f;
// }
tf_tiles.add(tf_tile);
}
}
} }
//try (){ // sets the limit to the capacity and the position to zero.
fb_corr2d.clear();
fb_target_disparity.clear();
fb_tiles_stage1.clear();
fb_tiles_stage2.clear();
for (int i = 0; i < tf_tiles.size(); i++) {
TfInTile tf_tile = tf_tiles.get(i);
fb_corr2d.put (tf_tile.corr2d);
float td = tf_tile.target_disparity;
if (Float.isNaN(td)) td = 0.0f;
fb_target_disparity.put(i, td);
fb_tiles_stage1.put(tf_tile.tile);
fb_tiles_stage2.put(tf_tile.tile);
}
// fb_target_disparity.limit(tf_tiles.size()); // put float absolute does not movew the pointer
fb_target_disparity.position(tf_tiles.size()); // put float absolute does not movew the pointer
fb_tiles_stage1. position(tf_tiles.size()); // put float absolute does not movew the pointer
fb_tiles_stage2. position(tf_tiles.size()); // put float absolute does not movew the pointer
//sets the limit to the current position and then sets the position to zero.
fb_corr2d. flip();
fb_target_disparity.flip();
fb_tiles_stage1. flip();
fb_tiles_stage2. flip();
String [] titles = {"predicted", "target", "gt_disparity", "gt_strength", "nn_out", "nn_error","abs_err","abs_heur","clean_nn"};
fb_predicted.rewind(); // not needed for absolute get();
float [][] result = new float [titles.length][tilesX*tilesY];
/*
for (int i = 0; i < tf_tiles.size(); i++) {
TfInTile tf_tile = tf_tiles.get(i);
// result[0][i] = tf_tile.target_disparity + fb_predicted.get(i);
result[1][i] = tf_tile.target_disparity;
result[2][i] = tf_tile.gt_disparity;
result[3][i] = tf_tile.gt_strength;
}
//smpb = b.graph(); (new showDoubleFloatArrays()).showArrays(
result,
tilesX,
tilesY,
true,
"NN_disparity-pre",
titles);
*/
//Session sess = b.session();
//System.out.println(b.metaGraphDef());
//final List<String> labels = tensorFlowService.loadLabels(source, //if (corr_side > 0) return 0; // always
// MODEL_NAME, "imagenet_comp_graph_label_strings.txt"); run_stage1(
//System.out.println("Loaded graph and " + labels.size() + " labels"); fb_corr2d, // FloatBuffer fb_corr2d,
fb_target_disparity, // FloatBuffer fb_target_disparity,
fb_tiles_stage1); // IntBuffer fb_tiles_stage1);
run_stage2(
fb_tiles_stage2, // IntBuffer fb_tiles_stage2,
fb_predicted); // FloatBuffer fb_disparity)
//output = sess.runner().feed(o, t).fetch().run().get(0).copyTo()
/* if (!keep_empty) {
try ( for (int i =0; i < result[0].length; i++) {
final Session s = new Session(g); result[0][i]=Float.NaN;
@SuppressWarnings("unchecked") result[1][i]=Float.NaN;
final Tensor<Float> result = (Tensor<Float>) s.runner().feed("input", image)// result[2][i]=Float.NaN;
.fetch("output").run().get(0)
){
...
} }
*/
//}
try (Graph g = new Graph()) {
final String value = "Hello from " + TensorFlow.version();
// Construct the computation graph with a single operation, a constant
// named "MyConst" with a value "value".
try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
} }
// Execute the "MyConst" operation in a Session. fb_predicted.rewind(); // not needed for absolute get();
try (Session s = new Session(g); for (int i = 0; i < tf_tiles.size(); i++) {
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks. TfInTile tf_tile = tf_tiles.get(i);
Tensor output = s.runner().fetch("MyConst").run().get(0)) { result[0][i] = tf_tile.target_disparity + fb_predicted.get(i);
System.out.println(new String(output.bytesValue(), "UTF-8")); result[1][i] = tf_tile.target_disparity;
s.close(); result[2][i] = tf_tile.gt_disparity;
output.close(); result[3][i] = tf_tile.gt_strength;
} result[4][i] = fb_predicted.get(i);
result[5][i] = (result[3][i] > 0)?(result[0][i]-result[2][i]):Float.NaN;
result[6][i] = (result[3][i] > 0)?Math.abs(result[5][i]):Float.NaN;
result[7][i] = (result[3][i] > 0)?Math.abs(result[1][i] - result[2][i]):Float.NaN;
result[8][i] = (result[3][i] > 0)?result[0][i]:Float.NaN;
} }
(new showDoubleFloatArrays()).showArrays(
result,
tilesX,
tilesY,
true,
"NN_disparity",
titles);
return 0;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment