Commit b7a74a73 authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

java tf plugin

parent 658a3f2c
......@@ -6,8 +6,16 @@
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.OperationBuilder;
import org.tensorflow.Shape;
import org.tensorflow.Output;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
......@@ -56,23 +64,152 @@ public class TensorflowExamplePlugin
}
}
/**
* From https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*/
/**
* wraps a single float in a tensor
* @param f the float to wrap
* @return a tensor containing the float
*/
private static Tensor<Float> toTensor(final float f, final Collection<Tensor<?>> tensorsToClose) {
final Tensor<Float> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static Tensor<Float> toTensor2DFloat(final float [][] f, final Collection<Tensor<?>> tensorsToClose) {
final Tensor<Float> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static Tensor<Integer> toTensor1DInt(final int [] f, final Collection<Tensor<?>> tensorsToClose) {
final Tensor<Integer> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static void closeTensors(final Collection<Tensor<?>> ts) {
for (final Tensor<?> t : ts) {
try {
t.close();
} catch (final Exception e) {
System.err.println("Error closing Tensor.");
e.printStackTrace();
}
}
ts.clear();
}
public static void main() throws Exception {
final Graph smpb;
float [][] rv_stage1_out = new float[78408][32];
// from: infer_qcds_01.py
float [][] img_corr2d = new float[78408][324];
float [][] img_target = new float[78408][ 1];
int [] img_ntile = new int[78408];
// init ntile
for(int i=0;i<img_ntile.length;i++){
img_ntile[i] = i;
}
try (SavedModelBundle b = SavedModelBundle.load(EXPORTDIR,PB_TAG)){
System.out.println("OK");
smpb = b.graph();
/*
* for feed:
* "ph_corr2d": img_corr2d
* "ph_target_disparity": img_target
* "ph_ntile": img_ntile
*
* so it will look like:
*
* https://divis.io/2018/01/enterprise-tensorflow-code-examples/ ->
* https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*
* sess.runner()
* .feed("ph_corr2d",img_corr2d)
* .feed("ph_target_disparity",img_target)
* .feed("ph_ntile",img_ntile)
* .fetch("Disparity_net/stage1done:0")
* .run()
* .get(0)
*/
final SavedModelBundle bundle = SavedModelBundle.load(EXPORTDIR,PB_TAG);
final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(5);
System.out.println("OK");
try {
// init variable via constant
Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose);
Output builder_init = bundle.graph().opBuilder("Const", "rv_stage1_out_init").setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);
// variable
OperationBuilder builder2 = bundle.graph().opBuilder("Variable", "rv_stage1_out");
builder2.addInput(builder_init);
//Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose);
//builder.setAttr("dtype", t.dataType()).setAttr("shape",t.shape()).build().output(0);
// stage 1
bundle.session().runner()
.feed("ph_corr2d",toTensor2DFloat(img_corr2d, tensorsToClose))
.feed("ph_target_disparity",toTensor2DFloat(img_target, tensorsToClose))
.feed("ph_ntile",toTensor1DInt(img_ntile, tensorsToClose))
.fetch("Disparity_net/stage1done:0")
.run()
.get(0);
// stage 2
final Tensor<?> result = bundle.session().runner()
.feed("ph_ntile",toTensor1DInt(img_ntile, tensorsToClose))
.fetch("Disparity_net/stage2_out_sparse:0")
.run()
.get(0);
tensorsToClose.add(result);
float [] resultValues = (float[]) result.copyTo(new float[78408]);
System.out.println("DONE");
} catch (final IllegalStateException ise) {
System.out.println("Very Bad Error (VBE): "+ise);
closeTensors(tensorsToClose);
} catch (final NumberFormatException nfe) {
//just skip unparsable lines ?!
} finally {
closeTensors(tensorsToClose);
}
//try (){
//smpb = b.graph();
//Session sess = b.session();
//System.out.println(b.metaGraphDef());
//final List<String> labels = tensorFlowService.loadLabels(source,
// MODEL_NAME, "imagenet_comp_graph_label_strings.txt");
//System.out.println("Loaded graph and " + labels.size() + " labels");
//output = sess.runner().feed(o, t).fetch().run().get(0).copyTo()
/*
try (
final Session s = new Session(g);
......@@ -84,7 +221,7 @@ public class TensorflowExamplePlugin
}
*/
}
//}
try (Graph g = new Graph()) {
final String value = "Hello from " + TensorFlow.version();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment