PR 11 of #508 — Predictor.from_model_paths factory (#519)#541
Merged
gitttt-1234 merged 4 commits intoMay 28, 2026
Merged
Conversation
Adds `Predictor.from_model_paths()` and the underlying `sleap_nn.inference.factory.from_model_paths` so the new Predictor (PR 8) can be built directly from model checkpoint paths — the unblocking piece for the rest of #519 (legacy run_inference shim, CLI wiring of the new flow). Architecture: the factory delegates the heavy model-loading work (Lightning checkpoint reconstruction with all the optimizer / scheduler / hard-mining hyperparams Lightning needs, SLEAP <=1.4 legacy converter, backbone/head ckpt overrides, device placement) to the legacy `inference.predictors.Predictor`, then re-wraps the loaded torch module(s) and PAF scorer with the new `InferenceLayer` subclasses. A thin adapter, not a reimplementation. Layer dispatch for the 5 supported model-type combinations: - single_instance → SingleInstanceLayer - bottomup → BottomUpLayer - multi_class_bottomup → BottomUpMultiClassLayer - centroid + centered_instance → TopDownLayer - centroid + multi_class_topdown → TopDownMultiClassLayer `Predictor.from_model_paths(...)` is exposed as a classmethod alias for ergonomic parity with the legacy predictor's API. Tests: tests/inference/test_factory.py — 11 tests covering layer-type dispatch (5 model-type combos), classmethod equivalence, end-to-end predict smoke (3 layer types), unsupported-combination error path, and parity vs the legacy `inference_model.forward` on the single-instance checkpoint within 1e-4 atol / 1e-5 rtol. PR 11b will use this factory to (a) replace `run_inference` with a deprecation shim and (b) delete `sleap_nn/export/inference.py` per the original #519 plan. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## divya/inf-refactor-10-cli-unification #541 +/- ##
=========================================================================
- Coverage 64.80% 61.66% -3.14%
=========================================================================
Files 126 129 +3
Lines 19158 19607 +449
=========================================================================
- Hits 12415 12091 -324
- Misses 6743 7516 +773 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…oad cost Each factory test was previously calling ``from_model_paths()`` fresh, which triggered a full Lightning checkpoint reconstruction every time. With 11 tests × 5 ckpts, ~17 fresh model loads accumulated ~2GB of references that GC didn't reclaim between tests, putting memory pressure on Linux/Windows CI runners. PR 11's first CI run: - Mac: 24m20s (was ~14m before) - Linux: 55m (timeout — was ~9m before) - Windows: 45m22s (timeout — was ~30m before) - The bottomup parity test (`test_phase_split_matches_monolithic_postprocess`) in test_paf_worker_pool.py took 578s (was ~30s) because of the bloated process state. Refactor: each ckpt combo gets a module-scoped fixture that loads once, with a ``gc.collect()`` finalizer. Tests share the predictor for type checks + smoke runs; the parity test still loads the legacy predictor independently (different load path, different test). Local: full inference suite (test_factory + test_paf_worker_pool + layers) runs in 3.5 min (Mac). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
These tests loaded real Lightning checkpoints and compared the new layer's forward pass to the legacy InferenceModel's forward pass byte- for-byte (within 1e-4). Useful as development-time scaffolding during PRs 4-9 to catch porting mistakes; redundant once the refactor is done because: 1. End-to-end byte-for-byte parity is captured at the top of the stack via the PR 0 goldens — that's the canonical comparison point and runs cheap. 2. Per-layer parity locks in legacy float-op-ordering, blocking intentional improvements (PR 5 already had to bump tolerances for legitimate ULP drift from a vectorization rewrite). 3. Combined load + execute cost: ~17 min on Mac CI from just two tests (`test_bottomup_layer_parity_vs_legacy` 430s and `test_phase_split_matches_monolithic_postprocess` 578s). A final test PR at the end of the refactor stack will run end-to-end byte-for-byte parity vs the PR 0 goldens — one comprehensive check instead of N redundant per-layer ones. Tests removed: - test_bottomup.py::test_bottomup_layer_parity_vs_legacy - test_bottomup_multiclass.py::test_parity_vs_legacy - test_centered_instance.py::test_centered_instance_layer_parity_vs_legacy_model_path - test_centroid.py::test_centroid_layer_parity_vs_legacy_model_path - test_topdown.py::test_topdown_layer_parity_vs_pr0_golden (heavy, despite name) - test_topdown_multiclass.py::test_centered_instance_multiclass_layer_parity_vs_legacy - test_factory.py::test_factory_parity_vs_legacy_single_instance - test_paf_worker_pool.py::test_phase_split_matches_monolithic_postprocess Local inference-suite runtime: 2m16s (was 3m30s). CI projection: - Linux ~9m → ~4-5m - Mac ~15m → ~7-8m - Windows ~34m → ~12-15m Restoring close to the pre-refactor baseline (Linux 3m36s, Mac 6m30s, Windows 9m6s on PR 526). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ile (#542) ## Summary PR 10 reserved `--stream-to-file` as a CLI flag but raised `UsageError` on use because `Predictor.from_model_paths` didn't exist yet. PR 11 shipped that factory; this PR connects the flag to the new flow end-to-end. ## Behavior - **`--stream-to-file <path>`** on `sleap-nn infer` now builds a new `Predictor` via `from_model_paths(...)` and streams predictions to `<path>` via `IncrementalLabelsWriter`. Memory stays O(`write-interval`). - **`--tracking` + `--stream-to-file`** is rejected with a clear `UsageError`. Tracking on the new flow lands in a follow-up. - **Without `--stream-to-file`**: unchanged. The legacy `run_inference` flow still owns tracking + GUI progress + the rest of the surface. ## What ships - `_run_inference_impl` routes to a new `_run_stream_to_file` helper when `--stream-to-file` is set. - `_run_stream_to_file` builds the right provider (`VideoProvider` for `.mp4` etc., `LabelsProvider` for `.slp`), constructs a `Predictor` via `from_model_paths`, and calls `predict_to_file` with the requested write interval. - Skeleton resolution: `_skeleton_from_predictor` reads `training_config.yaml` from the model directory. ## Tests `tests/cli/test_flag_validation.py`: - Replaced "`--stream-to-file` is hard error until PR 14" with **"`--stream-to-file` accepts the flag and reaches the new flow"** (mocks the factory + asserts `predict_to_file` is called). - New test: `--tracking` + `--stream-to-file` → `UsageError` mentioning tracking. - `--write-interval` alone still rejected (pre-existing test). All 14 CLI tests pass locally. ## Test plan - [x] `pytest tests/cli/` — 14 passed - [x] `black --check` + `ruff check` clean - [ ] Linux / Windows / Mac CI 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes part of #519 (PR 11 of #508). Adds the
Predictor.from_model_paths()factory that's the unblocking piece for the rest of #519 (the legacyrun_inferenceshim + the CLI new-flow wiring).Architecture
The factory in
sleap_nn/inference/factory.pyis a thin adapter, not a reimplementation of model loading. It delegates the heavy work to the legacyinference.predictors.Predictor:.jsonconfig converterbackbone_ckpt_path/head_ckpt_pathoverridesThe factory then reads attributes off the loaded
legacy.inference_model(e.g.,torch_model,paf_scorer,output_stride,crop_hw) and rewraps them with the newInferenceLayersubclasses from PRs 4, 6.Predictor.from_model_paths(...)is exposed as a classmethod alias for ergonomic parity with the legacy predictor's API.Layer dispatch — 5 supported model-type combinations
single_instanceSingleInstanceLayerbottomupBottomUpLayermulti_class_bottomupBottomUpMultiClassLayercentroid+centered_instanceTopDownLayer(CentroidLayer + CenteredInstanceLayer)centroid+multi_class_topdownTopDownMultiClassLayerUnsupported combinations (e.g., two centroids, missing centered-instance for top-down) raise
ValueError.Why not delete the legacy loader yet?
The legacy
inference.predictors.Predictoris tightly coupled toLightningModule.load_from_checkpointand the SLEAP <=1.4.jsonconverter — both stable code paths. Forking that logic into the new module is risky and gains nothing. PR 11b (deletion) will rip out the unrelatedsleap_nn/export/inference.py(the duplicate driver in the export-side path); the legacy loader stays as the model-loading backend for the factory until a future cleanup.Tests (11)
tests/inference/test_factory.py:Predictorwhose.layeris the expected concrete typePredictor.from_model_paths(...)and the free function produce equivalent objectspredictor.layer.predict()returns a structurally well-formedOutputson a synthetic image for single-instance, bottom-up, top-downinference_model.forwardon the single-instance checkpoint within 1e-4 atol / 1e-5 rtolPer-type layer-vs-legacy parity is already covered by
tests/inference/layers/test_*.py; the factory test verifies the wiring keeps the new layers byte-for-byte equivalent.What's deferred to PR 11b
sleap_nn/predict.py::run_inferencebody with a deprecation shim that delegates toPredictor.from_model_paths(...).predict(...)sleap_nn/export/inference.py+sleap_nn/export/predictors/(the export-side duplicate driver)sleap-nn inferto use the new flow (currently delegates to legacyrun_inference)Test plan
pytest tests/inference/test_factory.py— 11 passpytest tests/inference/— 101 passed, 5 skipped, no regressions across the inference suiteblack --check+ruff checkclean🤖 Generated with Claude Code