Commit 66eda4b8 authored by Andrey Filippov's avatar Andrey Filippov

Merge branch 'gpu' of git@git.elphel.com:Elphel/imagej-elphel.git into gpu

parents c264a349 03e238e2
...@@ -12,6 +12,7 @@ import org.tensorflow.SavedModelBundle; ...@@ -12,6 +12,7 @@ import org.tensorflow.SavedModelBundle;
import org.tensorflow.OperationBuilder; import org.tensorflow.OperationBuilder;
import org.tensorflow.Shape; import org.tensorflow.Shape;
import org.tensorflow.Output; import org.tensorflow.Output;
import org.tensorflow.Operation;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
...@@ -51,7 +52,9 @@ public class TensorflowExamplePlugin ...@@ -51,7 +52,9 @@ public class TensorflowExamplePlugin
{ {
public final static String EXPORTDIR = "/home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir"; public final static String EXPORTDIR = "/home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir";
public final static String PB_TAG = "model_pb"; // tf.saved_model.tag_constants.SERVING = "serve"
public final static String SERVING = "serve";
public static void run() public static void run()
{ {
...@@ -115,57 +118,60 @@ public class TensorflowExamplePlugin ...@@ -115,57 +118,60 @@ public class TensorflowExamplePlugin
final Graph smpb; final Graph smpb;
// init for variable?
float [][] rv_stage1_out = new float[78408][32]; float [][] rv_stage1_out = new float[78408][32];
// from: infer_qcds_01.py // from infer_qcds_01.py
float [][] img_corr2d = new float[78408][324]; float [][] img_corr2d = new float[78408][324];
float [][] img_target = new float[78408][ 1]; float [][] img_target = new float[78408][ 1];
int [] img_ntile = new int[78408]; int [] img_ntile = new int[78408];
// init ntile // init ntile for testing?
for(int i=0;i<img_ntile.length;i++){ for(int i=0;i<img_ntile.length;i++){
img_ntile[i] = i; img_ntile[i] = i;
} }
/*
* for feed:
* "ph_corr2d": img_corr2d
* "ph_target_disparity": img_target
* "ph_ntile": img_ntile
*
* so it will look like:
*
* https://divis.io/2018/01/enterprise-tensorflow-code-examples/ ->
* https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*
* sess.runner()
* .feed("ph_corr2d",img_corr2d)
* .feed("ph_target_disparity",img_target)
* .feed("ph_ntile",img_ntile)
* .fetch("Disparity_net/stage1done:0")
* .run()
* .get(0)
*/
final SavedModelBundle bundle = SavedModelBundle.load(EXPORTDIR,PB_TAG); final SavedModelBundle bundle = SavedModelBundle.load(EXPORTDIR,SERVING);
final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(5); final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(5);
System.out.println("OK"); System.out.println("OK");
try { try {
// init variable via constant System.out.println("S0:");
Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose); // read Variable info test
Output builder_init = bundle.graph().opBuilder("Const", "rv_stage1_out_init").setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0); Operation opr = bundle.graph().operation("rv_stage1_out");
System.out.println(opr.toString());
System.out.println("S1:");
// init variable via constant?
Tensor<Float> tsr = toTensor2DFloat(rv_stage1_out, tensorsToClose);
Output builder_init = bundle.graph()
.opBuilder("Const", "rv_stage1_out_init")
.setAttr ("dtype", tsr.dataType())
.setAttr ("value", tsr)
.build()
.output(0);
System.out.println(builder_init);
// variable // variable
OperationBuilder builder2 = bundle.graph().opBuilder("Variable", "rv_stage1_out"); OperationBuilder builder2 = bundle.graph().opBuilder("Variable", "rv_stage1_out_extra_variable");
builder2.addInput(builder_init); //.addInput(builder_init);
//Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose); //builder2.
//builder.setAttr("dtype", t.dataType()).setAttr("shape",t.shape()).build().output(0); //bundle.graph().opBuilder("Assign", "Assign/" + builder2.op().name()).addInput(variable).addInput(value).build().output(0);
//Tensor<Float> tensorVal = tsr;
//Output oValue = bundle.graph().opBuilder("Const", "rv_stage1_out_2").setAttr("dtype", tensorVal.dataType()).setAttr("value", tensorVal).build().output(0);
//System.out.println(oValue);
//Output oValue = bundle.graph().opBuilder("Variable", "rv_stage1_out").setAttr("value", tensorVal).build().output(0);
//bundle.graph().opBuilder("Assign", "Assign/rv_stage1_out").setAttr("value", tsr).build();
System.out.println("DONE");
// stage 1 // stage 1
bundle.session().runner() bundle.session().runner()
.feed("ph_corr2d",toTensor2DFloat(img_corr2d, tensorsToClose)) .feed("ph_corr2d",toTensor2DFloat(img_corr2d, tensorsToClose))
...@@ -186,11 +192,12 @@ public class TensorflowExamplePlugin ...@@ -186,11 +192,12 @@ public class TensorflowExamplePlugin
float [] resultValues = (float[]) result.copyTo(new float[78408]); float [] resultValues = (float[]) result.copyTo(new float[78408]);
System.out.println("DONE"); System.out.println("DONE");
} catch (final IllegalStateException ise) { //} catch (final IllegalStateException ise) {
System.out.println("Very Bad Error (VBE): "+ise); // System.out.println("Very Bad Error (VBE): "+ise);
closeTensors(tensorsToClose); // closeTensors(tensorsToClose);
} catch (final NumberFormatException nfe) { } catch (final NumberFormatException nfe) {
//just skip unparsable lines ?! //just skip unparsable lines ?!
} finally { } finally {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment