Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
imagej-elphel
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
3
Issues
3
List
Board
Labels
Milestones
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
Elphel
imagej-elphel
Commits
b7a74a73
Commit
b7a74a73
authored
Sep 18, 2018
by
Oleg Dzhimiev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
java tf plugin
parent
658a3f2c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
142 additions
and
5 deletions
+142
-5
TensorflowExamplePlugin.java
src/main/java/TensorflowExamplePlugin.java
+142
-5
No files found.
src/main/java/TensorflowExamplePlugin.java
View file @
b7a74a73
...
@@ -6,8 +6,16 @@
...
@@ -6,8 +6,16 @@
import
org.tensorflow.Graph
;
import
org.tensorflow.Graph
;
import
org.tensorflow.Session
;
import
org.tensorflow.Session
;
import
org.tensorflow.Tensor
;
import
org.tensorflow.Tensor
;
import
org.tensorflow.Tensors
;
import
org.tensorflow.TensorFlow
;
import
org.tensorflow.TensorFlow
;
import
org.tensorflow.SavedModelBundle
;
import
org.tensorflow.SavedModelBundle
;
import
org.tensorflow.OperationBuilder
;
import
org.tensorflow.Shape
;
import
org.tensorflow.Output
;
import
java.util.ArrayList
;
import
java.util.Collection
;
import
java.util.List
;
/**
/**
...
@@ -56,23 +64,152 @@ public class TensorflowExamplePlugin
...
@@ -56,23 +64,152 @@ public class TensorflowExamplePlugin
}
}
}
}
/**
* From https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*/
/**
* wraps a single float in a tensor
* @param f the float to wrap
* @return a tensor containing the float
*/
private
static
Tensor
<
Float
>
toTensor
(
final
float
f
,
final
Collection
<
Tensor
<?>>
tensorsToClose
)
{
final
Tensor
<
Float
>
t
=
Tensors
.
create
(
f
);
if
(
tensorsToClose
!=
null
)
{
tensorsToClose
.
add
(
t
);
}
return
t
;
}
private
static
Tensor
<
Float
>
toTensor2DFloat
(
final
float
[][]
f
,
final
Collection
<
Tensor
<?>>
tensorsToClose
)
{
final
Tensor
<
Float
>
t
=
Tensors
.
create
(
f
);
if
(
tensorsToClose
!=
null
)
{
tensorsToClose
.
add
(
t
);
}
return
t
;
}
private
static
Tensor
<
Integer
>
toTensor1DInt
(
final
int
[]
f
,
final
Collection
<
Tensor
<?>>
tensorsToClose
)
{
final
Tensor
<
Integer
>
t
=
Tensors
.
create
(
f
);
if
(
tensorsToClose
!=
null
)
{
tensorsToClose
.
add
(
t
);
}
return
t
;
}
private
static
void
closeTensors
(
final
Collection
<
Tensor
<?>>
ts
)
{
for
(
final
Tensor
<?>
t
:
ts
)
{
try
{
t
.
close
();
}
catch
(
final
Exception
e
)
{
System
.
err
.
println
(
"Error closing Tensor."
);
e
.
printStackTrace
();
}
}
ts
.
clear
();
}
public
static
void
main
()
throws
Exception
{
public
static
void
main
()
throws
Exception
{
final
Graph
smpb
;
final
Graph
smpb
;
float
[][]
rv_stage1_out
=
new
float
[
78408
][
32
];
// from: infer_qcds_01.py
float
[][]
img_corr2d
=
new
float
[
78408
][
324
];
float
[][]
img_corr2d
=
new
float
[
78408
][
324
];
float
[][]
img_target
=
new
float
[
78408
][
1
];
float
[][]
img_target
=
new
float
[
78408
][
1
];
int
[]
img_ntile
=
new
int
[
78408
];
int
[]
img_ntile
=
new
int
[
78408
];
// init ntile
// init ntile
for
(
int
i
=
0
;
i
<
img_ntile
.
length
;
i
++){
for
(
int
i
=
0
;
i
<
img_ntile
.
length
;
i
++){
img_ntile
[
i
]
=
i
;
img_ntile
[
i
]
=
i
;
}
}
try
(
SavedModelBundle
b
=
SavedModelBundle
.
load
(
EXPORTDIR
,
PB_TAG
)){
/*
System
.
out
.
println
(
"OK"
);
* for feed:
smpb
=
b
.
graph
();
* "ph_corr2d": img_corr2d
* "ph_target_disparity": img_target
* "ph_ntile": img_ntile
*
* so it will look like:
*
* https://divis.io/2018/01/enterprise-tensorflow-code-examples/ ->
* https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*
* sess.runner()
* .feed("ph_corr2d",img_corr2d)
* .feed("ph_target_disparity",img_target)
* .feed("ph_ntile",img_ntile)
* .fetch("Disparity_net/stage1done:0")
* .run()
* .get(0)
*/
final
SavedModelBundle
bundle
=
SavedModelBundle
.
load
(
EXPORTDIR
,
PB_TAG
);
final
List
<
Tensor
<?>>
tensorsToClose
=
new
ArrayList
<
Tensor
<?>>(
5
);
System
.
out
.
println
(
"OK"
);
try
{
// init variable via constant
Tensor
<
Float
>
t
=
toTensor2DFloat
(
rv_stage1_out
,
tensorsToClose
);
Output
builder_init
=
bundle
.
graph
().
opBuilder
(
"Const"
,
"rv_stage1_out_init"
).
setAttr
(
"dtype"
,
t
.
dataType
()).
setAttr
(
"value"
,
t
).
build
().
output
(
0
);
// variable
OperationBuilder
builder2
=
bundle
.
graph
().
opBuilder
(
"Variable"
,
"rv_stage1_out"
);
builder2
.
addInput
(
builder_init
);
//Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose);
//builder.setAttr("dtype", t.dataType()).setAttr("shape",t.shape()).build().output(0);
// stage 1
bundle
.
session
().
runner
()
.
feed
(
"ph_corr2d"
,
toTensor2DFloat
(
img_corr2d
,
tensorsToClose
))
.
feed
(
"ph_target_disparity"
,
toTensor2DFloat
(
img_target
,
tensorsToClose
))
.
feed
(
"ph_ntile"
,
toTensor1DInt
(
img_ntile
,
tensorsToClose
))
.
fetch
(
"Disparity_net/stage1done:0"
)
.
run
()
.
get
(
0
);
// stage 2
final
Tensor
<?>
result
=
bundle
.
session
().
runner
()
.
feed
(
"ph_ntile"
,
toTensor1DInt
(
img_ntile
,
tensorsToClose
))
.
fetch
(
"Disparity_net/stage2_out_sparse:0"
)
.
run
()
.
get
(
0
);
tensorsToClose
.
add
(
result
);
float
[]
resultValues
=
(
float
[])
result
.
copyTo
(
new
float
[
78408
]);
System
.
out
.
println
(
"DONE"
);
}
catch
(
final
IllegalStateException
ise
)
{
System
.
out
.
println
(
"Very Bad Error (VBE): "
+
ise
);
closeTensors
(
tensorsToClose
);
}
catch
(
final
NumberFormatException
nfe
)
{
//just skip unparsable lines ?!
}
finally
{
closeTensors
(
tensorsToClose
);
}
//try (){
//smpb = b.graph();
//Session sess = b.session();
//System.out.println(b.metaGraphDef());
//final List<String> labels = tensorFlowService.loadLabels(source,
// MODEL_NAME, "imagenet_comp_graph_label_strings.txt");
//System.out.println("Loaded graph and " + labels.size() + " labels");
//output = sess.runner().feed(o, t).fetch().run().get(0).copyTo()
/*
/*
try (
try (
final Session s = new Session(g);
final Session s = new Session(g);
...
@@ -84,7 +221,7 @@ public class TensorflowExamplePlugin
...
@@ -84,7 +221,7 @@ public class TensorflowExamplePlugin
}
}
*/
*/
}
//
}
try
(
Graph
g
=
new
Graph
())
{
try
(
Graph
g
=
new
Graph
())
{
final
String
value
=
"Hello from "
+
TensorFlow
.
version
();
final
String
value
=
"Hello from "
+
TensorFlow
.
version
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment