@@ -119,9 +119,9 @@ class File(dj.Part):
119119
120120 definition = """
121121 -> master
122- file_id: int # Unique ID for each file
122+ file_id: int # Unique ID for each file
123123 ---
124- file_path: varchar(1000) # Filepath of each video, relative to root data directory.
124+ file_path: filepath@moseq-infer-processed # Filepath of each video, relative to root data directory.
125125 """
126126
127127
@@ -196,8 +196,11 @@ class Inference(dj.Computed):
196196 definition = """
197197 -> InferenceTask # `InferenceTask` key
198198 ---
199- syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings
200- inference_duration=NULL : float # Time duration (seconds) of the inference computation
199+ coordinates : longblob # Cleaned coordinates dictionary after outlier removal.
200+ confidences : longblob # Cleaned confidences dictionary after outlier removal.
201+ average_frame_rate : int # Average frame rate of the videos for model training (used for kappa calculation).
202+ syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings
203+ inference_duration=NULL : float # Time duration (seconds) of the inference computation
201204 """
202205
203206 def make_fetch (self , key ):
@@ -261,6 +264,10 @@ def make_compute(
261264 """
262265 Compute model inference results.
263266 """
267+ import glob
268+ import os
269+
270+ import cv2
264271 from keypoint_moseq import (
265272 apply_model ,
266273 format_data ,
@@ -290,8 +297,30 @@ def make_compute(
290297 fullfit_kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
291298 config_path = fullfit_kpms_dj_config_file
292299 )
300+
301+ # calculate average frame rate of all video in keypointset_dir
302+ # Search for multiple common video extensions: mp4, avi, mov, mkv, etc.
303+ video_extensions = [
304+ "*.mp4" ,
305+ "*.avi" ,
306+ "*.mov" ,
307+ "*.mkv" ,
308+ "*.wmv" ,
309+ "*.mpeg" ,
310+ "*.mpg" ,
311+ ]
312+ video_files = []
313+ for ext in video_extensions :
314+ video_files .extend (glob .glob (os .path .join (keypointset_dir , ext )))
315+ frame_rates = []
316+ for video_file in video_files :
317+ cap = cv2 .VideoCapture (video_file )
318+ frame_rate = cap .get (cv2 .CAP_PROP_FPS )
319+ frame_rates .append (frame_rate )
320+ average_frame_rate = np .mean (frame_rates )
321+
322+ # Load fullfit model
293323 fullfit_model , _ , _ , _ = load_checkpoint (path = fullfit_checkpoint_path )
294- # fullfit_model_pca = pickle.load(open(fullfit_pca_file_path, "rb"))
295324
296325 # Load new data
297326 coordinates , confidences , bodyparts = load_keypoints (
@@ -340,22 +369,31 @@ def make_compute(
340369 return (
341370 duration_seconds ,
342371 results_filepath ,
372+ average_frame_rate ,
373+ coordinates ,
374+ confidences ,
343375 )
344376
345377 def make_insert (
346378 self ,
347379 key ,
348380 duration_seconds ,
349381 results_filepath ,
382+ average_frame_rate ,
383+ coordinates ,
384+ confidences ,
350385 ):
351386 """
352387 Insert inference results into the database.
353388 """
354389 self .insert1 (
355390 {
356391 ** key ,
357- "inference_duration" : duration_seconds ,
358392 "syllable_segmentation_file" : results_filepath ,
393+ "coordinates" : coordinates ,
394+ "confidences" : confidences ,
395+ "average_frame_rate" : average_frame_rate ,
396+ "inference_duration" : duration_seconds ,
359397 }
360398 )
361399
0 commit comments