Commit 6c9c2486 authored by Andrey Filippov's avatar Andrey Filippov

CLAUDE: tp_dnn GPU_CHUNK sub-batching - fix full-res OOM

tpdnn_infer processed the whole `count`-scene request in one shot; at full res
(512x640) a 64-scene field tensor is ~10GB (64x124x512x640) + decode p ~10GB ->
CUDA OOM on a 16GB card. The 64x80 synthetic smoke didn't expose it.

Now loops in GPU_CHUNK-sized sub-batches (== infer_server's GPU_CHUNK), writing each
chunk straight to the host output buffers so GPU tensors stay bounded (~3-4GB/chunk
at CHUNK=8). L2 hidden/age/sprev carry across chunks; reset only at the first chunk
when l2_reset (matches the server). Env TPDNN_GPU_CHUNK (default 8).

Validated: parity vs server oracle still EXACT (0.0) at CHUNK=8 (8+4) and CHUNK=4
(4+4+4) - L2 carry across chunk boundaries is correct.
Co-Authored-By: 's avatarClaude Opus 4.8 (1M context) <noreply@anthropic.com>
parent ea9c117a
......@@ -15,6 +15,8 @@
#include <vector>
#include <string>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
using torch::Tensor;
......@@ -145,65 +147,74 @@ int tpdnn_infer(void* ctx, int level, int start, int count, int stride,
int N = c->N;
int vdim = 2 * c->vr + 1, nvel = vdim * vdim;
// newest-first window stack [count,N,H,W] (channel 0 = newest), matching the Java order.
std::vector<Tensor> wl;
wl.reserve(count);
for (int j = 0; j < count; j++) {
int newest = start + j * stride;
wl.push_back(lev.slice(0, newest - N + 1, newest + 1).flip(0)); // [N,H,W]
}
Tensor wins = torch::stack(wl, 0); // [count,N,H,W]
Tensor field = shift_stitch(c->l1, wins, c->P, c->out_ch); // [count,C,H,W]
auto dec = decode(field, c->vr, rx, ry, rw, rh, rmax); // ghostbusted
Tensor o5 = dec.first, rf = dec.second;
int nch = 5;
if (use_l2) {
Tensor ong = decode(field, c->vr, rx, ry, rw, rh, 0.0).first; // no ghostbuster (L2 gets full field)
Tensor l2in = torch::stack({ong.select(1, 2),
ong.select(1, 3) / VEL_DECIMATE,
ong.select(1, 4) / VEL_DECIMATE}, 1); // [count,3,H,W]
l2in = torch::nan_to_num(l2in, 0.0, 0.0, 0.0);
int Hf = l2in.size(2), Wf = l2in.size(3);
if (!c->l2_state_init || c->h_l2.size(2) != Hf || c->h_l2.size(3) != Wf || l2_reset) {
auto o = field.options();
c->h_l2 = torch::zeros({1, c->l2_ch, Hf, Wf}, o);
c->age_l2 = torch::zeros({1, 1, Hf, Wf}, o);
c->sprev_l2 = torch::zeros({1, 1, Hf, Wf}, o);
c->l2_state_init = true;
// GPU_CHUNK sub-batching (== server's GPU_CHUNK loop): process the `count` scenes in chunks so the
// big GPU tensors (field [b,C,H,W], decode p [b,nvel,H,W]) stay bounded -- at full res a 64-scene
// single shot is ~10GB and OOMs a 16GB card. L2 hidden/age/sprev carry across chunks; reset only at
// the FIRST chunk when l2_reset (matches infer_server). Env TPDNN_GPU_CHUNK (default 8). By Claude 06/27/2026.
static int CHUNK = -1;
if (CHUNK < 0) { const char* e = std::getenv("TPDNN_GPU_CHUNK"); CHUNK = e ? std::atoi(e) : 8; if (CHUNK < 1) CHUNK = 8; }
int nch = use_l2 ? 6 : 5;
for (int c0 = 0; c0 < count; c0 += CHUNK) {
int b = std::min(CHUNK, count - c0);
// newest-first window stack [b,N,H,W] (channel 0 = newest), matching the Java order.
std::vector<Tensor> wl;
wl.reserve(b);
for (int j = 0; j < b; j++) {
int newest = start + (c0 + j) * stride;
wl.push_back(lev.slice(0, newest - N + 1, newest + 1).flip(0)); // [N,H,W]
}
std::vector<Tensor> dets, vxs, vys, ages;
dets.reserve(count); vxs.reserve(count); vys.reserve(count); ages.reserve(count);
for (int j = 0; j < count; j++) {
auto out = c->l2.forward({l2in.slice(0, j, j + 1), c->h_l2}).toTuple();
c->h_l2 = out->elements()[0].toTensor(); // h_new
Tensor dlog = out->elements()[1].toTensor(); // [1,1,H,W] raw det logit
Tensor vel = out->elements()[2].toTensor(); // [1,2,H,W] bounded px/level-frame
Tensor s = torch::sigmoid(dlog.slice(1, 0, 1)); // [1,1,H,W]
// AGE (track-before-detect persistence) — exactly the server's 5x5 ancestor-gated update.
Tensor maxS = torch::max_pool2d(c->sprev_l2, 5, 1, 2);
Tensor elig = (c->sprev_l2 >= AGE_K * maxS).logical_and(c->sprev_l2 > AGE_THR);
Tensor prev = torch::where(elig, c->age_l2, torch::zeros_like(c->age_l2));
c->age_l2 = torch::where(s > AGE_THR,
torch::max_pool2d(prev, 5, 1, 2) + 1.0f,
torch::zeros_like(c->age_l2));
c->sprev_l2 = s;
dets.push_back(s.select(1, 0)); // [1,H,W]
ages.push_back(c->age_l2.select(1, 0));
vxs.push_back(vel.select(1, 0));
vys.push_back(vel.select(1, 1));
Tensor wins = torch::stack(wl, 0); // [b,N,H,W]
Tensor field = shift_stitch(c->l1, wins, c->P, c->out_ch); // [b,C,H,W]
auto dec = decode(field, c->vr, rx, ry, rw, rh, rmax); // ghostbusted
Tensor o5 = dec.first, rf = dec.second; // [b,5,H,W], [b,rh,rw,nvel]
if (use_l2) {
Tensor ong = decode(field, c->vr, rx, ry, rw, rh, 0.0).first; // no ghostbuster (L2 gets full field)
Tensor l2in = torch::stack({ong.select(1, 2),
ong.select(1, 3) / VEL_DECIMATE,
ong.select(1, 4) / VEL_DECIMATE}, 1); // [b,3,H,W]
l2in = torch::nan_to_num(l2in, 0.0, 0.0, 0.0);
int Hf = l2in.size(2), Wf = l2in.size(3);
if (!c->l2_state_init || c->h_l2.size(2) != Hf || c->h_l2.size(3) != Wf || (l2_reset && c0 == 0)) {
auto o = field.options();
c->h_l2 = torch::zeros({1, c->l2_ch, Hf, Wf}, o);
c->age_l2 = torch::zeros({1, 1, Hf, Wf}, o);
c->sprev_l2 = torch::zeros({1, 1, Hf, Wf}, o);
c->l2_state_init = true;
}
std::vector<Tensor> dets, vxs, vys, ages;
dets.reserve(b); vxs.reserve(b); vys.reserve(b); ages.reserve(b);
for (int j = 0; j < b; j++) {
auto out = c->l2.forward({l2in.slice(0, j, j + 1), c->h_l2}).toTuple();
c->h_l2 = out->elements()[0].toTensor(); // h_new
Tensor dlog = out->elements()[1].toTensor(); // [1,1,H,W] raw det logit
Tensor vel = out->elements()[2].toTensor(); // [1,2,H,W] bounded px/level-frame
Tensor s = torch::sigmoid(dlog.slice(1, 0, 1)); // [1,1,H,W]
// AGE (track-before-detect persistence) — exactly the server's 5x5 ancestor-gated update.
Tensor maxS = torch::max_pool2d(c->sprev_l2, 5, 1, 2);
Tensor elig = (c->sprev_l2 >= AGE_K * maxS).logical_and(c->sprev_l2 > AGE_THR);
Tensor prev = torch::where(elig, c->age_l2, torch::zeros_like(c->age_l2));
c->age_l2 = torch::where(s > AGE_THR,
torch::max_pool2d(prev, 5, 1, 2) + 1.0f,
torch::zeros_like(c->age_l2));
c->sprev_l2 = s;
dets.push_back(s.select(1, 0)); // [1,H,W]
ages.push_back(c->age_l2.select(1, 0));
vxs.push_back(vel.select(1, 0));
vys.push_back(vel.select(1, 1));
}
Tensor l2vx = torch::cat(vxs, 0) * (float)VEL_DECIMATE; // [b,H,W]
Tensor l2vy = torch::cat(vys, 0) * (float)VEL_DECIMATE;
o5 = torch::stack({o5.select(1, 0), o5.select(1, 1), // keep L1 dx,dy
torch::cat(dets, 0), l2vx, l2vy, torch::cat(ages, 0)}, 1); // [b,6,H,W]
}
Tensor l2vx = torch::cat(vxs, 0) * (float)VEL_DECIMATE; // [count,H,W]
Tensor l2vy = torch::cat(vys, 0) * (float)VEL_DECIMATE;
o5 = torch::stack({o5.select(1, 0), o5.select(1, 1), // keep L1 dx,dy
torch::cat(dets, 0), l2vx, l2vy, torch::cat(ages, 0)}, 1); // [count,6,H,W]
nch = 6;
}
Tensor o5c = o5.contiguous().to(torch::kCPU); // [count,nch,H,W]
Tensor rfc = rf.contiguous().to(torch::kCPU); // [count,rh,rw,nvel]
std::memcpy(out_o5, o5c.data_ptr<float>(), (size_t)count * nch * H * W * sizeof(float));
std::memcpy(out_roi, rfc.data_ptr<float>(), (size_t)count * rh * rw * nvel * sizeof(float));
Tensor o5c = o5.contiguous().to(torch::kCPU); // [b,nch,H,W]
Tensor rfc = rf.contiguous().to(torch::kCPU); // [b,rh,rw,nvel]
std::memcpy(out_o5 + (size_t)c0 * nch * H * W, o5c.data_ptr<float>(), (size_t)b * nch * H * W * sizeof(float));
std::memcpy(out_roi + (size_t)c0 * rh * rw * nvel, rfc.data_ptr<float>(), (size_t)b * rh * rw * nvel * sizeof(float));
}
return nch;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment