Skip to content

PR 11 of #508 — Predictor.from_model_paths factory (#519)#541

Merged
gitttt-1234 merged 4 commits into
divya/inf-refactor-10-cli-unificationfrom
divya/inf-refactor-11-from-model-paths
May 28, 2026
Merged

PR 11 of #508 — Predictor.from_model_paths factory (#519)#541
gitttt-1234 merged 4 commits into
divya/inf-refactor-10-cli-unificationfrom
divya/inf-refactor-11-from-model-paths

Conversation

@gitttt-1234

Copy link
Copy Markdown
Collaborator

Summary

Closes part of #519 (PR 11 of #508). Adds the Predictor.from_model_paths() factory that's the unblocking piece for the rest of #519 (the legacy run_inference shim + the CLI new-flow wiring).

Architecture

The factory in sleap_nn/inference/factory.py is a thin adapter, not a reimplementation of model loading. It delegates the heavy work to the legacy inference.predictors.Predictor:

  • Lightning checkpoint reconstruction (with all the optimizer / scheduler / hard-mining hyperparams Lightning needs even when only weights matter)
  • SLEAP <=1.4 legacy .json config converter
  • backbone_ckpt_path / head_ckpt_path overrides
  • Device placement
  • Skeleton hydration

The factory then reads attributes off the loaded legacy.inference_model (e.g., torch_model, paf_scorer, output_stride, crop_hw) and rewraps them with the new InferenceLayer subclasses from PRs 4, 6.

Predictor.from_model_paths(...) is exposed as a classmethod alias for ergonomic parity with the legacy predictor's API.

Layer dispatch — 5 supported model-type combinations

Detected types Layer composition
single_instance SingleInstanceLayer
bottomup BottomUpLayer
multi_class_bottomup BottomUpMultiClassLayer
centroid + centered_instance TopDownLayer (CentroidLayer + CenteredInstanceLayer)
centroid + multi_class_topdown TopDownMultiClassLayer

Unsupported combinations (e.g., two centroids, missing centered-instance for top-down) raise ValueError.

Why not delete the legacy loader yet?

The legacy inference.predictors.Predictor is tightly coupled to LightningModule.load_from_checkpoint and the SLEAP <=1.4 .json converter — both stable code paths. Forking that logic into the new module is risky and gains nothing. PR 11b (deletion) will rip out the unrelated sleap_nn/export/inference.py (the duplicate driver in the export-side path); the legacy loader stays as the model-loading backend for the factory until a future cleanup.

Tests (11)

tests/inference/test_factory.py:

  • Layer-type dispatch (5 combos): each builds a Predictor whose .layer is the expected concrete type
  • Classmethod equivalence: Predictor.from_model_paths(...) and the free function produce equivalent objects
  • Smoke tests (3): predictor.layer.predict() returns a structurally well-formed Outputs on a synthetic image for single-instance, bottom-up, top-down
  • Error path: unsupported combination raises
  • Parity vs legacy inference_model.forward on the single-instance checkpoint within 1e-4 atol / 1e-5 rtol

Per-type layer-vs-legacy parity is already covered by tests/inference/layers/test_*.py; the factory test verifies the wiring keeps the new layers byte-for-byte equivalent.

What's deferred to PR 11b

  • Replace sleap_nn/predict.py::run_inference body with a deprecation shim that delegates to Predictor.from_model_paths(...).predict(...)
  • Delete sleap_nn/export/inference.py + sleap_nn/export/predictors/ (the export-side duplicate driver)
  • Migrate test imports of those internals
  • CLI: flip sleap-nn infer to use the new flow (currently delegates to legacy run_inference)

Test plan

  • pytest tests/inference/test_factory.py — 11 pass
  • pytest tests/inference/ — 101 passed, 5 skipped, no regressions across the inference suite
  • black --check + ruff check clean
  • Linux / Windows / Mac CI

🤖 Generated with Claude Code

