Commit 7489545f authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

+ downloading model from community

parent 3972a933
...@@ -57,6 +57,11 @@ ...@@ -57,6 +57,11 @@
<artifactId>libtensorflow_jni_gpu</artifactId> <artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.10.0</version> <version>1.10.0</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.ant</groupId>
<artifactId>ant-compress</artifactId>
<version>1.5</version>
</dependency>
<dependency> <dependency>
<groupId>commons-configuration</groupId> <groupId>commons-configuration</groupId>
<artifactId>commons-configuration</artifactId> <artifactId>commons-configuration</artifactId>
......
...@@ -3,10 +3,17 @@ ...@@ -3,10 +3,17 @@
* SPDX-License-Identifier: GPL-3.0-or-later * SPDX-License-Identifier: GPL-3.0-or-later
*/ */
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import java.nio.IntBuffer; import java.nio.IntBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList; import java.util.ArrayList;
import org.apache.ant.compress.taskdefs.Unzip;
import org.tensorflow.SavedModelBundle; import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor; import org.tensorflow.Tensor;
...@@ -47,9 +54,12 @@ Sessions and Tensors ...@@ -47,9 +54,12 @@ Sessions and Tensors
public class TensorflowInferModel public class TensorflowInferModel
{ {
final static String TRAINED_MODEL_URL = "https://community.elphel.com/files/quad-stereo/ml/trained_model_v1.0.zip";
final static String TRAINED_MODEL = "trained_model"; // /home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir"; final static String TRAINED_MODEL = "trained_model"; // /home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir";
final static String SERVING = "serve"; final static String SERVING = "serve";
final int tilesX, tilesY, num_tiles, num_layers; final int tilesX, tilesY, num_tiles, num_layers;
final int corr_side, corr_side2; final int corr_side, corr_side2;
// final long [] shape_corr2d; // final long [] shape_corr2d;
...@@ -68,6 +78,40 @@ long[] shape = new long[] {batch, imageSize}; ...@@ -68,6 +78,40 @@ long[] shape = new long[] {batch, imageSize};
// public // public
final SavedModelBundle bundle; final SavedModelBundle bundle;
// utils: download url
private static Path download(String sourceURL, String targetDirectory) throws IOException
{
URL url = new URL(sourceURL);
String fileName = sourceURL.substring(sourceURL.lastIndexOf('/') + 1, sourceURL.length());
Path targetPath = new File(targetDirectory + File.separator + fileName).toPath();
Files.copy(url.openStream(), targetPath, StandardCopyOption.REPLACE_EXISTING);
return targetPath;
}
// utils: unpack zip to dir
private static boolean download_and_unpack(String sourceURL, String targetDirectory) {
Path zipped_model = null;
System.out.println("Downloading "+sourceURL+". Please, wait...");
try {
zipped_model = download(sourceURL, targetDirectory);
} catch (IOException e){
e.printStackTrace();
return false;
}
Unzip unzipper = new Unzip();
unzipper.setSrc(zipped_model.toFile());
unzipper.setDest(new File(targetDirectory));
unzipper.execute();
System.out.println(unzipper.getLocation());
return true;
}
public TensorflowInferModel(int tilesX, int tilesY, int corr_side, int num_layers) public TensorflowInferModel(int tilesX, int tilesY, int corr_side, int num_layers)
{ {
this.tilesX = tilesX; this.tilesX = tilesX;
...@@ -83,11 +127,26 @@ long[] shape = new long[] {batch, imageSize}; ...@@ -83,11 +127,26 @@ long[] shape = new long[] {batch, imageSize};
this.fb_tiles_stage2 = IntBuffer.allocate(num_tiles ); this.fb_tiles_stage2 = IntBuffer.allocate(num_tiles );
this.fb_predicted = FloatBuffer.allocate(num_tiles ); this.fb_predicted = FloatBuffer.allocate(num_tiles );
String abs_model_path = getClass().getClassLoader().getResource(TRAINED_MODEL).getFile(); //String resourceDir = System.getProperty("user.dir")+"/src/main/resources";
// ./target/classes/
String resourceDir = getClass().getClassLoader().getResource("").getFile();
String abs_model_path = null;
try {
abs_model_path = getClass().getClassLoader().getResource(TRAINED_MODEL).getFile();
} catch (java.lang.NullPointerException e) {
//e.printStackTrace();
download_and_unpack(TRAINED_MODEL_URL, resourceDir);
// re-read
abs_model_path = getClass().getClassLoader().getResource(TRAINED_MODEL).getFile();
System.out.println("New downloaded path: "+abs_model_path);
}
System.out.println("TensorflowInferModel model path: "+abs_model_path); System.out.println("TensorflowInferModel model path: "+abs_model_path);
// this will load graph/data and open a session that does not need to be closed until the program is closed // this will load graph/data and open a session that does not need to be closed until the program is closed
//// bundle = null; //// bundle = null;
bundle = SavedModelBundle.load(abs_model_path,SERVING); bundle = SavedModelBundle.load(abs_model_path, SERVING);
// Operation opr = bundle.graph().operation("rv_stage1_out"); // Operation opr = bundle.graph().operation("rv_stage1_out");
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment