@@ -58,7 +58,7 @@ def activate(
5858 )
5959
6060
61- # -------------- Functions required by the element-moseq ---------------
61+ # -------------- Functions required by element-moseq ---------------
6262
6363
6464def get_kpms_root_data_dir () -> list :
@@ -87,7 +87,7 @@ def get_kpms_processed_data_dir() -> Optional[str]:
8787
8888 Method in parent namespace should provide a string to a directory where KPMS output
8989 files will be stored. If unspecified, output files will be stored in the
90- session directory 'videos' folder, per DeepLabCut default.
90+ session directory 'videos' folder, per Keypoint-MoSeq default.
9191 """
9292 if hasattr (_linking_module , "get_kpms_processed_data_dir" ):
9393 return _linking_module .get_kpms_processed_data_dir ()
@@ -197,14 +197,15 @@ class InferenceTask(dj.Manual):
197197 """
198198
199199 definition = """
200- -> VideoRecording # `VideoRecording` key
201- -> Model # `Model` key
200+ -> VideoRecording # `VideoRecording` key
201+ -> Model # `Model` key
202202 ---
203- -> PoseEstimationMethod # Pose estimation method used for the specified `recording_id`
204- keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording
205- inference_output_dir='' : varchar(1000) # Optional. Sub-directory where the results will be stored
206- inference_desc='' : varchar(1000) # Optional. User-defined description of the inference task
207- num_iterations=NULL : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50.
203+ -> PoseEstimationMethod # Pose estimation method used for the specified `recording_id`
204+ keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording
205+ inference_output_dir='' : varchar(1000) # Optional. Sub-directory where the results will be stored
206+ inference_desc='' : varchar(1000) # Optional. User-defined description of the inference task
207+ num_iterations=NULL : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50.
208+ task_mode='load' : enum('load', 'trigger') # Task mode for the inference task
208209 """
209210
210211
@@ -305,12 +306,14 @@ def make(self, key):
305306 num_iterations ,
306307 model_id ,
307308 pose_estimation_method ,
309+ task_mode ,
308310 ) = (InferenceTask & key ).fetch1 (
309311 "keypointset_dir" ,
310312 "inference_output_dir" ,
311313 "num_iterations" ,
312314 "model_id" ,
313315 "pose_estimation_method" ,
316+ "task_mode" ,
314317 )
315318
316319 kpms_root = get_kpms_root_data_dir ()
@@ -322,7 +325,7 @@ def make(self, key):
322325 )
323326 keypointset_dir = find_full_path (kpms_root , keypointset_dir )
324327
325- inference_output_dir = model_dir / inference_output_dir
328+ inference_output_dir = os . path . join ( model_dir , inference_output_dir )
326329
327330 if not os .path .exists (inference_output_dir ):
328331 os .makedirs (model_dir / inference_output_dir )
@@ -366,55 +369,98 @@ def make(self, key):
366369 f"No valid `kpms_dj_config` found in the parent model directory { model_dir .parent } "
367370 )
368371
369- start_time = datetime .utcnow ()
370- results = apply_model (
371- model = model ,
372- data = data ,
373- metadata = metadata ,
374- pca = pca ,
375- project_dir = model_dir .parent .as_posix (),
376- model_name = Path (model_dir ).name ,
377- results_path = (inference_output_dir / "results.h5" ).as_posix (),
378- return_model = False ,
379- num_iters = num_iterations
380- or 50.0 , # default internal value in the keypoint-moseq function
381- ** kpms_dj_config ,
382- )
383- end_time = datetime .utcnow ()
372+ if task_mode == "trigger" :
373+ start_time = datetime .utcnow ()
374+ results = apply_model (
375+ model = model ,
376+ data = data ,
377+ metadata = metadata ,
378+ pca = pca ,
379+ project_dir = model_dir .parent .as_posix (),
380+ model_name = Path (model_dir ).name ,
381+ results_path = (inference_output_dir / "results.h5" ).as_posix (),
382+ return_model = False ,
383+ num_iters = num_iterations
384+ or 50 , # default internal value in the keypoint-moseq function
385+ ** kpms_dj_config ,
386+ )
387+ end_time = datetime .utcnow ()
384388
385- duration_seconds = (end_time - start_time ).total_seconds ()
389+ duration_seconds = (end_time - start_time ).total_seconds ()
386390
387- save_results_as_csv (
388- results = results ,
389- save_dir = (inference_output_dir / "results_as_csv" ).as_posix (),
390- )
391+ save_results_as_csv (
392+ results = results ,
393+ save_dir = (inference_output_dir / "results_as_csv" ).as_posix (),
394+ )
391395
392- fig , _ = plot_syllable_frequencies (
393- results = results , path = inference_output_dir .as_posix ()
394- )
395- fig .savefig (inference_output_dir / "syllable_frequencies.png" )
396- plt .close (fig )
397-
398- generate_trajectory_plots (
399- coordinates = coordinates ,
400- results = results ,
401- output_dir = (inference_output_dir / "trajectory_plots" ).as_posix (),
402- ** kpms_dj_config ,
403- )
396+ fig , _ = plot_syllable_frequencies (
397+ results = results , path = inference_output_dir .as_posix ()
398+ )
399+ fig .savefig (inference_output_dir / "syllable_frequencies.png" )
400+ plt .close (fig )
401+
402+ generate_trajectory_plots (
403+ coordinates = coordinates ,
404+ results = results ,
405+ output_dir = (inference_output_dir / "trajectory_plots" ).as_posix (),
406+ ** kpms_dj_config ,
407+ )
404408
405- sampled_instances = generate_grid_movies (
406- coordinates = coordinates ,
407- results = results ,
408- output_dir = (inference_output_dir / "grid_movies" ).as_posix (),
409- ** kpms_dj_config ,
410- )
409+ sampled_instances = generate_grid_movies (
410+ coordinates = coordinates ,
411+ results = results ,
412+ output_dir = (inference_output_dir / "grid_movies" ).as_posix (),
413+ ** kpms_dj_config ,
414+ )
411415
412- plot_similarity_dendrogram (
413- coordinates = coordinates ,
414- results = results ,
415- save_path = (inference_output_dir / "similarity_dendogram" ).as_posix (),
416- ** kpms_dj_config ,
417- )
416+ plot_similarity_dendrogram (
417+ coordinates = coordinates ,
418+ results = results ,
419+ save_path = (inference_output_dir / "similarity_dendogram" ).as_posix (),
420+ ** kpms_dj_config ,
421+ )
422+
423+ else :
424+ from keypoint_moseq import (
425+ load_results ,
426+ filter_centroids_headings ,
427+ get_syllable_instances ,
428+ sample_instances ,
429+ )
430+
431+ # load results
432+ results = load_results (
433+ project_dir = Path (inference_output_dir ).parent ,
434+ model_name = Path (inference_output_dir ).parts [- 1 ],
435+ )
436+
437+ # extract sampled_instances
438+ ## extract syllables from results
439+ syllables = {k : v ["syllable" ] for k , v in results .items ()}
440+
441+ ## extract and smooth centroids and headings
442+ centroids = {k : v ["centroid" ] for k , v in results .items ()}
443+ headings = {k : v ["heading" ] for k , v in results .items ()}
444+
445+ filter_size = 9 # default value
446+ centroids , headings = filter_centroids_headings (
447+ centroids , headings , filter_size = filter_size
448+ )
449+
450+ # sample instances for each syllable
451+ syllable_instances = get_syllable_instances (
452+ syllables , min_duration = 3 , min_frequency = 0.005
453+ )
454+
455+ sampled_instances = sample_instances (
456+ syllable_instances = syllable_instances ,
457+ num_samples = 4 * 6 , # minimum rows * cols
458+ coordinates = coordinates ,
459+ centroids = centroids ,
460+ headings = headings ,
461+ )
462+
463+ duration_seconds = None
418464
419465 self .insert1 ({** key , "inference_duration" : duration_seconds })
420466
0 commit comments