|
2 | 2 |
|
3 | 3 | import importlib.metadata |
4 | 4 | import inspect |
| 5 | +import os |
| 6 | +from pathlib import Path |
5 | 7 |
|
| 8 | +import numpy as np |
6 | 9 | import pytest |
7 | 10 |
|
8 | 11 |
|
@@ -75,3 +78,44 @@ def test_dlclive_methods_match_gui_usage(): |
75 | 78 | get_pose_params, _ = _get_signature_params(DLCLive.get_pose) |
76 | 79 | get_pose_missing = {name for name in {"frame", "frame_time"} if name not in get_pose_params} |
77 | 80 | assert not get_pose_missing, f"DLCLive.get_pose signature mismatch, missing: {sorted(get_pose_missing)}" |
| 81 | + |
| 82 | + |
| 83 | +@pytest.mark.dlclive_compat |
| 84 | +def test_dlclive_minimal_inference_smoke(): |
| 85 | + """ |
| 86 | + Real runtime smoke test (init + pose call) using a tiny exported model. |
| 87 | +
|
| 88 | + Opt-in via env vars: |
| 89 | + - DLCLIVE_TEST_MODEL_PATH: absolute/relative path to exported model folder/file |
| 90 | + - DLCLIVE_TEST_MODEL_TYPE: optional model type (default: pytorch) |
| 91 | + """ |
| 92 | + model_path_env = os.getenv("DLCLIVE_TEST_MODEL_PATH", "").strip() |
| 93 | + if not model_path_env: |
| 94 | + pytest.skip("Set DLCLIVE_TEST_MODEL_PATH to run real DLCLive inference smoke test.") |
| 95 | + |
| 96 | + model_path = Path(model_path_env).expanduser() |
| 97 | + if not model_path.exists(): |
| 98 | + pytest.skip(f"DLCLIVE_TEST_MODEL_PATH does not exist: {model_path}") |
| 99 | + |
| 100 | + model_type = os.getenv("DLCLIVE_TEST_MODEL_TYPE", "pytorch").strip() or "pytorch" |
| 101 | + |
| 102 | + from dlclive import DLCLive # noqa: PLC0415 |
| 103 | + from dlclivegui.services.dlc_processor import validate_pose_array # noqa: PLC0415 |
| 104 | + |
| 105 | + dlc = DLCLive( |
| 106 | + model_path=str(model_path), |
| 107 | + model_type=model_type, |
| 108 | + dynamic=[False, 0.5, 10], |
| 109 | + resize=1.0, |
| 110 | + precision="FP32", |
| 111 | + single_animal=True, |
| 112 | + ) |
| 113 | + |
| 114 | + frame = np.zeros((64, 64, 3), dtype=np.uint8) |
| 115 | + dlc.init_inference(frame) |
| 116 | + pose = dlc.get_pose(frame, frame_time=0.0) |
| 117 | + pose_arr = validate_pose_array(pose, source_backend="DLCLive.get_pose") |
| 118 | + |
| 119 | + assert pose_arr.ndim in (2, 3) |
| 120 | + assert pose_arr.shape[-1] == 3 |
| 121 | + assert np.isfinite(pose_arr).all() |
0 commit comments