@@ -223,39 +223,26 @@ def make_fetch(self, key):
223223 ) # model dir relative to processed data directory
224224
225225 model_key = (Model * moseq_train .SelectedFullFit & key ).fetch1 ("KEY" )
226- checkpoint_file_path = (
226+ fullfit_checkpoint_path = (
227227 moseq_train .FullFit .File & model_key & 'file_name="checkpoint.h5"'
228228 ).fetch1 ("file_path" )
229- kpms_dj_config_file = ( moseq_train . FullFit . ConfigFile & model_key ). fetch1 (
230- "config_file"
231- )
232- pca_file_path = (
229+ fullfit_kpms_dj_config_file = (
230+ moseq_train . FullFit . ConfigFile & model_key
231+ ). fetch1 ( "config_file" )
232+ fullfit_pca_file_path = (
233233 moseq_train .PCAFit .File & model_key & 'file_name="pca.p"'
234234 ).fetch1 ("file_path" )
235235
236- data_file_path = (
237- moseq_train .PCAFit .File & model_key & 'file_name="data.pkl"'
238- ).fetch1 ("file_path" )
239- metadata_file_path = (
240- moseq_train .PCAFit .File & model_key & 'file_name="metadata.pkl"'
241- ).fetch1 ("file_path" )
242- coordinates , confidences = (moseq_train .PreProcessing & model_key ).fetch (
243- "coordinates" , "confidences"
244- )
245236 return (
246237 keypointset_dir ,
247238 inference_output_dir ,
248239 num_iterations ,
249240 task_mode ,
250241 model_dir_rel ,
251242 model_file ,
252- checkpoint_file_path ,
253- kpms_dj_config_file ,
254- pca_file_path ,
255- data_file_path ,
256- metadata_file_path ,
257- coordinates ,
258- confidences ,
243+ fullfit_checkpoint_path ,
244+ fullfit_kpms_dj_config_file ,
245+ fullfit_pca_file_path ,
259246 )
260247
261248 def make_compute (
@@ -267,34 +254,12 @@ def make_compute(
267254 task_mode ,
268255 model_dir_rel ,
269256 model_file ,
270- checkpoint_file_path ,
271- kpms_dj_config_file ,
272- pca_file_path ,
273- data_file_path ,
274- metadata_file_path ,
275- coordinates ,
276- confidences ,
257+ fullfit_checkpoint_path ,
258+ fullfit_kpms_dj_config_file ,
259+ fullfit_pca_file_path ,
277260 ):
278261 """
279262 Compute model inference results.
280-
281- Args:
282- key (dict): `InferenceTask` primary key.
283- keypointset_dir (str): Directory containing keypoint data.
284- inference_output_dir (str): Output directory for inference results.
285- num_iterations (int): Number of iterations for model fitting.
286- model_id (int): Model ID.
287- pose_estimation_method (str): Pose estimation method.
288- task_mode (str): Task mode ('trigger' or 'load').
289-
290- Raises:
291- FileNotFoundError: If no pca model (`pca.p`) found in the parent model directory.
292- FileNotFoundError: If no model (`checkpoint.h5`) found in the model directory.
293- NotImplementedError: If the format method is not `deeplabcut`.
294- FileNotFoundError: If no valid `kpms_dj_config` found in the parent model directory.
295-
296- Returns:
297- tuple: Inference results including duration, results data, and sampled instances.
298263 """
299264 from keypoint_moseq import (
300265 apply_model ,
@@ -303,6 +268,7 @@ def make_compute(
303268 load_keypoints ,
304269 load_pca ,
305270 load_results ,
271+ outlier_removal ,
306272 save_results_as_csv ,
307273 )
308274
@@ -311,51 +277,60 @@ def make_compute(
311277
312278 start_time = datetime .now (timezone .utc )
313279
314- # Get directories first
280+ # Get directories for new recordings
315281 kpms_root = moseq_train .get_kpms_root_data_dir ()
316282 kpms_processed = moseq_train .get_kpms_processed_data_dir ()
317283
318284 # Construct the full path to the inference output directory
319285 inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir
320-
321- if task_mode == "trigger" :
322- if not inference_output_dir .exists ():
323- inference_output_dir .mkdir (parents = True , exist_ok = True )
324-
325286 keypointset_dir = find_full_path (kpms_root , keypointset_dir )
326287
327288 if task_mode == "trigger" :
328- kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
329- config_path = kpms_dj_config_file
289+ # load saved model data
290+ fullfit_kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
291+ config_path = fullfit_kpms_dj_config_file
330292 )
293+ fullfit_model , _ , _ , _ = load_checkpoint (path = fullfit_checkpoint_path )
294+ # fullfit_model_pca = pickle.load(open(fullfit_pca_file_path, "rb"))
331295
332- metadata = pickle .load (open (metadata_file_path , "rb" ))
333- data = pickle .load (open (data_file_path , "rb" ))
334- model_data = pickle .load (open (model_file , "rb" ))
335- if task_mode == "trigger" :
336- results = apply_model (
337- model_name = inference_output_dir .name ,
338- model = model_data ,
339- data = data ,
340- metadata = metadata ,
341- pca = pca_file_path ,
342- project_dir = inference_output_dir .parent ,
343- results_path = (inference_output_dir / "results.h5" ),
344- return_model = False ,
345- num_iters = num_iterations or DEFAULT_NUM_ITERS ,
346- overwrite = True ,
347- save_results = True ,
348- ** kpms_dj_config_dict ,
349- )
296+ # Load new data
297+ coordinates , confidences , bodyparts = load_keypoints (
298+ filepath_pattern = keypointset_dir , format = "deeplabcut"
299+ )
300+ coordinates , confidences = outlier_removal (
301+ coordinates ,
302+ confidences ,
303+ inference_output_dir ,
304+ overwrite = False ,
305+ ** fullfit_kpms_dj_config_dict ,
306+ )
307+ data , metadata = format_data (
308+ coordinates , confidences , ** fullfit_kpms_dj_config_dict
309+ )
350310
351- # Create results directory and save CSV files
352- save_results_as_csv (
353- results = results ,
354- save_dir = (inference_output_dir / "results_as_csv" ).as_posix (),
355- )
311+ # # apply saved model to new data
312+ results = apply_model (
313+ model = fullfit_model ,
314+ data = data ,
315+ metadata = metadata ,
316+ project_dir = inference_output_dir ,
317+ model_name = inference_output_dir .name ,
318+ results_path = (inference_output_dir / "results.h5" ),
319+ return_model = False ,
320+ num_iters = num_iterations or DEFAULT_NUM_ITERS ,
321+ overwrite = True ,
322+ save_results = True ,
323+ ** fullfit_kpms_dj_config_dict ,
324+ )
325+
326+ # Create results directory and save CSV files
327+ save_results_as_csv (
328+ results = results ,
329+ save_dir = (inference_output_dir / "results_as_csv" ).as_posix (),
330+ )
356331
357- end_time = datetime .now (timezone .utc )
358- duration_seconds = (end_time - start_time ).total_seconds ()
332+ end_time = datetime .now (timezone .utc )
333+ duration_seconds = (end_time - start_time ).total_seconds ()
359334
360335 else :
361336 duration_seconds = None
@@ -406,7 +381,7 @@ class VideoSequence(dj.Part):
406381 latent_states : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model
407382 centroids : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model
408383 headings : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model
409- file_csv : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format)
384+ file : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format)
410385 """
411386
412387 class SampledInstance (dj .Part ):
@@ -509,7 +484,7 @@ def make(self, key):
509484 "latent_states" : latent_states [vid ],
510485 "centroids" : filtered_centroids [vid ],
511486 "headings" : filtered_headings [vid ],
512- "file_csv " : (
487+ "file " : (
513488 inference_output_dir / "results_as_csv" / f"{ vid } .csv"
514489 ).as_posix (),
515490 }
0 commit comments