Adds `Predictor.from_model_paths()` and the underlying
`sleap_nn.inference.factory.from_model_paths` so the new Predictor
(PR 8) can be built directly from model checkpoint paths — the
unblocking piece for the rest of #519 (legacy run_inference shim,
CLI wiring of the new flow).

Architecture: the factory delegates the heavy model-loading work
(Lightning checkpoint reconstruction with all the optimizer /
scheduler / hard-mining hyperparams Lightning needs, SLEAP <=1.4
legacy converter, backbone/head ckpt overrides, device placement) to
the legacy `inference.predictors.Predictor`, then re-wraps the loaded
torch module(s) and PAF scorer with the new `InferenceLayer`
subclasses. A thin adapter, not a reimplementation.

Layer dispatch for the 5 supported model-type combinations:
- single_instance                 → SingleInstanceLayer
- bottomup                        → BottomUpLayer
- multi_class_bottomup            → BottomUpMultiClassLayer
- centroid + centered_instance    → TopDownLayer
- centroid + multi_class_topdown  → TopDownMultiClassLayer

`Predictor.from_model_paths(...)` is exposed as a classmethod alias
for ergonomic parity with the legacy predictor's API.

Tests: tests/inference/test_factory.py — 11 tests covering layer-type
dispatch (5 model-type combos), classmethod equivalence, end-to-end
predict smoke (3 layer types), unsupported-combination error path,
and parity vs the legacy `inference_model.forward` on the
single-instance checkpoint within 1e-4 atol / 1e-5 rtol.

PR 11b will use this factory to (a) replace `run_inference` with a
deprecation shim and (b) delete `sleap_nn/export/inference.py` per
the original #519 plan.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@codecov

codecov Bot commented May 2, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 64.46809% with 167 lines in your changes missing coverage. Please review.
✅ Project coverage is 61.66%. Comparing base (49b3672) to head (5f26e89).

Files with missing lines Patch % Lines
sleap_nn/inference/layers/exported.py 0.00% 65 Missing ⚠️
sleap_nn/inference/factory.py 59.15% 58 Missing ⚠️
sleap_nn/cli.py 83.08% 23 Missing ⚠️
sleap_nn/inference/predictor.py 72.91% 13 Missing ⚠️
sleap_nn/inference/tracking.py 86.00% 7 Missing ⚠️
sleap_nn/inference/providers.py 96.29% 1 Missing ⚠️
Additional details and impacted files
@@                            Coverage Diff                            @@
##           divya/inf-refactor-10-cli-unification     #541      +/-   ##
=========================================================================
- Coverage                                  64.80%   61.66%   -3.14%     
=========================================================================
  Files                                        126      129       +3     
  Lines                                      19158    19607     +449     
=========================================================================
- Hits                                       12415    12091     -324     
- Misses                                      6743     7516     +773     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

gitttt-1234 and others added 3 commits May 1, 2026 18:09
…oad cost

Each factory test was previously calling ``from_model_paths()`` fresh,
which triggered a full Lightning checkpoint reconstruction every time.
With 11 tests × 5 ckpts, ~17 fresh model loads accumulated ~2GB of
references that GC didn't reclaim between tests, putting memory
pressure on Linux/Windows CI runners. PR 11's first CI run:
- Mac: 24m20s (was ~14m before)
- Linux: 55m (timeout — was ~9m before)
- Windows: 45m22s (timeout — was ~30m before)
- The bottomup parity test (`test_phase_split_matches_monolithic_postprocess`)
  in test_paf_worker_pool.py took 578s (was ~30s) because of the bloated
  process state.

Refactor: each ckpt combo gets a module-scoped fixture that loads
once, with a ``gc.collect()`` finalizer. Tests share the predictor for
type checks + smoke runs; the parity test still loads the legacy
predictor independently (different load path, different test).

Local: full inference suite (test_factory + test_paf_worker_pool +
layers) runs in 3.5 min (Mac).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
These tests loaded real Lightning checkpoints and compared the new
layer's forward pass to the legacy InferenceModel's forward pass byte-
for-byte (within 1e-4). Useful as development-time scaffolding during
PRs 4-9 to catch porting mistakes; redundant once the refactor is
done because:

