Commit 9202dd30 authored by Andrey Filippov's avatar Andrey Filippov

CLAUDE: native LibTorch L1+L2 inference shim (libtpdnn.so) for JNA

Piece 2 of the native-JNA DNN path. tp_dnn.cpp is a C-ABI port of infer_server.py's
hot path so the Java client can run L1+L2 in-process instead of over TCP:
  tpdnn_init/upload/infer/free  (+ num_levels/level_frames)
faithfully reproducing build_pyramid, the 16x shift-and-stitch full-res recovery,
decode (ghostbuster + velocity centroid), and the L2 ConvGRU recurrence + track-age.
Loads the TorchScript models from imagej_elphel_dnn (export_torchscript /
export_l2_torchscript). Disables the TorchScript JIT fuser at init (nvrtc element-wise
fusion fails on Blackwell; production wants no runtime nvrtc).

Validated: native vs the running Python server (same CUDA) max|diff| offset5=0,
roi=0 — bit-for-bit. (Oracle dump_ref.py + driver tpdnn_test.cpp, scratch.)

Built standalone via build_dnn.sh (g++ + libtorch 2.7.1+cu128, ABI=1), separate
from the nvcc-built libtileproc.so; fetch_libtorch.sh pulls the pinned libtorch.
Context unification + zero-copy kernel<->tensor sharing is a later step.
Co-Authored-By: 's avatarClaude Opus 4.8 (1M context) <noreply@anthropic.com>
parent 7540202f
libtileproc.so
libtpdnn.so
tp_nvrtc_probe
*.o
#!/usr/bin/env bash
# Build libtpdnn.so — native LibTorch L1+L2 inference shim for JNA (no Python server).
# Separate from libtileproc.so (nvcc kernels): this links libtorch (g++). The two .so's unify
# their CUDA context later (zero-copy kernel<->tensor). By Claude on 2026-06-26.
#
# Requires libtorch 2.7.1+cu128 (matches the TorchScript export torch version). Set LIBTORCH to
# its root (default /home/elphel/git/libtorch). Run jna/fetch_libtorch.sh to obtain it.
set -e
cd "$(dirname "$0")"
LIBTORCH="${LIBTORCH:-/home/elphel/git/libtorch}"
[ -d "$LIBTORCH/include/torch" ] || { echo "libtorch not found at $LIBTORCH (set LIBTORCH= or run fetch_libtorch.sh)"; exit 1; }
g++ -std=gnu++17 -O3 -DNDEBUG -fPIC --shared \
-D_GLIBCXX_USE_CXX11_ABI=1 \
-I"$LIBTORCH/include" -I"$LIBTORCH/include/torch/csrc/api/include" \
tp_dnn.cpp \
-o libtpdnn.so \
-L"$LIBTORCH/lib" -Wl,-rpath,"$LIBTORCH/lib" \
-ltorch -ltorch_cpu -ltorch_cuda -lc10 -lc10_cuda
echo "built ./libtpdnn.so (LIBTORCH=$LIBTORCH)"
#!/bin/bash
# Fetch + extract the pinned libtorch (cu128 / CUDA 12.8, Blackwell sm_120) from mirror.elphel.com.
# Runtime dependency for native DNN inference (L1/L2 via TorchScript). NOT in git (~3.8 GB zip / ~GB extracted).
# Default extract location: /home/elphel/git/libtorch (native build uses -DCMAKE_PREFIX_PATH=<that>).
# By Claude on 06/27/2026.
set -euo pipefail
LT_ZIP="libtorch-cxx11-abi-shared-with-deps-2.7.1-cu128.zip"
LT_URL="https://mirror.elphel.com/libtorch/${LT_ZIP}"
PARENT="${1:-/home/elphel/git}" # libtorch extracts to $PARENT/libtorch
DEST="$PARENT/libtorch"
if [ -f "$DEST/build-version" ]; then
echo "libtorch already present: $DEST ($(cat "$DEST/build-version"))"; exit 0
fi
mkdir -p "$PARENT"; cd "$PARENT"
echo "Downloading $LT_URL ..."
# NOTE: mirror.elphel.com WAF returns 406 to curl's default UA -> use a browser UA.
curl -fSL -A "Mozilla/5.0 (X11; Linux x86_64)" "$LT_URL" -o "$LT_ZIP"
echo "Extracting -> $DEST ..."
unzip -q -o "$LT_ZIP" # extracts top-level ./libtorch/
rm -f "$LT_ZIP"
echo "libtorch ready: $DEST ($(cat "$DEST/build-version"))"
// tp_dnn.cpp — native LibTorch inference for the CUAS L1+L2 DNN path (no Python server).
//
// C-ABI shim for JNA. A faithful port of infer_server.py's hot path
// (build_pyramid + shift_stitch + decode + the L2 recurrence/age), so the Java client can run
// the same L1+L2 inference in-process instead of over TCP. Loads the TorchScript models exported
// by imagej_elphel_dnn (export_torchscript.py / export_l2_torchscript.py).
//
// Built as a standalone libtpdnn.so against libtorch (g++), separate from the nvcc-built
// libtileproc.so. Zero-copy kernel<->tensor sharing (one CUDA primary context) is a later step.
// By Claude on 06/26/2026.
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <vector>
#include <string>
#include <cstring>
#include <cmath>
using torch::Tensor;
namespace {
constexpr int VEL_DECIMATE = 4; // velocity-grid cells per px/level-frame (Java curt_vel_decimate)
constexpr float AGE_THR = 0.2f; // L2 track-age death threshold
constexpr float AGE_K = 0.5f; // ancestor gate: prev-neighbor det >= AGE_K*local-max to pass its age
constexpr int NOISE_REF_LEVEL = 3; // net calibrated to ~LEV3 absolute noise
constexpr int S_STRIDE = 4; // shift-and-stitch stride (full-res recovery): SxS forwards
struct Ctx {
torch::jit::Module l1, l2;
bool has_l2 = false;
int N = 9, P = 24, vr = 5, out_ch = 0, l2_ch = 24;
std::vector<Tensor> pyr; // temporal pyramid levels from upload, on CUDA
// L2 carry state (persists across infer calls; reset on l2_reset), == server's h_l2/age_l2/sprev_l2
Tensor h_l2, age_l2, sprev_l2;
bool l2_state_init = false;
};
// build_pyramid: log [T,H,W] -> temporal-averaged levels (== Java temporalAverageLReLU, linear).
std::vector<Tensor> build_pyramid(const Tensor& log, int nmax = 8) {
std::vector<Tensor> lv;
int T = log.size(0);
lv.push_back(0.5f * (log.slice(0, 1, T) + log.slice(0, 0, T - 1))); // [T-1,H,W]
while ((int)lv.size() < nmax) {
Tensor prev = lv.back();
int n = (int)prev.size(0) / 2 - 1;
if (n < 1) break;
Tensor idx = torch::arange(n, torch::TensorOptions().dtype(torch::kLong).device(prev.device()));
lv.push_back(0.5f * (prev.index_select(0, idx * 2 + 2) + prev.index_select(0, idx * 2)));
}
return lv;
}
// shift_stitch: x [b,N,H,W] -> full-res field [b,out_ch,H,W] via SxS strided forwards.
Tensor shift_stitch(torch::jit::Module& m, const Tensor& x, int P, int out_ch) {
int b = x.size(0), H = x.size(2), W = x.size(3);
int half = P / 2;
Tensor xp = torch::constant_pad_nd(x, {half, half, half, half}, 0); // [b,N,H+P,W+P]
Tensor full = torch::zeros({b, out_ch, H, W}, x.options());
for (int sy = 0; sy < S_STRIDE; sy++)
for (int sx = 0; sx < S_STRIDE; sx++) {
Tensor xin = xp.slice(2, sy).slice(3, sx);
Tensor y = m.forward({xin}).toTensor(); // [b,C,oH,oW]
int oh = std::min((int)y.size(2), (H - sy + S_STRIDE - 1) / S_STRIDE);
int ow = std::min((int)y.size(3), (W - sx + S_STRIDE - 1) / S_STRIDE);
full.slice(2, sy, c10::nullopt, S_STRIDE).slice(3, sx, c10::nullopt, S_STRIDE)
.copy_(y.slice(2, 0, oh).slice(3, 0, ow));
}
return full;
}
// decode: field [b,C,H,W] -> offset5 [b,5,H,W] (dx,dy,s,vx,vy) + roi [b,rh,rw,nvel] (softmax*s).
std::pair<Tensor, Tensor> decode(const Tensor& field, int vr,
int x0, int y0, int rw, int rh, double rmax_cells) {
int vdim = 2 * vr + 1, nvel = vdim * vdim;
auto fopt = field.options();
Tensor s = torch::sigmoid(field.select(1, 0)); // [b,H,W]
Tensor p = field.slice(1, 1, 1 + nvel).softmax(1); // [b,nvel,H,W]
Tensor k = torch::arange(nvel, torch::TensorOptions().dtype(torch::kLong).device(field.device()));
Tensor cx = (k.remainder(vdim) - vr).to(fopt.dtype()); // [nvel] vx cell coord
Tensor cy = (k.div(vdim, "floor") - vr).to(fopt.dtype()); // [nvel] vy cell coord
if (rmax_cells > 0) { // ghostbuster
Tensor corner = (cx * cx + cy * cy) > (float)(rmax_cells * rmax_cells); // [nvel] bool
Tensor amax = p.argmax(1); // [b,H,W]
Tensor ghost = corner.to(torch::kLong).index_select(0, amax.flatten())
.view(amax.sizes()).to(torch::kBool); // [b,H,W]
p = p * torch::logical_not(corner).to(fopt.dtype()).view({1, nvel, 1, 1});
Tensor keep = torch::logical_not(ghost).to(fopt.dtype()); // [b,H,W]
p = p * keep.unsqueeze(1);
s = s * keep;
}
Tensor psum = p.sum(1).clamp_min(1e-12); // [b,H,W]
Tensor vx = (p * cx.view({1, nvel, 1, 1})).sum(1) / psum; // [b,H,W]
Tensor vy = (p * cy.view({1, nvel, 1, 1})).sum(1) / psum;
Tensor dx = field.select(1, 1 + nvel); // [b,H,W]
Tensor dy = field.select(1, 1 + nvel + 1);
Tensor offset5 = torch::stack({dx, dy, s, vx, vy}, 1); // [b,5,H,W]
Tensor roi = (p * s.unsqueeze(1)).slice(2, y0, y0 + rh).slice(3, x0, x0 + rw); // [b,nvel,rh,rw]
roi = roi.permute({0, 2, 3, 1}).contiguous(); // [b,rh,rw,nvel]
return {offset5, roi};
}
} // namespace
extern "C" {
// Load TorchScript L1 (+optional L2) onto CUDA, disable the JIT fuser (nvrtc element-wise fusion
// fails on Blackwell; production wants no runtime nvrtc). l2_path empty/NULL -> L2 off.
void* tpdnn_init(const char* l1_path, const char* l2_path, int N, int P, int vr, int l2_ch_hidden) {
torch::jit::setGraphExecutorOptimize(false);
torch::jit::setTensorExprFuserEnabled(false);
Ctx* c = new Ctx();
c->N = N; c->P = P; c->vr = vr; c->l2_ch = l2_ch_hidden;
c->out_ch = 1 + (2 * vr + 1) * (2 * vr + 1) + 2;
c->l1 = torch::jit::load(l1_path); c->l1.to(torch::kCUDA); c->l1.eval();
if (l2_path && l2_path[0]) {
c->l2 = torch::jit::load(l2_path); c->l2.to(torch::kCUDA); c->l2.eval();
c->has_l2 = true;
}
return (void*)c;
}
// build_pyramid from host log [T,H,W] (row-major float32).
void tpdnn_upload(void* ctx, const float* log, int T, int H, int W) {
Ctx* c = (Ctx*)ctx;
Tensor host = torch::from_blob((void*)log, {T, H, W}, torch::kFloat32).clone();
c->pyr = build_pyramid(host.to(torch::kCUDA));
}
int tpdnn_num_levels(void* ctx) { return (int)((Ctx*)ctx)->pyr.size(); }
int tpdnn_level_frames(void* ctx, int lev) { return (int)((Ctx*)ctx)->pyr[lev].size(0); }
// One INFER over `count` scenes of `level` (newest_j = start + j*stride). Writes offset5 into
// out_o5 [count*nch*H*W] and roi into out_roi [count*rh*rw*nvel] (native float order). Returns
// nch (5 L1-only, 6 with L2:+age). H,W == upload H,W; nvel == (2vr+1)^2 (caller sizes buffers).
int tpdnn_infer(void* ctx, int level, int start, int count, int stride,
int rx, int ry, int rw, int rh, double rmax,
int l2_enable, int l2_reset, double noise_scale,
float* out_o5, float* out_roi) {
Ctx* c = (Ctx*)ctx;
torch::NoGradGuard ng;
bool use_l2 = (l2_enable != 0) && c->has_l2;
if (noise_scale <= 0.0) noise_scale = std::pow(2.0, (level - NOISE_REF_LEVEL) / 2.0);
Tensor lev = c->pyr[level] * (float)noise_scale; // [Tl,H,W]
int H = lev.size(1), W = lev.size(2);
int N = c->N;
int vdim = 2 * c->vr + 1, nvel = vdim * vdim;
// newest-first window stack [count,N,H,W] (channel 0 = newest), matching the Java order.
std::vector<Tensor> wl;
wl.reserve(count);
for (int j = 0; j < count; j++) {
int newest = start + j * stride;
wl.push_back(lev.slice(0, newest - N + 1, newest + 1).flip(0)); // [N,H,W]
}
Tensor wins = torch::stack(wl, 0); // [count,N,H,W]
Tensor field = shift_stitch(c->l1, wins, c->P, c->out_ch); // [count,C,H,W]
auto dec = decode(field, c->vr, rx, ry, rw, rh, rmax); // ghostbusted
Tensor o5 = dec.first, rf = dec.second;
int nch = 5;
if (use_l2) {
Tensor ong = decode(field, c->vr, rx, ry, rw, rh, 0.0).first; // no ghostbuster (L2 gets full field)
Tensor l2in = torch::stack({ong.select(1, 2),
ong.select(1, 3) / VEL_DECIMATE,
ong.select(1, 4) / VEL_DECIMATE}, 1); // [count,3,H,W]
l2in = torch::nan_to_num(l2in, 0.0, 0.0, 0.0);
int Hf = l2in.size(2), Wf = l2in.size(3);
if (!c->l2_state_init || c->h_l2.size(2) != Hf || c->h_l2.size(3) != Wf || l2_reset) {
auto o = field.options();
c->h_l2 = torch::zeros({1, c->l2_ch, Hf, Wf}, o);
c->age_l2 = torch::zeros({1, 1, Hf, Wf}, o);
c->sprev_l2 = torch::zeros({1, 1, Hf, Wf}, o);
c->l2_state_init = true;
}
std::vector<Tensor> dets, vxs, vys, ages;
dets.reserve(count); vxs.reserve(count); vys.reserve(count); ages.reserve(count);
for (int j = 0; j < count; j++) {
auto out = c->l2.forward({l2in.slice(0, j, j + 1), c->h_l2}).toTuple();
c->h_l2 = out->elements()[0].toTensor(); // h_new
Tensor dlog = out->elements()[1].toTensor(); // [1,1,H,W] raw det logit
Tensor vel = out->elements()[2].toTensor(); // [1,2,H,W] bounded px/level-frame
Tensor s = torch::sigmoid(dlog.slice(1, 0, 1)); // [1,1,H,W]
// AGE (track-before-detect persistence) — exactly the server's 5x5 ancestor-gated update.
Tensor maxS = torch::max_pool2d(c->sprev_l2, 5, 1, 2);
Tensor elig = (c->sprev_l2 >= AGE_K * maxS).logical_and(c->sprev_l2 > AGE_THR);
Tensor prev = torch::where(elig, c->age_l2, torch::zeros_like(c->age_l2));
c->age_l2 = torch::where(s > AGE_THR,
torch::max_pool2d(prev, 5, 1, 2) + 1.0f,
torch::zeros_like(c->age_l2));
c->sprev_l2 = s;
dets.push_back(s.select(1, 0)); // [1,H,W]
ages.push_back(c->age_l2.select(1, 0));
vxs.push_back(vel.select(1, 0));
vys.push_back(vel.select(1, 1));
}
Tensor l2vx = torch::cat(vxs, 0) * (float)VEL_DECIMATE; // [count,H,W]
Tensor l2vy = torch::cat(vys, 0) * (float)VEL_DECIMATE;
o5 = torch::stack({o5.select(1, 0), o5.select(1, 1), // keep L1 dx,dy
torch::cat(dets, 0), l2vx, l2vy, torch::cat(ages, 0)}, 1); // [count,6,H,W]
nch = 6;
}
Tensor o5c = o5.contiguous().to(torch::kCPU); // [count,nch,H,W]
Tensor rfc = rf.contiguous().to(torch::kCPU); // [count,rh,rw,nvel]
std::memcpy(out_o5, o5c.data_ptr<float>(), (size_t)count * nch * H * W * sizeof(float));
std::memcpy(out_roi, rfc.data_ptr<float>(), (size_t)count * rh * rw * nvel * sizeof(float));
return nch;
}
void tpdnn_free(void* ctx) { delete (Ctx*)ctx; }
} // extern "C"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment