Skip to content

Commit f7c942b

Browse files
committed
Refactor check_install tests and model paths
Rename MODELS_FOLDER to MODELS_DIR and update references (TORCH_CONFIG checkpoint and TF_MODEL_DIR) for clearer naming. Change missing-backend errors from NotImplementedError to ImportError to better reflect installation issues. Simplify main(): consolidate arg parsing, consistently create TMP_DIR and MODELS_DIR, and add backend_results tracking to report per-backend SUCCESS/ERROR statuses with a printed summary. Improve error recording for backend failures and adjust cleanup check when removing the temporary directory.
1 parent 0629b15 commit f7c942b

File tree

1 file changed

+57
-38
lines changed

1 file changed

+57
-38
lines changed

dlclive/check_install/check_install.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@
2222
SNAPSHOT_NAME = "snapshot-700000.pb"
2323
TMP_DIR = Path(__file__).parent / "dlc-live-tmp"
2424

25-
MODELS_FOLDER = TMP_DIR / "test_models"
25+
MODELS_DIR = TMP_DIR / "test_models"
2626
TORCH_MODEL = "resnet_50"
2727
TORCH_CONFIG = {
28-
"checkpoint": MODELS_FOLDER / f"exported_quadruped_{TORCH_MODEL}.pt",
28+
"checkpoint": MODELS_DIR / f"exported_quadruped_{TORCH_MODEL}.pt",
2929
"super_animal": "superanimal_quadruped",
3030
}
31-
TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
31+
TF_MODEL_DIR = MODELS_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
3232

3333

3434
def run_pytorch_test(video_file: str, display: bool = False):
3535
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
3636

3737
if Engine.PYTORCH not in get_available_backends():
38-
raise NotImplementedError(
38+
raise ImportError(
3939
"PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
4040
)
4141
# Download model from the DeepLabCut Model Zoo
@@ -63,7 +63,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
6363

6464
def run_tensorflow_test(video_file: str, display: bool = False):
6565
if Engine.TENSORFLOW not in get_available_backends():
66-
raise NotImplementedError(
66+
raise ImportError(
6767
"TensorFlow backend is not available. Please ensure TensorFlow is installed to run the TensorFlow test."
6868
)
6969
model_dir = TF_MODEL_DIR
@@ -92,38 +92,39 @@ def run_tensorflow_test(video_file: str, display: bool = False):
9292

9393

9494
def main():
95-
tmp_dir = None
96-
try:
97-
parser = argparse.ArgumentParser(
98-
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
99-
)
100-
parser.add_argument(
101-
"--display",
102-
action="store_true",
103-
default=False,
104-
help="Run the test and display tracking",
105-
)
106-
parser.add_argument(
107-
"--nodisplay",
108-
action="store_false",
109-
dest="display",
110-
help=argparse.SUPPRESS,
111-
)
95+
backend_results = {}
96+
97+
parser = argparse.ArgumentParser(
98+
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
99+
)
100+
parser.add_argument(
101+
"--display",
102+
action="store_true",
103+
default=False,
104+
help="Run the test and display tracking",
105+
)
106+
parser.add_argument(
107+
"--nodisplay",
108+
action="store_false",
109+
dest="display",
110+
help=argparse.SUPPRESS,
111+
)
112112

113-
args = parser.parse_args()
114-
display = args.display
113+
args = parser.parse_args()
114+
display = args.display
115115

116-
if not display:
117-
print("Running without displaying video")
116+
if not display:
117+
print("Running without displaying video")
118118

119-
# make temporary directory
120-
print("\nCreating temporary directory...\n")
121-
tmp_dir = TMP_DIR
122-
tmp_dir.mkdir(mode=0o775, exist_ok=True)
123-
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
119+
# make temporary directory
120+
print("\nCreating temporary directory...\n")
121+
tmp_dir = TMP_DIR
122+
tmp_dir.mkdir(mode=0o775, exist_ok=True)
123+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
124124

125-
video_file = str(tmp_dir / "dog_clip.avi")
125+
video_file = str(tmp_dir / "dog_clip.avi")
126126

127+
try:
127128
# download dog test video from github:
128129
# Use raw.githubusercontent.com for direct file access
129130
if not Path(video_file).exists():
@@ -147,33 +148,51 @@ def main():
147148
print("\nRunning PyTorch test...\n")
148149
run_pytorch_test(video_file, display=display)
149150
any_backend_succeeded = True
151+
backend_results["pytorch"] = ("SUCCESS", None)
150152
elif backend == Engine.TENSORFLOW:
151153
print("\nRunning TensorFlow test...\n")
152154
run_tensorflow_test(video_file, display=display)
153155
any_backend_succeeded = True
156+
backend_results["tensorflow"] = ("SUCCESS", None)
154157
else:
155158
warnings.warn(
156159
f"Unrecognized backend {backend}, skipping...", UserWarning
157160
)
158161
except Exception as e:
162+
backend_name = (
163+
"pytorch" if backend == Engine.PYTORCH else
164+
"tensorflow" if backend == Engine.TENSORFLOW else
165+
str(backend)
166+
)
167+
backend_results[backend_name] = ("ERROR", str(e))
159168
backend_failures[backend] = e
160169
warnings.warn(
161170
f"Error while running test for backend {backend}: {e}. "
162171
"Continuing to test other available backends.",
163172
UserWarning,
164173
)
165174

166-
if not any_backend_succeeded and backend_failures:
167-
failure_messages = "; ".join(
168-
f"{b}: {exc}" for b, exc in backend_failures.items()
169-
)
170-
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
175+
print("\n---\nBackend test summary:")
176+
for name in ("tensorflow", "pytorch"):
177+
status, _ = backend_results.get(name, ("SKIPPED", None))
178+
print(f"{name:<11} [{status}]")
179+
print("---")
180+
for name, (status, error) in backend_results.items():
181+
if status == "ERROR":
182+
print(f"{name.capitalize()} error:\n{error}\n")
183+
184+
if not any_backend_succeeded and backend_failures:
185+
failure_messages = "; ".join(
186+
f"{b}: {exc}" for b, exc in backend_failures.items()
187+
)
188+
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
189+
171190

172191
finally:
173192
# deleting temporary files
174193
print("\n Deleting temporary files...\n")
175194
try:
176-
if tmp_dir is not None and tmp_dir.exists():
195+
if tmp_dir.exists():
177196
shutil.rmtree(tmp_dir)
178197
except PermissionError:
179198
warnings.warn(

0 commit comments

Comments
 (0)