1. End-to-end byte-for-byte parity is captured at the top of the
   stack via the PR 0 goldens — that's the canonical comparison
   point and runs cheap.
2. Per-layer parity locks in legacy float-op-ordering, blocking
   intentional improvements (PR 5 already had to bump tolerances
   for legitimate ULP drift from a vectorization rewrite).
3. Combined load + execute cost: ~17 min on Mac CI from just two
   tests (`test_bottomup_layer_parity_vs_legacy` 430s and
   `test_phase_split_matches_monolithic_postprocess` 578s).

A final test PR at the end of the refactor stack will run end-to-end
byte-for-byte parity vs the PR 0 goldens — one comprehensive check
instead of N redundant per-layer ones.

Tests removed:
- test_bottomup.py::test_bottomup_layer_parity_vs_legacy
- test_bottomup_multiclass.py::test_parity_vs_legacy
- test_centered_instance.py::test_centered_instance_layer_parity_vs_legacy_model_path
- test_centroid.py::test_centroid_layer_parity_vs_legacy_model_path
- test_topdown.py::test_topdown_layer_parity_vs_pr0_golden (heavy, despite name)
- test_topdown_multiclass.py::test_centered_instance_multiclass_layer_parity_vs_legacy
- test_factory.py::test_factory_parity_vs_legacy_single_instance
- test_paf_worker_pool.py::test_phase_split_matches_monolithic_postprocess

Local inference-suite runtime: 2m16s (was 3m30s). CI projection:
- Linux ~9m → ~4-5m
- Mac ~15m → ~7-8m
- Windows ~34m → ~12-15m

Restoring close to the pre-refactor baseline (Linux 3m36s, Mac 6m30s,
Windows 9m6s on PR 526).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ile (#542)

## Summary

PR 10 reserved `--stream-to-file` as a CLI flag but raised `UsageError`
on use because `Predictor.from_model_paths` didn't exist yet. PR 11
shipped that factory; this PR connects the flag to the new flow
end-to-end.

## Behavior

- **`--stream-to-file <path>`** on `sleap-nn infer` now builds a new
`Predictor` via `from_model_paths(...)` and streams predictions to
`<path>` via `IncrementalLabelsWriter`. Memory stays
O(`write-interval`).
- **`--tracking` + `--stream-to-file`** is rejected with a clear
`UsageError`. Tracking on the new flow lands in a follow-up.
- **Without `--stream-to-file`**: unchanged. The legacy `run_inference`
flow still owns tracking + GUI progress + the rest of the surface.

## What ships

- `_run_inference_impl` routes to a new `_run_stream_to_file` helper
when `--stream-to-file` is set.
- `_run_stream_to_file` builds the right provider (`VideoProvider` for
`.mp4` etc., `LabelsProvider` for `.slp`), constructs a `Predictor` via
`from_model_paths`, and calls `predict_to_file` with the requested write
interval.
- Skeleton resolution: `_skeleton_from_predictor` reads
`training_config.yaml` from the model directory.

## Tests

`tests/cli/test_flag_validation.py`:
- Replaced "`--stream-to-file` is hard error until PR 14" with
**"`--stream-to-file` accepts the flag and reaches the new flow"**
(mocks the factory + asserts `predict_to_file` is called).
- New test: `--tracking` + `--stream-to-file` → `UsageError` mentioning
tracking.
- `--write-interval` alone still rejected (pre-existing test).

All 14 CLI tests pass locally.

## Test plan

- [x] `pytest tests/cli/` — 14 passed
- [x] `black --check` + `ruff check` clean
- [ ] Linux / Windows / Mac CI

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gitttt-1234 gitttt-1234 marked this pull request as ready for review May 28, 2026 18:13
@gitttt-1234 gitttt-1234 merged commit a86b3ff into divya/inf-refactor-10-cli-unification May 28, 2026
@gitttt-1234 gitttt-1234 deleted the divya/inf-refactor-11-from-model-paths branch May 28, 2026 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant