Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
imagej-elphel
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
3
Issues
3
List
Board
Labels
Milestones
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
Elphel
imagej-elphel
Commits
66eda4b8
Commit
66eda4b8
authored
Sep 19, 2018
by
Andrey Filippov
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'gpu' of git@git.elphel.com:Elphel/imagej-elphel.git into gpu
parents
c264a349
03e238e2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
35 deletions
+42
-35
TensorflowExamplePlugin.java
src/main/java/TensorflowExamplePlugin.java
+42
-35
No files found.
src/main/java/TensorflowExamplePlugin.java
View file @
66eda4b8
...
@@ -12,6 +12,7 @@ import org.tensorflow.SavedModelBundle;
...
@@ -12,6 +12,7 @@ import org.tensorflow.SavedModelBundle;
import
org.tensorflow.OperationBuilder
;
import
org.tensorflow.OperationBuilder
;
import
org.tensorflow.Shape
;
import
org.tensorflow.Shape
;
import
org.tensorflow.Output
;
import
org.tensorflow.Output
;
import
org.tensorflow.Operation
;
import
java.util.ArrayList
;
import
java.util.ArrayList
;
import
java.util.Collection
;
import
java.util.Collection
;
...
@@ -51,7 +52,9 @@ public class TensorflowExamplePlugin
...
@@ -51,7 +52,9 @@ public class TensorflowExamplePlugin
{
{
public
final
static
String
EXPORTDIR
=
"/home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir"
;
public
final
static
String
EXPORTDIR
=
"/home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir"
;
public
final
static
String
PB_TAG
=
"model_pb"
;
// tf.saved_model.tag_constants.SERVING = "serve"
public
final
static
String
SERVING
=
"serve"
;
public
static
void
run
()
public
static
void
run
()
{
{
...
@@ -115,57 +118,60 @@ public class TensorflowExamplePlugin
...
@@ -115,57 +118,60 @@ public class TensorflowExamplePlugin
final
Graph
smpb
;
final
Graph
smpb
;
// init for variable?
float
[][]
rv_stage1_out
=
new
float
[
78408
][
32
];
float
[][]
rv_stage1_out
=
new
float
[
78408
][
32
];
// from
:
infer_qcds_01.py
// from infer_qcds_01.py
float
[][]
img_corr2d
=
new
float
[
78408
][
324
];
float
[][]
img_corr2d
=
new
float
[
78408
][
324
];
float
[][]
img_target
=
new
float
[
78408
][
1
];
float
[][]
img_target
=
new
float
[
78408
][
1
];
int
[]
img_ntile
=
new
int
[
78408
];
int
[]
img_ntile
=
new
int
[
78408
];
// init ntile
// init ntile
for testing?
for
(
int
i
=
0
;
i
<
img_ntile
.
length
;
i
++){
for
(
int
i
=
0
;
i
<
img_ntile
.
length
;
i
++){
img_ntile
[
i
]
=
i
;
img_ntile
[
i
]
=
i
;
}
}
/*
* for feed:
* "ph_corr2d": img_corr2d
* "ph_target_disparity": img_target
* "ph_ntile": img_ntile
*
* so it will look like:
*
* https://divis.io/2018/01/enterprise-tensorflow-code-examples/ ->
* https://github.com/DIVSIO/tensorflow_java_cli_example/blob/master/src/main/java/divisio/example/tensorflow/cli/RunRegression.java
*
* sess.runner()
* .feed("ph_corr2d",img_corr2d)
* .feed("ph_target_disparity",img_target)
* .feed("ph_ntile",img_ntile)
* .fetch("Disparity_net/stage1done:0")
* .run()
* .get(0)
*/
final
SavedModelBundle
bundle
=
SavedModelBundle
.
load
(
EXPORTDIR
,
PB_TA
G
);
final
SavedModelBundle
bundle
=
SavedModelBundle
.
load
(
EXPORTDIR
,
SERVIN
G
);
final
List
<
Tensor
<?>>
tensorsToClose
=
new
ArrayList
<
Tensor
<?>>(
5
);
final
List
<
Tensor
<?>>
tensorsToClose
=
new
ArrayList
<
Tensor
<?>>(
5
);
System
.
out
.
println
(
"OK"
);
System
.
out
.
println
(
"OK"
);
try
{
try
{
// init variable via constant
System
.
out
.
println
(
"S0:"
);
Tensor
<
Float
>
t
=
toTensor2DFloat
(
rv_stage1_out
,
tensorsToClose
);
// read Variable info test
Output
builder_init
=
bundle
.
graph
().
opBuilder
(
"Const"
,
"rv_stage1_out_init"
).
setAttr
(
"dtype"
,
t
.
dataType
()).
setAttr
(
"value"
,
t
).
build
().
output
(
0
);
Operation
opr
=
bundle
.
graph
().
operation
(
"rv_stage1_out"
);
System
.
out
.
println
(
opr
.
toString
());
System
.
out
.
println
(
"S1:"
);
// init variable via constant?
Tensor
<
Float
>
tsr
=
toTensor2DFloat
(
rv_stage1_out
,
tensorsToClose
);
Output
builder_init
=
bundle
.
graph
()
.
opBuilder
(
"Const"
,
"rv_stage1_out_init"
)
.
setAttr
(
"dtype"
,
tsr
.
dataType
())
.
setAttr
(
"value"
,
tsr
)
.
build
()
.
output
(
0
);
System
.
out
.
println
(
builder_init
);
// variable
// variable
OperationBuilder
builder2
=
bundle
.
graph
().
opBuilder
(
"Variable"
,
"rv_stage1_out"
);
OperationBuilder
builder2
=
bundle
.
graph
().
opBuilder
(
"Variable"
,
"rv_stage1_out
_extra_variable
"
);
builder2
.
addInput
(
builder_init
);
//
.addInput(builder_init);
//
Tensor<Float> t = toTensor2DFloat(rv_stage1_out, tensorsToClose);
//
builder2.
//bu
ilder.setAttr("dtype", t.dataType()).setAttr("shape",t.shape()
).build().output(0);
//bu
ndle.graph().opBuilder("Assign", "Assign/" + builder2.op().name()).addInput(variable).addInput(value
).build().output(0);
//Tensor<Float> tensorVal = tsr;
//Output oValue = bundle.graph().opBuilder("Const", "rv_stage1_out_2").setAttr("dtype", tensorVal.dataType()).setAttr("value", tensorVal).build().output(0);
//System.out.println(oValue);
//Output oValue = bundle.graph().opBuilder("Variable", "rv_stage1_out").setAttr("value", tensorVal).build().output(0);
//bundle.graph().opBuilder("Assign", "Assign/rv_stage1_out").setAttr("value", tsr).build();
System
.
out
.
println
(
"DONE"
);
// stage 1
// stage 1
bundle
.
session
().
runner
()
bundle
.
session
().
runner
()
.
feed
(
"ph_corr2d"
,
toTensor2DFloat
(
img_corr2d
,
tensorsToClose
))
.
feed
(
"ph_corr2d"
,
toTensor2DFloat
(
img_corr2d
,
tensorsToClose
))
...
@@ -186,11 +192,12 @@ public class TensorflowExamplePlugin
...
@@ -186,11 +192,12 @@ public class TensorflowExamplePlugin
float
[]
resultValues
=
(
float
[])
result
.
copyTo
(
new
float
[
78408
]);
float
[]
resultValues
=
(
float
[])
result
.
copyTo
(
new
float
[
78408
]);
System
.
out
.
println
(
"DONE"
);
System
.
out
.
println
(
"DONE"
);
}
catch
(
final
IllegalStateException
ise
)
{
//
} catch (final IllegalStateException ise) {
System
.
out
.
println
(
"Very Bad Error (VBE): "
+
ise
);
//
System.out.println("Very Bad Error (VBE): "+ise);
closeTensors
(
tensorsToClose
);
//
closeTensors(tensorsToClose);
}
catch
(
final
NumberFormatException
nfe
)
{
}
catch
(
final
NumberFormatException
nfe
)
{
//just skip unparsable lines ?!
//just skip unparsable lines ?!
}
finally
{
}
finally
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment