Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
imagej-elphel
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
3
Issues
3
List
Board
Labels
Milestones
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
Elphel
imagej-elphel
Commits
7489545f
Commit
7489545f
authored
Sep 24, 2018
by
Oleg Dzhimiev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
+ downloading model from community
parent
3972a933
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
3 deletions
+67
-3
pom.xml
pom.xml
+5
-0
TensorflowInferModel.java
src/main/java/TensorflowInferModel.java
+62
-3
No files found.
pom.xml
View file @
7489545f
...
...
@@ -57,6 +57,11 @@
<artifactId>
libtensorflow_jni_gpu
</artifactId>
<version>
1.10.0
</version>
</dependency>
<dependency>
<groupId>
org.apache.ant
</groupId>
<artifactId>
ant-compress
</artifactId>
<version>
1.5
</version>
</dependency>
<dependency>
<groupId>
commons-configuration
</groupId>
<artifactId>
commons-configuration
</artifactId>
...
...
src/main/java/TensorflowInferModel.java
View file @
7489545f
...
...
@@ -3,10 +3,17 @@
* SPDX-License-Identifier: GPL-3.0-or-later
*/
import
java.io.File
;
import
java.io.IOException
;
import
java.net.URL
;
import
java.nio.FloatBuffer
;
import
java.nio.IntBuffer
;
import
java.nio.file.Files
;
import
java.nio.file.Path
;
import
java.nio.file.StandardCopyOption
;
import
java.util.ArrayList
;
import
org.apache.ant.compress.taskdefs.Unzip
;
import
org.tensorflow.SavedModelBundle
;
import
org.tensorflow.Tensor
;
...
...
@@ -47,9 +54,12 @@ Sessions and Tensors
public
class
TensorflowInferModel
{
final
static
String
TRAINED_MODEL_URL
=
"https://community.elphel.com/files/quad-stereo/ml/trained_model_v1.0.zip"
;
final
static
String
TRAINED_MODEL
=
"trained_model"
;
// /home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir";
final
static
String
SERVING
=
"serve"
;
final
int
tilesX
,
tilesY
,
num_tiles
,
num_layers
;
final
int
corr_side
,
corr_side2
;
// final long [] shape_corr2d;
...
...
@@ -68,6 +78,40 @@ long[] shape = new long[] {batch, imageSize};
// public
final
SavedModelBundle
bundle
;
// utils: download url
private
static
Path
download
(
String
sourceURL
,
String
targetDirectory
)
throws
IOException
{
URL
url
=
new
URL
(
sourceURL
);
String
fileName
=
sourceURL
.
substring
(
sourceURL
.
lastIndexOf
(
'/'
)
+
1
,
sourceURL
.
length
());
Path
targetPath
=
new
File
(
targetDirectory
+
File
.
separator
+
fileName
).
toPath
();
Files
.
copy
(
url
.
openStream
(),
targetPath
,
StandardCopyOption
.
REPLACE_EXISTING
);
return
targetPath
;
}
// utils: unpack zip to dir
private
static
boolean
download_and_unpack
(
String
sourceURL
,
String
targetDirectory
)
{
Path
zipped_model
=
null
;
System
.
out
.
println
(
"Downloading "
+
sourceURL
+
". Please, wait..."
);
try
{
zipped_model
=
download
(
sourceURL
,
targetDirectory
);
}
catch
(
IOException
e
){
e
.
printStackTrace
();
return
false
;
}
Unzip
unzipper
=
new
Unzip
();
unzipper
.
setSrc
(
zipped_model
.
toFile
());
unzipper
.
setDest
(
new
File
(
targetDirectory
));
unzipper
.
execute
();
System
.
out
.
println
(
unzipper
.
getLocation
());
return
true
;
}
public
TensorflowInferModel
(
int
tilesX
,
int
tilesY
,
int
corr_side
,
int
num_layers
)
{
this
.
tilesX
=
tilesX
;
...
...
@@ -83,11 +127,26 @@ long[] shape = new long[] {batch, imageSize};
this
.
fb_tiles_stage2
=
IntBuffer
.
allocate
(
num_tiles
);
this
.
fb_predicted
=
FloatBuffer
.
allocate
(
num_tiles
);
String
abs_model_path
=
getClass
().
getClassLoader
().
getResource
(
TRAINED_MODEL
).
getFile
();
//String resourceDir = System.getProperty("user.dir")+"/src/main/resources";
// ./target/classes/
String
resourceDir
=
getClass
().
getClassLoader
().
getResource
(
""
).
getFile
();
String
abs_model_path
=
null
;
try
{
abs_model_path
=
getClass
().
getClassLoader
().
getResource
(
TRAINED_MODEL
).
getFile
();
}
catch
(
java
.
lang
.
NullPointerException
e
)
{
//e.printStackTrace();
download_and_unpack
(
TRAINED_MODEL_URL
,
resourceDir
);
// re-read
abs_model_path
=
getClass
().
getClassLoader
().
getResource
(
TRAINED_MODEL
).
getFile
();
System
.
out
.
println
(
"New downloaded path: "
+
abs_model_path
);
}
System
.
out
.
println
(
"TensorflowInferModel model path: "
+
abs_model_path
);
// this will load graph/data and open a session that does not need to be closed until the program is closed
//// bundle = null;
bundle
=
SavedModelBundle
.
load
(
abs_model_path
,
SERVING
);
bundle
=
SavedModelBundle
.
load
(
abs_model_path
,
SERVING
);
// Operation opr = bundle.graph().operation("rv_stage1_out");
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment