Commit 7a351e9e authored by Andrey Filippov's avatar Andrey Filippov

CLAUDE: export L2 (Layer2Net step) to TorchScript for native LibTorch

Piece 1 of the native-JNA DNN path (no Python server). Adds export_l2_torchscript.py:
wraps a trained Layer2Net's cell+head into a single-step L2Step module
forward(x,h)->(h_new,det,vel) — exactly infer_server's per-scene recurrence
(h=cell(x,h); det,vel=decode(h)) — so the C++ side just carries h and calls it
per scene. Size-agnostic (circular pad + 1x1 head), runs on the full field.

Validated: scripted==eager exact (0.0); C++ LibTorch (libtorch_probe/l2_probe)
loads it on Blackwell CUDA and replays the recurrence with hidden-state match
9.5e-7. Required disabling the TorchScript JIT fuser (nvrtc element-wise fusion
fails on Blackwell -arch; production wants no runtime nvrtc) — folds into the
native lib startup in piece 2.
Co-Authored-By: 's avatarClaude Opus 4.8 (1M context) <noreply@anthropic.com>
parent cd7b0801
# export_l2_torchscript.py - part of imagej_elphel_dnn — Elphel PyTorch DNN companions to imagej-elphel
#
# Copyright (C) 2026 Elphel, Inc.
#
# -----------------------------------------------------------------------------
# imagej_elphel_dnn is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------
"""Export a trained Layer-2 (Layer2Net ConvGRU-on-torus) checkpoint to self-contained TorchScript
for native LibTorch inference (no Python at runtime).
Production L2 (infer_server.py) does NOT call Layer2Net.forward over a fixed 32x32 torus; it runs
the cell convolutionally over the FULL field and carries the hidden state across scenes:
h = m2.cell(l2in[j], h) # [1,Ch,H,W] recurrent update, one scene at a time
det, vel = m2.decode(h) # [1,1,H,W], [1,2,H,W]
So we export a single-step wrapper, L2Step, whose forward(x, h) -> (h_new, det, vel) is exactly that
one recurrent step. The C++ side then just keeps `h` and calls forward() per scene — the recurrence
loop (and the age/track post-processing) lives in C++, matching the server's Python loop.
Build-once on a dev box with the torch version that MATCHES the C++ libtorch (2.7.1 -> libtorch
2.7.1+cu128). The .ts.pt is device-agnostic; C++ loads it onto CUDA.
Usage: python export_l2_torchscript.py runs/mexhat_gaps_boost40/model.pt [out.ts.pt]
Also writes <out>.test_*.bin reference vectors (a T-step recurrence) for the C++ probe to verify
the native output against PyTorch. By Claude on 06/26/2026.
"""
import sys, struct, torch
import torch.nn as nn
from layer2 import Layer2Net
class L2Step(nn.Module):
"""One Layer-2 recurrent step == infer_server's `h = cell(x,h); det,vel = decode(h)`.
forward(x, h) -> (h_new, det, vel). All convs are size-agnostic (circular pad + 1x1 head),
so the same module runs on any H,W (full field at inference). By Claude on 06/26/2026."""
def __init__(self, net):
super().__init__()
self.cell = net.cell # ConvGRUCellTorus (circular padding; FCN -> any H,W)
self.head = net.head # 1x1 conv: hidden -> det(1) + raw Vx,Vy(2)
self.vmax = float(net.vmax) # velocity readout bound, px/level-frame
def forward(self, x, h):
# x: [1,3,H,W] (s, Vx/vd, Vy/vd) h: [1,Ch,H,W] -> h_new[1,Ch,H,W], det[1,1,H,W], vel[1,2,H,W]
h = self.cell(x, h) # recurrent update
o = self.head(h) # [1,3,H,W]
det = o[:, 0:1] # raw det logit (sigmoid applied downstream, like the server)
vel = self.vmax * torch.tanh(o[:, 1:3]) # [1,2,H,W] px/level-frame
return h, det, vel
def _w(path, t):
f = t.detach().contiguous().view(-1).tolist()
with open(path, "wb") as fh:
fh.write(struct.pack("<%df" % len(f), *f))
def main():
ckpt = sys.argv[1]
out = sys.argv[2] if len(sys.argv) > 2 else (ckpt[:-3] if ckpt.endswith(".pt") else ckpt) + ".l2.ts.pt"
ck = torch.load(ckpt, map_location="cpu", weights_only=False)
a = ck.get("args", {}) or {}
ch_hidden = a.get("ch", 24)
vmax = a.get("vmax", 1.4)
grid = a.get("G", 32)
net = Layer2Net(ch_in=3, ch_hidden=ch_hidden, grid=grid, vmax=vmax)
net.load_state_dict(ck["model"])
net.eval()
print(f"L2 config: ch_hidden={ch_hidden} vmax={vmax} grid={grid}")
step = L2Step(net).eval()
# Reference: a T-step recurrence on a small field (exercises carry-across + cell + head + tanh).
torch.manual_seed(0)
H = W = 16
T = 4
seq = torch.randn(T, 1, 3, H, W) # [T,1,3,H,W] fake L1->L2 inputs (s, Vx/vd, Vy/vd)
with torch.no_grad():
h = torch.zeros(1, ch_hidden, H, W)
dets, vels = [], []
for t in range(T):
h, det, vel = step(seq[t], h)
dets.append(det); vels.append(vel)
det_eager = torch.cat(dets, 0) # [T,1,H,W]
vel_eager = torch.cat(vels, 0) # [T,2,H,W]
h_eager = h # [1,Ch,H,W] final hidden
ts = torch.jit.script(step)
ts.save(out)
ts2 = torch.jit.load(out, map_location="cpu"); ts2.eval()
with torch.no_grad():
h = torch.zeros(1, ch_hidden, H, W)
dd = 0.0
for t in range(T):
h, det, vel = ts2(seq[t], h)
dd = max(dd, (det - det_eager[t:t+1]).abs().max().item(),
(vel - vel_eager[t:t+1]).abs().max().item())
dd = max(dd, (h - h_eager).abs().max().item())
print(f"saved {out}; {T}-step recurrence reload max|diff| vs eager = {dd}")
# Reference vectors for the C++ probe: header (T,H,W,Ch) + seq + final det/vel/h.
with open(out + ".test_meta.bin", "wb") as fh:
fh.write(struct.pack("<4i", T, H, W, ch_hidden))
_w(out + ".test_seq.bin", seq) # [T,1,3,H,W]
_w(out + ".test_det.bin", det_eager) # [T,1,H,W]
_w(out + ".test_vel.bin", vel_eager) # [T,2,H,W]
_w(out + ".test_h.bin", h_eager) # [1,Ch,H,W]
print(f"wrote reference vectors: seq{tuple(seq.shape)} det{tuple(det_eager.shape)} "
f"vel{tuple(vel_eager.shape)} h{tuple(h_eager.shape)}")
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment