Skip to content

Commit 7c627e2

Browse files
committed
update Inference table definition
1 parent 2208df5 commit 7c627e2

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

element_moseq/moseq_infer.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ class File(dj.Part):
119119

120120
definition = """
121121
-> master
122-
file_id: int # Unique ID for each file
122+
file_id: int # Unique ID for each file
123123
---
124-
file_path: varchar(1000) # Filepath of each video, relative to root data directory.
124+
file_path: filepath@moseq-infer-processed # Filepath of each video, relative to root data directory.
125125
"""
126126

127127

@@ -196,8 +196,11 @@ class Inference(dj.Computed):
196196
definition = """
197197
-> InferenceTask # `InferenceTask` key
198198
---
199-
syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings
200-
inference_duration=NULL : float # Time duration (seconds) of the inference computation
199+
coordinates : longblob # Cleaned coordinates dictionary after outlier removal.
200+
confidences : longblob # Cleaned confidences dictionary after outlier removal.
201+
average_frame_rate : int # Average frame rate of the videos for model training (used for kappa calculation).
202+
syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings
203+
inference_duration=NULL : float # Time duration (seconds) of the inference computation
201204
"""
202205

203206
def make_fetch(self, key):
@@ -261,6 +264,10 @@ def make_compute(
261264
"""
262265
Compute model inference results.
263266
"""
267+
import glob
268+
import os
269+
270+
import cv2
264271
from keypoint_moseq import (
265272
apply_model,
266273
format_data,
@@ -290,8 +297,30 @@ def make_compute(
290297
fullfit_kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
291298
config_path=fullfit_kpms_dj_config_file
292299
)
300+
301+
# calculate average frame rate of all video in keypointset_dir
302+
# Search for multiple common video extensions: mp4, avi, mov, mkv, etc.
303+
video_extensions = [
304+
"*.mp4",
305+
"*.avi",
306+
"*.mov",
307+
"*.mkv",
308+
"*.wmv",
309+
"*.mpeg",
310+
"*.mpg",
311+
]
312+
video_files = []
313+
for ext in video_extensions:
314+
video_files.extend(glob.glob(os.path.join(keypointset_dir, ext)))
315+
frame_rates = []
316+
for video_file in video_files:
317+
cap = cv2.VideoCapture(video_file)
318+
frame_rate = cap.get(cv2.CAP_PROP_FPS)
319+
frame_rates.append(frame_rate)
320+
average_frame_rate = np.mean(frame_rates)
321+
322+
# Load fullfit model
293323
fullfit_model, _, _, _ = load_checkpoint(path=fullfit_checkpoint_path)
294-
# fullfit_model_pca = pickle.load(open(fullfit_pca_file_path, "rb"))
295324

296325
# Load new data
297326
coordinates, confidences, bodyparts = load_keypoints(
@@ -340,22 +369,31 @@ def make_compute(
340369
return (
341370
duration_seconds,
342371
results_filepath,
372+
average_frame_rate,
373+
coordinates,
374+
confidences,
343375
)
344376

345377
def make_insert(
346378
self,
347379
key,
348380
duration_seconds,
349381
results_filepath,
382+
average_frame_rate,
383+
coordinates,
384+
confidences,
350385
):
351386
"""
352387
Insert inference results into the database.
353388
"""
354389
self.insert1(
355390
{
356391
**key,
357-
"inference_duration": duration_seconds,
358392
"syllable_segmentation_file": results_filepath,
393+
"coordinates": coordinates,
394+
"confidences": confidences,
395+
"average_frame_rate": average_frame_rate,
396+
"inference_duration": duration_seconds,
359397
}
360398
)
361399

0 commit comments

Comments
 (0)