Skip to content

Commit 2b1aa0c

Browse files
committed
Add optional smoke test for exported dlclive model via env vars
1 parent ba20dac commit 2b1aa0c

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/compat/test_dlclive_package_compat.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import importlib.metadata
44
import inspect
5+
import os
6+
from pathlib import Path
57

8+
import numpy as np
69
import pytest
710

811

@@ -75,3 +78,44 @@ def test_dlclive_methods_match_gui_usage():
7578
get_pose_params, _ = _get_signature_params(DLCLive.get_pose)
7679
get_pose_missing = {name for name in {"frame", "frame_time"} if name not in get_pose_params}
7780
assert not get_pose_missing, f"DLCLive.get_pose signature mismatch, missing: {sorted(get_pose_missing)}"
81+
82+
83+
@pytest.mark.dlclive_compat
84+
def test_dlclive_minimal_inference_smoke():
85+
"""
86+
Real runtime smoke test (init + pose call) using a tiny exported model.
87+
88+
Opt-in via env vars:
89+
- DLCLIVE_TEST_MODEL_PATH: absolute/relative path to exported model folder/file
90+
- DLCLIVE_TEST_MODEL_TYPE: optional model type (default: pytorch)
91+
"""
92+
model_path_env = os.getenv("DLCLIVE_TEST_MODEL_PATH", "").strip()
93+
if not model_path_env:
94+
pytest.skip("Set DLCLIVE_TEST_MODEL_PATH to run real DLCLive inference smoke test.")
95+
96+
model_path = Path(model_path_env).expanduser()
97+
if not model_path.exists():
98+
pytest.skip(f"DLCLIVE_TEST_MODEL_PATH does not exist: {model_path}")
99+
100+
model_type = os.getenv("DLCLIVE_TEST_MODEL_TYPE", "pytorch").strip() or "pytorch"
101+
102+
from dlclive import DLCLive # noqa: PLC0415
103+
from dlclivegui.services.dlc_processor import validate_pose_array # noqa: PLC0415
104+
105+
dlc = DLCLive(
106+
model_path=str(model_path),
107+
model_type=model_type,
108+
dynamic=[False, 0.5, 10],
109+
resize=1.0,
110+
precision="FP32",
111+
single_animal=True,
112+
)
113+
114+
frame = np.zeros((64, 64, 3), dtype=np.uint8)
115+
dlc.init_inference(frame)
116+
pose = dlc.get_pose(frame, frame_time=0.0)
117+
pose_arr = validate_pose_array(pose, source_backend="DLCLive.get_pose")
118+
119+
assert pose_arr.ndim in (2, 3)
120+
assert pose_arr.shape[-1] == 3
121+
assert np.isfinite(pose_arr).all()

0 commit comments

Comments
 (0)