2222SNAPSHOT_NAME = "snapshot-700000.pb"
2323TMP_DIR = Path (__file__ ).parent / "dlc-live-tmp"
2424
25- MODELS_FOLDER = TMP_DIR / "test_models"
25+ MODELS_DIR = TMP_DIR / "test_models"
2626TORCH_MODEL = "resnet_50"
2727TORCH_CONFIG = {
28- "checkpoint" : MODELS_FOLDER / f"exported_quadruped_{ TORCH_MODEL } .pt" ,
28+ "checkpoint" : MODELS_DIR / f"exported_quadruped_{ TORCH_MODEL } .pt" ,
2929 "super_animal" : "superanimal_quadruped" ,
3030}
31- TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
31+ TF_MODEL_DIR = MODELS_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
3232
3333
3434def run_pytorch_test (video_file : str , display : bool = False ):
3535 from dlclive .modelzoo .pytorch_model_zoo_export import export_modelzoo_model
3636
3737 if Engine .PYTORCH not in get_available_backends ():
38- raise NotImplementedError (
38+ raise ImportError (
3939 "PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
4040 )
4141 # Download model from the DeepLabCut Model Zoo
@@ -63,7 +63,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
6363
6464def run_tensorflow_test (video_file : str , display : bool = False ):
6565 if Engine .TENSORFLOW not in get_available_backends ():
66- raise NotImplementedError (
66+ raise ImportError (
6767 "TensorFlow backend is not available. Please ensure TensorFlow is installed to run the TensorFlow test."
6868 )
6969 model_dir = TF_MODEL_DIR
@@ -92,38 +92,39 @@ def run_tensorflow_test(video_file: str, display: bool = False):
9292
9393
9494def main ():
95- tmp_dir = None
96- try :
97- parser = argparse .ArgumentParser (
98- description = "Test DLC-Live installation by downloading and evaluating a demo DLC project!"
99- )
100- parser .add_argument (
101- "--display" ,
102- action = "store_true" ,
103- default = False ,
104- help = "Run the test and display tracking" ,
105- )
106- parser .add_argument (
107- "--nodisplay" ,
108- action = "store_false" ,
109- dest = "display" ,
110- help = argparse .SUPPRESS ,
111- )
95+ backend_results = {}
96+
97+ parser = argparse .ArgumentParser (
98+ description = "Test DLC-Live installation by downloading and evaluating a demo DLC project!"
99+ )
100+ parser .add_argument (
101+ "--display" ,
102+ action = "store_true" ,
103+ default = False ,
104+ help = "Run the test and display tracking" ,
105+ )
106+ parser .add_argument (
107+ "--nodisplay" ,
108+ action = "store_false" ,
109+ dest = "display" ,
110+ help = argparse .SUPPRESS ,
111+ )
112112
113- args = parser .parse_args ()
114- display = args .display
113+ args = parser .parse_args ()
114+ display = args .display
115115
116- if not display :
117- print ("Running without displaying video" )
116+ if not display :
117+ print ("Running without displaying video" )
118118
119- # make temporary directory
120- print ("\n Creating temporary directory...\n " )
121- tmp_dir = TMP_DIR
122- tmp_dir .mkdir (mode = 0o775 , exist_ok = True )
123- MODELS_FOLDER .mkdir (parents = True , exist_ok = True )
119+ # make temporary directory
120+ print ("\n Creating temporary directory...\n " )
121+ tmp_dir = TMP_DIR
122+ tmp_dir .mkdir (mode = 0o775 , exist_ok = True )
123+ MODELS_DIR .mkdir (parents = True , exist_ok = True )
124124
125- video_file = str (tmp_dir / "dog_clip.avi" )
125+ video_file = str (tmp_dir / "dog_clip.avi" )
126126
127+ try :
127128 # download dog test video from github:
128129 # Use raw.githubusercontent.com for direct file access
129130 if not Path (video_file ).exists ():
@@ -147,33 +148,51 @@ def main():
147148 print ("\n Running PyTorch test...\n " )
148149 run_pytorch_test (video_file , display = display )
149150 any_backend_succeeded = True
151+ backend_results ["pytorch" ] = ("SUCCESS" , None )
150152 elif backend == Engine .TENSORFLOW :
151153 print ("\n Running TensorFlow test...\n " )
152154 run_tensorflow_test (video_file , display = display )
153155 any_backend_succeeded = True
156+ backend_results ["tensorflow" ] = ("SUCCESS" , None )
154157 else :
155158 warnings .warn (
156159 f"Unrecognized backend { backend } , skipping..." , UserWarning
157160 )
158161 except Exception as e :
162+ backend_name = (
163+ "pytorch" if backend == Engine .PYTORCH else
164+ "tensorflow" if backend == Engine .TENSORFLOW else
165+ str (backend )
166+ )
167+ backend_results [backend_name ] = ("ERROR" , str (e ))
159168 backend_failures [backend ] = e
160169 warnings .warn (
161170 f"Error while running test for backend { backend } : { e } . "
162171 "Continuing to test other available backends." ,
163172 UserWarning ,
164173 )
165174
166- if not any_backend_succeeded and backend_failures :
167- failure_messages = "; " .join (
168- f"{ b } : { exc } " for b , exc in backend_failures .items ()
169- )
170- raise RuntimeError (f"All backend tests failed. Details: { failure_messages } " )
175+ print ("\n ---\n Backend test summary:" )
176+ for name in ("tensorflow" , "pytorch" ):
177+ status , _ = backend_results .get (name , ("SKIPPED" , None ))
178+ print (f"{ name :<11} [{ status } ]" )
179+ print ("---" )
180+ for name , (status , error ) in backend_results .items ():
181+ if status == "ERROR" :
182+ print (f"{ name .capitalize ()} error:\n { error } \n " )
183+
184+ if not any_backend_succeeded and backend_failures :
185+ failure_messages = "; " .join (
186+ f"{ b } : { exc } " for b , exc in backend_failures .items ()
187+ )
188+ raise RuntimeError (f"All backend tests failed. Details: { failure_messages } " )
189+
171190
172191 finally :
173192 # deleting temporary files
174193 print ("\n Deleting temporary files...\n " )
175194 try :
176- if tmp_dir is not None and tmp_dir .exists ():
195+ if tmp_dir .exists ():
177196 shutil .rmtree (tmp_dir )
178197 except PermissionError :
179198 warnings .warn (
0 commit comments