@@ -85,7 +85,7 @@ class Model(dj.Manual):
8585 ---
8686 model_name : varchar(1000) # User-friendly model name
8787 model_dir : varchar(1000) # Model directory relative to root data directory
88- model_file : filepath@moseq-infer-processed # Model pkl file containing states, parameters, hyperparameters, noise prior, and random seed
88+ model_file : filepath@moseq-infer-processed # Checkpoint file (h5 format)
8989 model_desc='' : varchar(1000) # Optional. User-defined description of the model
9090 -> [nullable] moseq_train.SelectedFullFit # Optional. FullFit key.
9191 """
@@ -225,7 +225,7 @@ def make_fetch(self, key):
225225 model_key = (Model * moseq_train .SelectedFullFit & key ).fetch1 ("KEY" )
226226 checkpoint_file_path = (
227227 moseq_train .FullFit .File & model_key & 'file_name="checkpoint.h5"'
228- ).fetch1 ("file " )
228+ ).fetch1 ("file_path " )
229229 kpms_dj_config_file = (moseq_train .FullFit .ConfigFile & model_key ).fetch1 (
230230 "config_file"
231231 )
@@ -318,48 +318,46 @@ def make_compute(
318318 # Construct the full path to the inference output directory
319319 inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir
320320
321- inference_output_dir . mkdir ( parents = True , exist_ok = True )
322- model_dir = find_full_path ( kpms_processed , model_dir_rel )
323- keypointset_dir = find_full_path ( kpms_root , keypointset_dir )
321+ if task_mode == "trigger" :
322+ if not inference_output_dir . exists ():
323+ inference_output_dir . mkdir ( parents = True , exist_ok = True )
324324
325- kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
326- config_path = kpms_dj_config_file
327- )
325+ keypointset_dir = find_full_path (kpms_root , keypointset_dir )
328326
329- metadata = pickle .load (open (metadata_file_path , "rb" ))
330- data = pickle .load (open (data_file_path , "rb" ))
331- model_data = pickle .load (open (model_file , "rb" ))
332327 if task_mode == "trigger" :
333- results = apply_model (
334- model_name = inference_output_dir .name ,
335- model = model_data ,
336- data = data ,
337- metadata = metadata ,
338- pca = pca_file_path ,
339- project_dir = inference_output_dir .parent ,
340- results_path = (inference_output_dir / "results.h5" ),
341- return_model = False ,
342- num_iters = num_iterations or DEFAULT_NUM_ITERS ,
343- overwrite = True ,
344- save_results = True ,
345- ** kpms_dj_config_dict ,
328+ kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
329+ config_path = kpms_dj_config_file
346330 )
347331
348- # Create results directory and save CSV files
349- save_results_as_csv (
350- results = results ,
351- save_dir = (inference_output_dir / "results_as_csv" ).as_posix (),
352- )
332+ metadata = pickle .load (open (metadata_file_path , "rb" ))
333+ data = pickle .load (open (data_file_path , "rb" ))
334+ model_data = pickle .load (open (model_file , "rb" ))
335+ if task_mode == "trigger" :
336+ results = apply_model (
337+ model_name = inference_output_dir .name ,
338+ model = model_data ,
339+ data = data ,
340+ metadata = metadata ,
341+ pca = pca_file_path ,
342+ project_dir = inference_output_dir .parent ,
343+ results_path = (inference_output_dir / "results.h5" ),
344+ return_model = False ,
345+ num_iters = num_iterations or DEFAULT_NUM_ITERS ,
346+ overwrite = True ,
347+ save_results = True ,
348+ ** kpms_dj_config_dict ,
349+ )
353350
354- end_time = datetime .now (timezone .utc )
355- duration_seconds = (end_time - start_time ).total_seconds ()
351+ # Create results directory and save CSV files
352+ save_results_as_csv (
353+ results = results ,
354+ save_dir = (inference_output_dir / "results_as_csv" ).as_posix (),
355+ )
356+
357+ end_time = datetime .now (timezone .utc )
358+ duration_seconds = (end_time - start_time ).total_seconds ()
356359
357360 else :
358- # For load mode
359- results = load_results (
360- project_dir = model_dir ,
361- model_name = model_dir .name ,
362- )
363361 duration_seconds = None
364362
365363 results_filepath = (inference_output_dir / "results.h5" ).as_posix ()
0 commit comments