Commit f19ac162 authored by Andrey Filippov's avatar Andrey Filippov

CLAUDE: Add ONNX-Runtime Java inference for C5P DNN front-end (Phase 1, CPU)

PyTorch-trained, ONNX-exported all-conv FCN (per-pixel Vx,Vy,s) run in Java via
ONNX Runtime 1.20.0 (CPU EP). CuasDnnInfer loads the model with a location resolver
(local path / scp user@host:path / http(s) -> ~/.cache/c5p_dnn/, fetching the
model.onnx + external-data .data pair) and runs a float[N][H][W] patch to raw
det/vel/off output. Verified bit-exact vs PyTorch (max abs diff 2.9e-6) via a fixed
test vector. New config param curt_dnn_model (empty default) selects the model,
mirroring the tile_processor_gpu kernel-source default/override scheme. CPU first
(.224 has no cuDNN; 82k net is microseconds/patch); GPU (CUDA/TensorRT EP) and the
CuasDetectRT integration are the next phase.
Co-Authored-By: 's avatarClaude Opus 4.8 (1M context) <noreply@anthropic.com>
parent 77f8ce01
......@@ -76,6 +76,14 @@
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.10.0</version>
</dependency>
<!-- ONNX Runtime (CPU) for the C5P DNN front-end inference (PyTorch-trained, ONNX-exported). // By Claude on 06/13/2026
CPU first - the 82k net is microseconds/patch and .224 has no cuDNN; swap to
onnxruntime_gpu once cuDNN 9 is present (CUDA EP) or use the TensorRT EP. -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.20.0</version>
</dependency>
<dependency>
<groupId>org.apache.ant</groupId>
......
package com.elphel.imagej.cuas.rt;
// C5P DNN front-end inference via ONNX Runtime (CPU). // By Claude on 06/13/2026
// Loads a PyTorch-trained, ONNX-exported all-conv FCN that maps an N-frame patch stack to a
// per-pixel (Vx,Vy,s) output. Phase 1: CPU, fed from Java float arrays (the convenient testing
// path, like the old TensorFlow integration). Phase 2 (later): onnxruntime_gpu CUDA/TensorRT EP,
// device chosen via the existing GPU-detection logic; eventual zero-copy from JCuda CUdeviceptr.
//
// Model location (curt_dnn_model): resolved as a local path, scp "user@host:/path", or http(s)
// URL; remote specs are fetched to ~/.cache/c5p_dnn/ together with the external-data sibling
// (model.onnx + model.onnx.data, which torch 2.9's exporter splits).
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.util.Collections;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
public class CuasDnnInfer implements AutoCloseable {
private final OrtEnvironment env;
private final OrtSession session;
private final String inputName;
private final int nframes;
public CuasDnnInfer(String modelSpec, int nframes) throws Exception {
String local = resolveModel(modelSpec);
this.nframes = nframes;
this.env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// CPU EP by default. GPU later: opts.addCUDA(deviceId) once cuDNN 9 is present.
this.session = env.createSession(local, opts);
this.inputName = session.getInputNames().iterator().next(); // "frames"
System.out.println("CuasDnnInfer: loaded " + local + " input=" + inputName
+ " outputs=" + session.getOutputNames());
}
/** Run one patch. frames = [N][H][W] (newest first). Returns the raw output channels
* [outCh] for the center output pixel: [0]=det logit, [1..V*V]=vel logits, last 2 = offset. */
public float[] inferRaw(float[][][] frames) throws Exception {
int N = frames.length, H = frames[0].length, W = frames[0][0].length;
FloatBuffer fb = FloatBuffer.allocate(N * H * W);
for (int n = 0; n < N; n++)
for (int y = 0; y < H; y++)
fb.put(frames[n][y]);
fb.rewind();
try (OnnxTensor in = OnnxTensor.createTensor(env, fb, new long[] {1, N, H, W});
OrtSession.Result res = session.run(Collections.singletonMap(inputName, in))) {
float[][][][] out = (float[][][][]) res.get(0).getValue(); // [1, outCh, 1, 1]
int outCh = out[0].length;
float[] flat = new float[outCh];
for (int c = 0; c < outCh; c++) flat[c] = out[0][c][0][0];
return flat;
}
}
public static float sigmoid(float x) { return (float) (1.0 / (1.0 + Math.exp(-x))); }
@Override public void close() throws Exception { session.close(); }
// ---- model location resolver (mirrors the tile_processor_gpu default/override scheme) ----
static String resolveModel(String spec) throws Exception {
if (spec == null || spec.trim().isEmpty())
throw new IllegalArgumentException("CuasDnnInfer: empty model spec (curt_dnn_model)");
spec = spec.trim();
if (new File(spec).isFile()) return spec; // local path
String cacheDir = System.getProperty("user.home") + "/.cache/c5p_dnn/";
new File(cacheDir).mkdirs();
String name = spec.substring(spec.lastIndexOf('/') + 1);
String localOnnx = cacheDir + name;
if (spec.startsWith("http://") || spec.startsWith("https://")) {
httpFetch(spec, localOnnx);
try { httpFetch(spec + ".data", localOnnx + ".data"); } catch (Exception e) { /* no external data */ }
} else if (spec.matches("[^@/]+@[^:]+:.*")) { // user@host:/path (scp/rsync)
rsyncFetch(spec, localOnnx);
try { rsyncFetch(spec + ".data", localOnnx + ".data"); } catch (Exception e) { /* none */ }
} else {
throw new java.io.FileNotFoundException("CuasDnnInfer: model not found / unrecognized spec: " + spec);
}
return localOnnx;
}
static void rsyncFetch(String remote, String dest) throws Exception {
Process p = new ProcessBuilder("rsync", "-az", remote, dest).inheritIO().start();
if (p.waitFor() != 0) throw new RuntimeException("rsync failed for " + remote);
}
static void httpFetch(String url, String dest) throws Exception {
try (InputStream is = new URL(url).openStream()) {
Files.copy(is, new File(dest).toPath(), java.nio.file.StandardCopyOption.REPLACE_EXISTING);
}
}
// ---- standalone verification: Java ORT output vs the saved PyTorch raw output ----
static float[] readFloatsLE(String path) throws Exception {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (DataInputStream in = new DataInputStream(new FileInputStream(path))) {
byte[] buf = new byte[8192]; int n;
while ((n = in.read(buf)) > 0) bos.write(buf, 0, n);
}
ByteBuffer bb = ByteBuffer.wrap(bos.toByteArray()).order(ByteOrder.LITTLE_ENDIAN);
float[] f = new float[bb.remaining() / 4];
for (int i = 0; i < f.length; i++) f[i] = bb.getFloat();
return f;
}
public static void main(String[] args) throws Exception {
// args: model(path/spec) testvec_in.bin [testvec_out.bin] [N H W]
String model = args[0], invec = args[1];
String outvec = args.length > 2 ? args[2] : null;
int N = args.length > 3 ? Integer.parseInt(args[3]) : 8;
int H = args.length > 4 ? Integer.parseInt(args[4]) : 24;
int W = args.length > 5 ? Integer.parseInt(args[5]) : 24;
float[] flat = readFloatsLE(invec);
if (flat.length != N * H * W) throw new RuntimeException("in size " + flat.length + " != " + (N * H * W));
float[][][] frames = new float[N][H][W];
int k = 0;
for (int n = 0; n < N; n++) for (int y = 0; y < H; y++) for (int x = 0; x < W; x++) frames[n][y][x] = flat[k++];
try (CuasDnnInfer infer = new CuasDnnInfer(model, N)) {
float[] out = infer.inferRaw(frames);
System.out.println("output len=" + out.length + " det_logit=" + out[0] + " s=" + sigmoid(out[0]));
if (outvec != null) {
float[] exp = readFloatsLE(outvec);
double maxabs = 0; int amax = 0;
for (int i = 0; i < Math.min(out.length, exp.length); i++) {
double d = Math.abs(out[i] - exp[i]); if (d > maxabs) { maxabs = d; amax = i; }
}
System.out.printf("max abs diff vs PyTorch = %.6g (at ch %d: java %.5f vs torch %.5f)%n",
maxabs, amax, out[amax], exp[amax]);
System.out.println(maxabs < 1e-3 ? "PASS - Java ORT matches PyTorch" : "MISMATCH - investigate");
}
}
}
}
......@@ -1177,6 +1177,7 @@ min_str_neib_fpn 0.35
public boolean curt_c5_white = false; // linear 4D unsharp/whitening of the conv output before all nonlinearities: sharpens (reduces moment of inertia / Gram spread) at the cost of noise gain; filename gets -W<sp>_<vel> // By Claude on 06/13/2026
public double curt_c5_white_sp = 0.5; // whitening spatial strength a (3-tap unsharp (1+2a,-a,-a) in x and y); 0 - no spatial sharpening // By Claude on 06/13/2026
public double curt_c5_white_vel = 0.0; // whitening velocity strength a (3-tap unsharp in vx and vy); 0 - no velocity sharpening // By Claude on 06/13/2026
public String curt_dnn_model = ""; // C5P DNN front-end model (ONNX): empty = disabled; local path, scp user@host:path, or http(s) URL fetched to cache; overrides bundled resource (mirrors tile_processor_gpu) // By Claude on 06/13/2026
public boolean curt_synth_src = true; // default set for the synthetic B-measurement experiment (set false for real-data runs); reads *-CUAS-SYNTHETIC-CUAS.tiff, output titles get -SYNTH // By Claude on 06/12/2026
public double curt_synth_scale = 5.0; // synthetic target peak, counts (synthetic file is peak-1 normalized; scaled at load) // By Claude on 06/12/2026
public boolean curt_synth_bg = true; // add the real *-CUAS-MERGED-CUAS.tiff scene under the synthetic targets (label-matched frames); false - clean targets only // By Claude on 06/12/2026
......@@ -3516,6 +3517,8 @@ min_str_neib_fpn 0.35
"Spatial unsharp strength a: s' = (1+2a)s - a(s_left+s_right) in x and y across the ROI. Larger = sharper spatial blob + more spatial noise. 0 - off."); // By Claude on 06/13/2026
gd.addNumericField("Whitening velocity strength", this.curt_c5_white_vel, 6,8,"", // By Claude on 06/13/2026
"Velocity unsharp strength a: same 3-tap in vx and vy within each pixel's velocity grid. Sharpens the velocity-cell spread. 0 - off (start spatial-only)."); // By Claude on 06/13/2026
gd.addStringField ("C5P DNN model (ONNX)", this.curt_dnn_model, 60, // By Claude on 06/13/2026
"Trained DNN front-end model location. Empty = disabled (use matched-filter/posterior path). Local path, scp user@host:/path (fetched to cache), or http(s) URL. Overrides any bundled resource - same default-vs-override scheme as the GPU kernel sources."); // By Claude on 06/13/2026
gd.addCheckbox ("Use synthetic input", this.curt_synth_src, // By Claude on 06/11/2026
"Read *-CUAS-SYNTHETIC-CUAS.tiff (generated test targets) instead of *-CUAS-MERGED-CUAS.tiff from the same model directory; all output titles get a -SYNTH mark."); // By Claude on 06/11/2026
gd.addNumericField("Synthetic target peak", this.curt_synth_scale, 6,8,"counts", // By Claude on 06/12/2026
......@@ -5060,6 +5063,7 @@ min_str_neib_fpn 0.35
this.curt_c5_white = gd.getNextBoolean(); // By Claude on 06/13/2026
this.curt_c5_white_sp = gd.getNextNumber(); // By Claude on 06/13/2026
this.curt_c5_white_vel = gd.getNextNumber(); // By Claude on 06/13/2026
this.curt_dnn_model = gd.getNextString().trim(); // By Claude on 06/13/2026
this.curt_synth_src = gd.getNextBoolean(); // By Claude on 06/11/2026
this.curt_synth_scale = gd.getNextNumber(); // By Claude on 06/12/2026
this.curt_synth_bg = gd.getNextBoolean(); // By Claude on 06/12/2026
......@@ -6424,6 +6428,7 @@ min_str_neib_fpn 0.35
properties.setProperty(prefix+"curt_c5_white", this.curt_c5_white+""); // boolean // By Claude on 06/13/2026
properties.setProperty(prefix+"curt_c5_white_sp", this.curt_c5_white_sp+""); // double // By Claude on 06/13/2026
properties.setProperty(prefix+"curt_c5_white_vel", this.curt_c5_white_vel+""); // double // By Claude on 06/13/2026
properties.setProperty(prefix+"curt_dnn_model", this.curt_dnn_model); // String // By Claude on 06/13/2026
properties.setProperty(prefix+"curt_synth_src", this.curt_synth_src+""); // boolean // By Claude on 06/11/2026
properties.setProperty(prefix+"curt_synth_scale", this.curt_synth_scale+""); // double // By Claude on 06/12/2026
properties.setProperty(prefix+"curt_synth_bg", this.curt_synth_bg+""); // boolean // By Claude on 06/12/2026
......@@ -6823,6 +6828,7 @@ min_str_neib_fpn 0.35
if (properties.getProperty(prefix+"curt_c5_white")!=null) this.curt_c5_white=Boolean.parseBoolean(properties.getProperty(prefix+"curt_c5_white")); // By Claude on 06/13/2026
if (properties.getProperty(prefix+"curt_c5_white_sp")!=null) this.curt_c5_white_sp=Double.parseDouble(properties.getProperty(prefix+"curt_c5_white_sp")); // By Claude on 06/13/2026
if (properties.getProperty(prefix+"curt_c5_white_vel")!=null) this.curt_c5_white_vel=Double.parseDouble(properties.getProperty(prefix+"curt_c5_white_vel")); // By Claude on 06/13/2026
if (properties.getProperty(prefix+"curt_dnn_model")!=null) this.curt_dnn_model=(String) properties.getProperty(prefix+"curt_dnn_model"); // By Claude on 06/13/2026
if (properties.getProperty(prefix+"curt_synth_src")!=null) this.curt_synth_src=Boolean.parseBoolean(properties.getProperty(prefix+"curt_synth_src")); // By Claude on 06/11/2026
if (properties.getProperty(prefix+"curt_synth_scale")!=null) this.curt_synth_scale=Double.parseDouble(properties.getProperty(prefix+"curt_synth_scale")); // By Claude on 06/12/2026
......@@ -9104,6 +9110,7 @@ min_str_neib_fpn 0.35
imp.curt_c5_white = this.curt_c5_white; // By Claude on 06/13/2026
imp.curt_c5_white_sp = this.curt_c5_white_sp; // By Claude on 06/13/2026
imp.curt_c5_white_vel = this.curt_c5_white_vel; // By Claude on 06/13/2026
imp.curt_dnn_model = this.curt_dnn_model; // By Claude on 06/13/2026
imp.curt_synth_src = this.curt_synth_src; // By Claude on 06/11/2026
imp.curt_synth_scale = this.curt_synth_scale; // By Claude on 06/12/2026
imp.curt_synth_bg = this.curt_synth_bg; // By Claude on 06/12/2026
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment