PR 4 of #508 — InferenceLayer ABC + SingleInstanceLayer (#512)#534
Merged
gitttt-1234 merged 2 commits intoMay 28, 2026
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## divya/inf-refactor-03-torch-backend #534 +/- ##
=======================================================================
- Coverage 63.65% 61.66% -1.99%
=======================================================================
Files 109 129 +20
Lines 17918 19607 +1689
=======================================================================
+ Hits 11406 12091 +685
- Misses 6512 7516 +1004 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
1427454 to
90758f0
Compare
aae919a to
5ba3150
Compare
…#508) Lock down the InferenceLayer abstraction with the proof-of-pattern single-instance subclass — and prove parity end-to-end against the PR 0 golden. Every subsequent layer (PR 6) and the Predictor orchestrator (PR 8) follows this template. Layout: sleap_nn/inference/layers/ configs.py attrs.frozen PreprocessConfig + PostprocessConfig base.py InferenceLayer ABC: preprocess / postprocess / predict / __call__ / warmup. Includes the _to_4d_float_tensor helper that accepts (H,W), (H,W,C), (C,H,W), (B,H,W,C), (B,C,H,W) numpy or torch — gives every subclass a uniform input contract for the new direct-numpy API. single_instance.py SingleInstanceLayer (concrete). Decodes confmaps via ops.peaks.find_global_peaks, applies the full coord ladder via ops.coord, returns Outputs. End-to-end parity proven by ``test_single_instance_layer_parity_vs_pr0_golden``: the new layer's pred_keypoints / pred_peak_values match the PR 0 golden within 1e-5 atol/rtol on the captured fixed input. This test is the linchpin of the entire refactor — every PR 5–14 will keep gating on it. Tests (`tests/inference/layers/`, 22 cases): * InferenceLayer ABC enforcement: cannot instantiate directly; rejects non-ModelBackend; predict / __call__ agreement. * _to_4d_float_tensor: 5 input shapes parametrized + numpy/torch + rejection of unsupported types and ranks. * SingleInstanceLayer: - parity vs PR 0 golden (within 1e-5) - direct numpy API - direct torch API - 2D grayscale + 4D batched inputs - return_confmaps off-by-default + opt-in populates Outputs.pred_confmaps - synthetic-input coord ladder verification (output_stride scaling) Full inference suite: 240 passed, 8 skipped, 2 xfailed. Parity regen (RUN_GOLDEN_REGEN_CHECK=1): 30/30 specs match byte-for-byte. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
90758f0 to
00f2d64
Compare
5ba3150 to
ab68fca
Compare
…513) (#535) ## Summary Closes #513. Three ops in `sleap_nn/inference/ops` get rewritten to remove constructs the legacy TorchScript ONNX exporter rejected, while preserving prior behavior bit-exactly (or within 1 ULP). ## Op rewrites - **`morphological_dilation`** — replace `Tensor.unfold` with an explicit 8-shift max stack. Drops the `xfail(strict=True)` marker on its ONNX-export smoke test from PR 1; that test now passes. `find_local_peaks_rough` gains exportability as a side effect (was previously documented as a \"known gap\" expecting a raise; now it succeeds). - **`find_global_peaks_rough`** — replace boolean-mask `index_put_` in-place assignment with `torch.where`. Also splits a single `squeeze(dim=(2, 3))` into two single-dim `squeeze` calls (the legacy exporter does not lower the multi-dim form). PR 1's `xfail(strict=True)` marker dropped; smoke test now passes. - **`crop_bboxes`** — replace per-peak `unfold` + advanced-indexing-on-unfolded-view with direct advanced indexing on a zero-padded image. Bbox top-lefts are floored before extraction (matching the prior `.to(torch.long)` truncation), so integer-aligned bboxes produce bit-exact crops; sub-pixel bboxes (centroid-driven top-down stage 2) reproduce the old \"snap to integer pixel\" behavior exactly. Avoids the bilinear-interp drift that `F.grid_sample` would introduce while still removing `unfold`. ## Tests updated - Drop the two PR 1 `xfail(strict=True)` markers (auto-flipped to passing) - Invert `find_local_peaks_rough_known_export_gap` — the function now exports for fixed-shape examples; test asserts that and verifies output parity. Variable-peak-count output remains a runtime-shape constraint that PR 7 (#515) addresses via `find_top_k_peaks` - Bump `test_golden_is_reproducible` float tolerance from strict zero to `atol=1e-5`, `rtol=1e-6` to absorb 1-ULP drift from the `torch.where` rewrite. Two orders of magnitude tighter than the design-doc budget (1e-4 / 1e-5); integer fields still compared exactly ## CUDA test suite Bonus: `tests/inference/test_cuda.py` — 12 module-level-skipif-gated tests covering pure ops, `Outputs` device transfer, `TorchBackend(device='cuda')`, `SingleInstanceLayer` cross-device parity, FP16 drift budget, pin_memory transfer correctness. Skip cleanly on non-CUDA hosts; on a CUDA box run with: ```bash pytest tests/inference/test_cuda.py -v ``` ## Test plan - [x] `RUN_GOLDEN_REGEN_CHECK=1` — 30/30 specs reproduce within 1e-5 of PR 0 goldens (subprocess-isolated capture) - [x] Full inference suite: 242 passed, 8 skipped (CUDA-only) + 12 skipped (CUDA suite), 0 failed, 0 xfailed - [x] PR 4's `SingleInstanceLayer` parity test still passes at 1e-5 atol/rtol - [x] Both ONNX-export `xfail` markers from PR 1 dropped — corresponding tests now pass cleanly 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #512. Lock down the
InferenceLayerabstraction with the proof-of-pattern single-instance subclass — and prove parity end-to-end against the PR 0 golden. Every subsequent layer (PR 6) and the Predictor orchestrator (PR 8) follows this template.What lands
sleap_nn/inference/layers/configs.py—attrs.frozenPreprocessConfig+PostprocessConfigsleap_nn/inference/layers/base.py—InferenceLayerABC: abstractpreprocess/postprocess, concretepredict/__call__/warmup. Includes_to_4d_float_tensorhelper that accepts(H,W),(H,W,C),(C,H,W),(B,H,W,C),(B,C,H,W)numpy or torch — gives every subclass a uniform input contract for the new direct-numpy APIsleap_nn/inference/layers/single_instance.py—SingleInstanceLayer. Decodes confmaps viaops.peaks.find_global_peaks, applies the full coord ladder viaops.coord, returnsOutputsHeadline result
End-to-end parity proven by
test_single_instance_layer_parity_vs_pr0_golden:The new layer stack —
np.ndarrayin,Outputsdataclass out — matches the captured-from-old-code golden within1e-5atol/rtol on the same fixed input. Linchpin of the refactor: as long as this passes, every subsequent PR (5–14) is verifiably parity-preserving on the single-instance path.Test plan
tests/inference/layers/— 22 cases:ModelBackend)_to_4d_float_tensor: 5 input shapes parametrized + numpy/torch + rejection of unsupported types/ranksreturn_confmapsoff-by-default + opt-in populatesOutputs.pred_confmapsSingleInstanceLayer.from_checkpoint(...)is deferred to PR 8 (#516) where checkpoint-load logic consolidates. For now the constructor takes an already-builtModelBackend; tests build via the existingPredictor.from_model_pathsand pull modules out ofinference_model.Full inference suite: 240 passed, 8 skipped, 2 xfailed.
Parity regen: 30/30 specs match.
🤖 Generated with Claude Code