Commit 66eda4b8 authored by Andrey Filippov's avatar Andrey Filippov

Merge branch 'gpu' of git@git.elphel.com:Elphel/imagej-elphel.git into gpu

parents c264a349 03e238e2
......@@ -12,6 +12,7 @@ import org.tensorflow.SavedModelBundle;
import org.tensorflow.OperationBuilder;
import org.tensorflow.Shape;
import org.tensorflow.Output;
import org.tensorflow.Operation;
import java.util.ArrayList;
import java.util.Collection;
......@@ -51,7 +52,9 @@ public class TensorflowExamplePlugin
{
public final static String EXPORTDIR = "/home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir";
public final static String PB_TAG = "model_pb";
// tf.saved_model.tag_constants.SERVING = "serve"
public final static String SERVING = "serve";
public static void run()
{
......@@ -115,39 +118,20 @@ public class TensorflowExamplePlugin
final Graph smpb;
// init for variable?
float [][] rv_stage1_out = new float[78408][32];
// from: infer_qcds_01.py
// from infer_qcds_01.py
float [][] img_corr2d = new float[78408][324];
float [][] img_target = new float[78408][ 1];
int [] img_ntile = new int[78408];
// init ntile
// init ntile for testing?
for(int i=0;i<img_ntile.length;i++){
img_ntile[i] = i;
}
/*
* for feed:
* "ph_corr2d": img_corr2d
* "ph_target_disparity": img_target
* "ph_ntile": img_ntile
*
* so it will look like:
*
* https://divis.io/2018/01/enterprise-tensorflow-code-examples/ ->
* https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*
* sess.runner()
* .feed("ph_corr2d",img_corr2d)
* .feed("ph_target_disparity",img_target)
* .feed("ph_ntile",img_ntile)
* .fetch("Disparity_net/stage1done:0")
* .run()
* .get(0)
*/
final SavedModelBundle bundle = SavedModelBundle.load(EXPORTDIR,PB_TAG);
final SavedModelBundle bundle = SavedModelBundle.load(EXPORTDIR,SERVING);
final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(5);
......@@ -155,16 +139,38 @@ public class TensorflowExamplePlugin
try {
// init variable via constant
Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose);
Output builder_init = bundle.graph().opBuilder("Const", "rv_stage1_out_init").setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);
System.out.println("S0:");
// read Variable info test
Operation opr = bundle.graph().operation("rv_stage1_out");
System.out.println(opr.toString());
System.out.println("S1:");
// init variable via constant?
Tensor<Float> tsr = toTensor2DFloat(rv_stage1_out, tensorsToClose);
Output builder_init = bundle.graph()
.opBuilder("Const", "rv_stage1_out_init")
.setAttr ("dtype", tsr.dataType())
.setAttr ("value", tsr)
.build()
.output(0);
System.out.println(builder_init);
// variable
OperationBuilder builder2 = bundle.graph().opBuilder("Variable", "rv_stage1_out");
builder2.addInput(builder_init);
OperationBuilder builder2 = bundle.graph().opBuilder("Variable", "rv_stage1_out_extra_variable");
//.addInput(builder_init);
//builder2.
//bundle.graph().opBuilder("Assign", "Assign/" + builder2.op().name()).addInput(variable).addInput(value).build().output(0);
//Tensor<Float> tensorVal = tsr;
//Output oValue = bundle.graph().opBuilder("Const", "rv_stage1_out_2").setAttr("dtype", tensorVal.dataType()).setAttr("value", tensorVal).build().output(0);
//System.out.println(oValue);
//Output oValue = bundle.graph().opBuilder("Variable", "rv_stage1_out").setAttr("value", tensorVal).build().output(0);
//bundle.graph().opBuilder("Assign", "Assign/rv_stage1_out").setAttr("value", tsr).build();
//Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose);
//builder.setAttr("dtype", t.dataType()).setAttr("shape",t.shape()).build().output(0);
System.out.println("DONE");
// stage 1
bundle.session().runner()
......@@ -186,11 +192,12 @@ public class TensorflowExamplePlugin
float [] resultValues = (float[]) result.copyTo(new float[78408]);
System.out.println("DONE");
} catch (final IllegalStateException ise) {
System.out.println("Very Bad Error (VBE): "+ise);
closeTensors(tensorsToClose);
//} catch (final IllegalStateException ise) {
// System.out.println("Very Bad Error (VBE): "+ise);
// closeTensors(tensorsToClose);
} catch (final NumberFormatException nfe) {
//just skip unparsable lines ?!
} finally {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment