From 107f9a5c5126b3e3395ad7d30128ae7c7c1c5169 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 23 Oct 2025 17:56:37 +0200 Subject: [PATCH 01/41] update(kpms_reader): return both base and dj config files with `dj_generate_config` function --- element_moseq/readers/kpms_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 63148ee..7845a83 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -88,7 +88,7 @@ def dj_generate_config(project_dir: str, **kwargs) -> str: with open(dj_cfg_path, "w") as f: yaml.safe_dump(cfg, f, sort_keys=False) - return dj_cfg_path + return dj_cfg_path, base_cfg_path def load_kpms_dj_config( @@ -141,10 +141,10 @@ def update_kpms_dj_config(project_dir: str, **kwargs) -> Dict[str, Any]: ) with open(dj_cfg_path, "r") as f: - cfg = yaml.safe_load(f) or {} + updated_cfg_path = yaml.safe_load(f) or {} - cfg.update(kwargs) + updated_cfg_path.update(kwargs) with open(dj_cfg_path, "w") as f: - yaml.safe_dump(cfg, f, sort_keys=False) - return cfg + yaml.safe_dump(updated_cfg_path, f, sort_keys=False) + return updated_cfg_path From 151f399577d42248b7360f58ccc57aa53ab60716 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 23 Oct 2025 17:58:18 +0200 Subject: [PATCH 02/41] refactor(remove moseq_report to add plots as filepath/attach in train and infer): remove `moseq_report` file --- element_moseq/moseq_report.py | 337 ---------------------------------- 1 file changed, 337 deletions(-) delete mode 100644 element_moseq/moseq_report.py diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py deleted file mode 100644 index 7ba68e1..0000000 --- a/element_moseq/moseq_report.py +++ /dev/null @@ -1,337 +0,0 @@ -import importlib -import inspect -import os -import pathlib -import tempfile -from pathlib import Path - -import datajoint as dj -import matplotlib.pyplot as plt -import numpy as np -from element_interface.utils import find_full_path - -from . import moseq_infer, moseq_train -from .plotting import viz_utils -from .readers import kpms_reader - -schema = dj.schema() -_linking_module = None -logger = dj.logger - - -def activate( - report_schema_name: str, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activate this schema. - - Args: - report_schema_name (str): Schema name on the database server to activate the `moseq_infer` schema. - create_schema (bool): When True (default), create schema in the database if it - does not yet exist. - create_tables (bool): When True (default), create schema tables in the database - if they do not yet exist. - linking_module (str): A module (or name) containing the required dependencies. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - assert hasattr( - linking_module, "get_kpms_root_data_dir" - ), "The linking module must specify a lookup function for a root data directory" - - global _linking_module - _linking_module = linking_module - - # activate - schema.activate( - report_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class PreProcessingReport(dj.Imported): - """Store the outlier keypoints plots that are generated in outbox by `moseq_train.PreProcessing`""" - - definition = """ - -> moseq_train.PreProcessing - video_id: int # ID of the matching video file - --- - outlier_plot: attach # A plot of the outlier keypoints - """ - - def make(self, key): - project_rel = (moseq_train.PCATask & key).fetch1("kpms_project_output_dir") - kpms_project_output_dir = ( - Path(moseq_train.get_kpms_processed_data_dir()) / project_rel - ) - video_ids, pose_estimation_paths = ( - moseq_train.KeypointSet.VideoFile & key - ).fetch("video_id", "pose_estimation_path") - # Map pose estimation filename (without .h5 extension) to video id - valid_entries = [ - (vid, p) - for vid, p in zip(video_ids, pose_estimation_paths) - if p and p.strip() # Check for non-empty strings - ] - if not valid_entries: - raise ValueError( - "No valid pose_estimation_paths found - all entries are empty" - ) - - posefile2vid = {Path(p).stem: vid for vid, p in valid_entries} - recording_names = list(posefile2vid.keys()) - - if not recording_names: - raise ValueError( - "No recording names found after processing pose estimation paths" - ) - - # Insert one row per recording - for rec in recording_names: - vid = posefile2vid[rec] - - # Look for outlier plot in QA directory - plot_path = ( - kpms_project_output_dir - / "QA" - / "plots" - / "keypoint_distance_outliers" - / f"{rec}.png" - ) - if not plot_path.exists(): - raise FileNotFoundError( - f"Outlier plot not found for {rec} at {plot_path}" - ) - - self.insert1( - { - **key, - "video_id": int(vid), - "outlier_plot": plot_path.as_posix(), - } - ) - - -@schema -class PCAReport(dj.Computed): - """ - Plots the principal components (PCs) from a PCAFit. - """ - - definition = """ - -> moseq_train.LatentDimension - --- - scree_plot: attach # A cumulative scree plot. - pcs_plot: attach # A visualization of each Principal Component (PC). - """ - - def make(self, key): - # Generate and store plots for the user to choose the latent dimensions in the next step - from keypoint_moseq import load_pca - - kpms_project_output_dir = (moseq_train.PCATask & key).fetch1( - "kpms_project_output_dir" - ) - kpms_project_output_dir = ( - moseq_train.get_kpms_processed_data_dir() / kpms_project_output_dir - ) - kpms_dj_config = kpms_reader.load_kpms_dj_config( - project_dir=kpms_project_output_dir - ) - - pca = load_pca(kpms_project_output_dir.as_posix()) - - # Modified version of plot_scree from keypoint_moseq - scree_fig = plt.figure() - num_pcs = len(pca.components_) - plt.plot(np.arange(num_pcs) + 1, np.cumsum(pca.explained_variance_ratio_)) - plt.xlabel("PCs") - plt.ylabel("Explained variance") - plt.gcf().set_size_inches((2.5, 2)) - plt.grid() - plt.tight_layout() - fname = f"{key['kpset_id']}_{key['bodyparts_id']}" - - # Modified version ofplot_pcs from keypoint_moseq to visualize components of PCs - pcs_fig = viz_utils.plot_pcs( - pca, - **kpms_dj_config, - interactive=False, - project_dir=kpms_project_output_dir, - ) - - tmpdir = tempfile.TemporaryDirectory() - - # plot variance summary - scree_path = pathlib.Path(tmpdir.name) / f"{fname}_scree_plot.png" - scree_fig.savefig(scree_path) - - # plot pcs - pcs_path = pathlib.Path(tmpdir.name) / f"{fname}_pcs_plot.png" - pcs_fig.savefig(pcs_path) - - # insert into table - self.insert1({**key, "scree_plot": scree_path, "pcs_plot": pcs_path}) - tmpdir.cleanup() - - -@schema -class PreFitReport(dj.Imported): - definition = """ - -> moseq_train.PreFit - --- - fitting_progress_pdf: attach # fitting_progress.pdf - fitting_progress_png: attach # fitting_progress.png - """ - - def make(self, key): - prefit_model_name = (moseq_train.PreFit & key).fetch1("model_name") - prefit_model_dir = find_full_path( - moseq_train.get_kpms_processed_data_dir(), prefit_model_name - ) - - pdf_path = prefit_model_dir / "fitting_progress.pdf" - png_path = prefit_model_dir / "fitting_progress.png" - - if not pdf_path.exists(): - raise FileNotFoundError( - f"PreFit PDF progress plot not found at {pdf_path}. " - ) - - if not png_path.exists(): - raise FileNotFoundError( - f"PreFit PNG progress plot not found at {png_path}. " - ) - - # Both files exist, insert them - self.insert1( - {**key, "fitting_progress_pdf": pdf_path, "fitting_progress_png": png_path} - ) - - -@schema -class FullFitReport(dj.Imported): - definition = """ - -> moseq_train.FullFit - --- - fitting_progress_pdf: attach # fitting_progress.pdf - fitting_progress_png: attach # fitting_progress.png - """ - - def make(self, key): - fullfit_model_name = (moseq_train.FullFit & key).fetch1("model_name") - fullfit_model_dir = find_full_path( - moseq_train.get_kpms_processed_data_dir(), fullfit_model_name - ) - - pdf_path = fullfit_model_dir / "fitting_progress.pdf" - png_path = fullfit_model_dir / "fitting_progress.png" - - if not pdf_path.exists(): - raise FileNotFoundError( - f"FullFit PDF progress plot not found at {pdf_path}. " - ) - - if not png_path.exists(): - raise FileNotFoundError( - f"FullFit PNG progress plot not found at {png_path}. " - ) - - # Both files exist, insert them - self.insert1( - {**key, "fitting_progress_pdf": pdf_path, "fitting_progress_png": png_path} - ) - - -@schema -class InferenceReport(dj.Imported): - definition = """ - -> moseq_infer.Inference - --- - syllable_frequencies: attach - similarity_dendrogram_png: attach - similarity_dendrogram_pdf: attach - all_trajectories_gif: attach - all_trajectories_pdf: attach - """ - - class Trajectory(dj.Part): - definition = """ - -> master - syllable_id: int - --- - plot_gif: attach - plot_pdf: attach - grid_movie: attach - """ - - def make(self, key): - import imageio - - task_info = (moseq_infer.InferenceTask & key).fetch1() - model = (moseq_infer.Model & {"model_id": task_info["model_id"]}).fetch1() - - model_dir = find_full_path( - moseq_train.get_kpms_processed_data_dir(), model["model_dir"] - ) - output_dir = Path(model_dir) / task_info["inference_output_dir"] - - # Insert per-inference entry - self.insert1( - { - **key, - "syllable_frequencies": output_dir / "syllable_frequencies.png", - "similarity_dendrogram_png": output_dir / "similarity_dendrogram.png", - "similarity_dendrogram_pdf": output_dir / "similarity_dendrogram.pdf", - "all_trajectories_gif": output_dir - / "trajectory_plots" - / "all_trajectories.gif", - "all_trajectories_pdf": output_dir - / "trajectory_plots" - / "all_trajectories.pdf", - } - ) - - # Insert per-syllable visuals - for syllable in (moseq_infer.Inference.GridMoviesSampledInstances & key).fetch( - "syllable" - ): - video_mp4_path = output_dir / "grid_movies" / f"syllable{syllable}.mp4" - video_mp4_to_gif_path = ( - output_dir / "grid_movies" / f"syllable{syllable}_grid_movie.gif" - ) - reader = imageio.get_reader(video_mp4_path) - fps = reader.get_meta_data()["fps"] - writer = imageio.get_writer(video_mp4_to_gif_path, fps=fps, loop=0) - - for frame in reader: - writer.append_data(frame) - - writer.close() - - self.Trajectory.insert1( - { - **key, - "syllable_id": syllable, - "plot_gif": output_dir - / "trajectory_plots" - / f"syllable{syllable}.gif", - "plot_pdf": output_dir - / "trajectory_plots" - / f"syllable{syllable}.pdf", - "grid_movie": video_mp4_to_gif_path, - } - ) From a71a23ed4b4796b5639b9344ff676a700d98b8c7 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 23 Oct 2025 21:48:02 +0200 Subject: [PATCH 03/41] refactor(kpms_reader) --- element_moseq/readers/kpms_reader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 7845a83..464f544 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -141,10 +141,10 @@ def update_kpms_dj_config(project_dir: str, **kwargs) -> Dict[str, Any]: ) with open(dj_cfg_path, "r") as f: - updated_cfg_path = yaml.safe_load(f) or {} + cfg_dict = yaml.safe_load(f) or {} - updated_cfg_path.update(kwargs) + cfg_dict.update(kwargs) with open(dj_cfg_path, "w") as f: - yaml.safe_dump(updated_cfg_path, f, sort_keys=False) - return updated_cfg_path + yaml.safe_dump(cfg_dict, f, sort_keys=False) + return cfg_dict From 9d0e0f15614d9ce7c1960f1417e85acc8679a201 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 24 Oct 2025 00:46:06 +0200 Subject: [PATCH 04/41] refactor(moseq_train): to include plots from `moseq_report` --- element_moseq/moseq_train.py | 659 +++++++++++++++++++++++++++-------- 1 file changed, 519 insertions(+), 140 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 1bad34f..fb4a06f 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -11,10 +11,11 @@ import cv2 import datajoint as dj +import matplotlib.pyplot as plt import numpy as np from element_interface.utils import find_full_path -from .plotting.viz_utils import copy_pdf_to_png +from .plotting import viz_utils from .readers import kpms_reader schema = dj.schema() @@ -169,7 +170,7 @@ class VideoFile(dj.Part): @schema -class Bodyparts(dj.Manual): +class BodyParts(dj.Manual): """Store the body parts to use in the analysis. Attributes: @@ -198,19 +199,18 @@ class PCATask(dj.Manual): Define the Principal Component Analysis (PCA) task for dimensionality reduction of keypoint data. Attributes: - Bodyparts (foreign key) : Unique ID for each `Bodyparts` key - outlier_scale_factor (int) : Scale factor for outlier detection in keypoint data (default: 6) + BodyParts (foreign key) : Unique ID for each `BodyParts` key + outlier_scale_factor (float) : Scale factor for outlier detection in keypoint data (default: 6) kpms_project_output_dir (str) : Optional. Keypoint-MoSeq project output directory, relative to root data directory task_mode (enum) : 'load' to load existing results, 'trigger' to compute new PCA """ definition = """ - -> Bodyparts # Unique ID for each `Bodyparts` key + -> BodyParts # Unique ID for each `BodyParts` key --- - outlier_scale_factor=6 : int # Scale factor for outlier detection in keypoint data (default: 6) + outlier_scale_factor=6 : float # Scale factor for outlier detection in keypoint data (default: 6) kpms_project_output_dir='' : varchar(255) # Optional. Keypoint-MoSeq project output directory, relative to root data directory task_mode='load' :enum('load','trigger') # 'load' to load existing results, 'trigger' to compute new PCA - """ @@ -234,6 +234,8 @@ class PreProcessing(dj.Computed): confidences : longblob # Dictionary mapping filenames to likelihood scores as ndarrays of shape (n_frames, n_bodyparts) formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). + pre_processing_time : datetime # datetime of the preprocessing execution. + pre_processing_duration : int # Execution time of the preprocessing in seconds. """ class Video(dj.Part): @@ -243,28 +245,36 @@ class Video(dj.Part): --- video_duration : int # Duration of each video in minutes frame_rate : float # Frame rate of the video in frames per second (Hz) + outlier_plot=NULL : attach # Plot of the outlier keypoints + """ + + class ConfigFile(dj.Part): + """ + Store the configuration files (first creation of the config file and the updates after processing). + """ + + definition = """ + -> master + --- + base_config_file: attach # the first creation of the config file + updated_config_file: attach # the updated config file after processing """ def make_fetch(self, key): """ Fetch required data for preprocessing from database tables. """ - anterior_bodyparts, posterior_bodyparts, use_bodyparts = ( - Bodyparts & key + BodyParts & key ).fetch1( "anterior_bodyparts", "posterior_bodyparts", "use_bodyparts", ) - pose_estimation_method, kpset_dir = (KeypointSet & key).fetch1( "pose_estimation_method", "kpset_dir" ) - video_paths, video_ids = (KeypointSet.VideoFile & key).fetch( - "video_path", "video_id" - ) - + keypoint_videofile_metadata = (KeypointSet.VideoFile & key).fetch(as_dict=True) kpms_project_output_dir, task_mode, outlier_scale_factor = ( PCATask & key ).fetch1("kpms_project_output_dir", "task_mode", "outlier_scale_factor") @@ -275,8 +285,7 @@ def make_fetch(self, key): use_bodyparts, pose_estimation_method, kpset_dir, - video_paths, - video_ids, + keypoint_videofile_metadata, kpms_project_output_dir, task_mode, outlier_scale_factor, @@ -290,8 +299,7 @@ def make_compute( use_bodyparts, pose_estimation_method, kpset_dir, - video_paths, - video_ids, + keypoint_videofile_metadata, kpms_project_output_dir, task_mode, outlier_scale_factor, @@ -335,6 +343,7 @@ def make_compute( plot_medoid_distance_outliers, ) + execution_time = datetime.now(timezone.utc) if task_mode == "trigger": from keypoint_moseq import setup_project @@ -349,74 +358,79 @@ def make_compute( kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) videos_dir = find_full_path( - get_kpms_root_data_dir(), Path(video_paths[0]).parent + get_kpms_root_data_dir(), + Path(keypoint_videofile_metadata[0]["video_path"]).parent, ) - if pose_estimation_method == "deeplabcut": from .readers.kpms_reader import _base_config_path - cfg_path = _base_config_path(kpset_dir) - cfg = Path(cfg_path) - if not cfg.exists(): - raise FileNotFoundError( - f"No DLC config.(yml|yaml) found in {kpset_dir}" - ) + # Find pose estimation config file + base_config_file = _base_config_path(kpset_dir) + base_config_file = Path(base_config_file) + if not base_config_file.exists(): + raise FileNotFoundError(f"No DLC config file found in {kpset_dir}") + # Create the kpms output diectory and config files setup_project( project_dir=kpms_project_output_dir.as_posix(), - deeplabcut_config=cfg.as_posix(), + deeplabcut_config=base_config_file.as_posix(), ) - else: raise NotImplementedError( "Currently, `deeplabcut` is the only pose estimation method supported by this Element. Please reach out at `support@datajoint.com` if you use another method." ) - else: kpms_project_output_dir = find_full_path( get_kpms_processed_data_dir(), kpms_project_output_dir ) kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) videos_dir = find_full_path( - get_kpms_root_data_dir(), Path(video_paths[0]).parent + get_kpms_root_data_dir(), + Path(keypoint_videofile_metadata[0]["video_path"]).parent, ) + # Format keypoint data raw_coordinates, raw_confidences, formatted_bodyparts = load_keypoints( filepath_pattern=kpset_dir, format=pose_estimation_method ) - video_metadata_list = [] + # Extract frame rate from keypoint video files + video_metadata_dict = dict() frame_rates = [] - for fp, video_id in zip(video_paths, video_ids): - video_path = (find_full_path(get_kpms_root_data_dir(), fp)).as_posix() + for row in keypoint_videofile_metadata: + video_id = int(row["video_id"]) + video_path = find_full_path( + get_kpms_root_data_dir(), row["video_path"] + ).as_posix() cap = cv2.VideoCapture(video_path) fps = float(cap.get(cv2.CAP_PROP_FPS)) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() - duration_minutes = (frame_count / fps) / 60.0 - frame_rates.append(fps) - - # Get video name for the Video part table - video_key = {"kpset_id": key["kpset_id"], "video_id": video_id} - if KeypointSet.VideoFile & video_key: - video_record = (KeypointSet.VideoFile & video_key).fetch1() - video_name = Path( - video_record["video_path"] - ).stem # Get filename without extension - video_metadata_list.append( - { - "video_id": video_id, - "video_name": video_name, - "video_duration": int(duration_minutes), - "frame_rate": fps, - } + if fps <= 0: + raise ValueError( + f"Invalid FPS ({fps}) for video_id {video_id} at {video_path}" ) - else: - logger.warning(f"Video record not found for video_id {video_id}") - + duration_minutes = int((frame_count / fps) / 60.0) + frame_rates.append(fps) + video_metadata_dict[video_id] = { + "video_name": Path(row["video_path"]).stem, + "video_duration": duration_minutes, + "frame_rate": fps, + "outlier_plot": None, + } average_frame_rate = float(np.mean(frame_rates)) - # Generate a copy of config.yml with the generated/updated info after it is known - kpms_reader.dj_generate_config( + # Generate a copy of pose estimation config file + videos_dir = find_full_path( + get_kpms_root_data_dir(), + Path(keypoint_videofile_metadata[0]["video_path"]).parent, + ) + + # Confirm that `use_bodyparts` are a subset of `formatted_bodyparts` + if not set(use_bodyparts).issubset(set(formatted_bodyparts)): + raise ValueError( + f"use_bodyparts ({use_bodyparts}) is not a subset of formatter bodyparts ({formatted_bodyparts})" + ) + base_config, _ = kpms_reader.dj_generate_config( project_dir=kpms_project_output_dir, video_dir=str(videos_dir), use_bodyparts=list(use_bodyparts), @@ -424,64 +438,78 @@ def make_compute( posterior_bodyparts=list(posterior_bodyparts), outlier_scale_factor=float(outlier_scale_factor), ) + # Update the config file content kpms_reader.update_kpms_dj_config( - kpms_project_output_dir, - fps=average_frame_rate, + kpms_project_output_dir, fps=average_frame_rate ) - # Remove outlier keypoints - kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + # load the udpated config file + kpms_dj_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + + # Clean outlier keypoints and generate plots cleaned_coordinates = {} cleaned_confidences = {} - for recording_name in raw_coordinates: - raw_coords = raw_coordinates[recording_name].copy() - raw_conf = raw_confidences[recording_name].copy() + for row in keypoint_videofile_metadata: + video_id = int(row["video_id"]) + pose_estimation_path = row["pose_estimation_path"] + pose_estimation_name = Path(pose_estimation_path).stem + + raw_coords = raw_coordinates[pose_estimation_name].copy() + raw_conf = raw_confidences[pose_estimation_name].copy() - # Find outliers using medoid distance analysis outliers = find_medoid_distance_outliers( raw_coords, outlier_scale_factor=outlier_scale_factor ) - - # Interpolate keypoints to fix outliers cleaned_coords = interpolate_keypoints(raw_coords, outliers["mask"]) - - # Update confidences for outlier points cleaned_conf = np.where(outliers["mask"], 0, raw_conf) - cleaned_coordinates[recording_name] = cleaned_coords - cleaned_confidences[recording_name] = cleaned_conf + # Keep keys as pose file names for downstream format_data + cleaned_coordinates[pose_estimation_name] = cleaned_coords + cleaned_confidences[pose_estimation_name] = cleaned_conf - # Plot outliers - if formatted_bodyparts is not None: - plot_medoid_distance_outliers( - project_dir=kpms_project_output_dir.as_posix(), - recording_name=recording_name, - original_coordinates=raw_coords, - interpolated_coordinates=cleaned_coords, - outlier_mask=outliers["mask"], - outlier_thresholds=outliers["thresholds"], - **kpms_config, + plot_medoid_distance_outliers( + project_dir=kpms_project_output_dir.as_posix(), + recording_name=pose_estimation_name, + original_coordinates=raw_coords, + interpolated_coordinates=cleaned_coords, + outlier_mask=outliers["mask"], + outlier_thresholds=outliers["thresholds"], + **kpms_dj_config, + ) + plot_file_path = ( + kpms_project_output_dir + / "QA" + / "plots" + / "keypoint_distance_outliers" + / f"{pose_estimation_name}.png" + ) + if not plot_file_path.exists(): + raise FileNotFoundError( + f"Outlier plot file not found at {plot_file_path}" ) + video_metadata_dict[video_id]["outlier_plot"] = plot_file_path.as_posix() + + # path to the config files + base_config_filepath = kpms_reader._base_config_path(kpms_project_output_dir) + updated_config_filepath = kpms_reader._dj_config_path(kpms_project_output_dir) + + completion_time = datetime.now(timezone.utc) + if task_mode == "trigger": + duration_seconds = (completion_time - execution_time).total_seconds() + else: + duration_seconds = None - # Check if outlier plot was created in QA directory - plot_path = ( - kpms_project_output_dir - / "QA" - / "plots" - / "keypoint_distance_outliers" - / f"{recording_name}.png" - ) - if not plot_path.exists(): - raise FileNotFoundError( - f"Could not create outlier plot for {recording_name} at {plot_path}" - ) return ( cleaned_coordinates, cleaned_confidences, formatted_bodyparts, average_frame_rate, - video_metadata_list, + video_metadata_dict, + updated_config_filepath, + base_config_filepath, + execution_time, + duration_seconds, ) def make_insert( @@ -491,32 +519,53 @@ def make_insert( cleaned_confidences, formatted_bodyparts, average_frame_rate, - video_metadata_list, + video_metadata_dict, + updated_config_filepath, + base_config_filepath, + execution_time, + duration_seconds, ): """ - Insert processed data into the PreProcessing table and Video part table. + Insert processed data into the PreProcessing table and part tables. """ + # Insert in the main table self.insert1( - dict( + { **key, - coordinates=cleaned_coordinates, - confidences=cleaned_confidences, - formatted_bodyparts=formatted_bodyparts, - average_frame_rate=average_frame_rate, - ) + "coordinates": cleaned_coordinates, + "confidences": cleaned_confidences, + "formatted_bodyparts": formatted_bodyparts, + "average_frame_rate": average_frame_rate, + "pre_processing_time": execution_time, + "pre_processing_duration": duration_seconds, + } ) - for video_metadata in video_metadata_list: - self.Video.insert1( - dict( - **key, - video_name=video_metadata["video_name"], - video_duration=video_metadata["video_duration"], - frame_rate=video_metadata["frame_rate"], - ) + # Insert video metadata in Video table + if video_metadata_dict: + self.Video.insert( + [ + { + **key, + "video_name": meta["video_name"], + "video_duration": meta["video_duration"], + "frame_rate": meta["frame_rate"], + "outlier_plot": meta["outlier_plot"], + } + for vid, meta in video_metadata_dict.items() + ] ) + # Insert configuration files + self.ConfigFile.insert1( + { + **key, + "base_config_file": base_config_filepath, + "updated_config_file": updated_config_filepath, + } + ) + @schema class PCAFit(dj.Computed): @@ -525,14 +574,28 @@ class PCAFit(dj.Computed): Attributes: PreProcessing (foreign key) : `PreProcessing` Key. pca_fit_time (datetime) : datetime of the PCA fitting analysis. + pca_fit_duration (int) : Execution time of the PCA fitting analysis in seconds. """ definition = """ -> PreProcessing # `PreProcessing` Key --- pca_fit_time=NULL : datetime # datetime of the PCA fitting analysis + pca_fit_duration=NULL : int # Execution time of the PCA fitting analysis in seconds. """ + class File(dj.Part): + """ + Store the PCA files (pca.p file). + """ + + definition = """ + -> master + --- + file_name : varchar(1000) # name of the pca file (e.g. 'pca.p'). + file_path : filepath@moseq-train-processed # path to the pca file (relative to the project output directory). + """ + def make(self, key): """ Format keypoint data and fit PCA model for dimensionality reduction. @@ -546,6 +609,8 @@ def make(self, key): 3. Fit PCA model and save as `pca.p` file. 4. Insert creation datetime into table. """ + import tempfile + from keypoint_moseq import fit_pca, format_data, save_pca kpms_project_output_dir, task_mode = (PCATask & key).fetch1( @@ -554,27 +619,63 @@ def make(self, key): kpms_project_output_dir = ( Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) - - kpms_default_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) + + # Load the configuration from the file + kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + + execution_time = datetime.now(timezone.utc) + + # Format keypoint data data, _ = format_data( - **kpms_default_config, coordinates=coordinates, confidences=confidences + **kpms_config, coordinates=coordinates, confidences=confidences ) + # Fit PCA model and save as `pca.p` file if task_mode == "trigger": - pca = fit_pca(**data, **kpms_default_config) + pca = fit_pca(**data, **kpms_config) save_pca(pca, kpms_project_output_dir.as_posix()) - creation_datetime = datetime.now(timezone.utc) + + # Check for pca.p file + pca_p_file = kpms_project_output_dir / "pca.p" + + if not pca_p_file.exists(): + raise FileNotFoundError( + f"No pca file (`pca.p`) found in the project directory {kpms_project_output_dir}" + ) + + completion_time = datetime.now(timezone.utc) + + if task_mode == "trigger": + duration_seconds = (completion_time - execution_time).total_seconds() else: - creation_datetime = None + duration_seconds = None - self.insert1(dict(**key, pca_fit_time=creation_datetime)) + # Insert in the main table + self.insert1( + { + **key, + "pca_fit_time": execution_time, + "pca_fit_duration": duration_seconds, + } + ) + + # Insert in File table + self.File.insert( + [ + { + **key, + "file_name": pca_p_file.name, + "file_path": pca_p_file, + } + ] + ) @schema -class LatentDimension(dj.Imported): +class LatentDimension(dj.Computed): """ Determine the optimal latent dimension for model fitting based on variance explained by PCA components. @@ -593,6 +694,19 @@ class LatentDimension(dj.Imported): latent_dim_desc : varchar(1000) # Automated description of the computation result. """ + class Plots(dj.Part): + """ + Store the PCA visualization plots. + """ + + definition = """ + -> master + --- + scree_plot: attach # A cumulative scree plot showing explained variance + pcs_plot: attach # A visualization of each Principal Component (PC) + pcs_xy_plot: attach # A visualization of the Principal Components (PCs) in the XY plane + """ + def make(self, key): """ Compute and store the optimal latent dimension based on 90% variance threshold. @@ -619,14 +733,9 @@ def make(self, key): Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) - pca_path = kpms_project_output_dir / "pca.p" - if pca_path.exists(): - pca = load_pca(kpms_project_output_dir.as_posix()) - else: - raise FileNotFoundError( - f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" - ) - + # Fetch PCA file path from upstream PCAFit.File table + pca_path = (PCAFit.File & key & 'file_name="pca.p"').fetch1("file_path") + pca = load_pca(Path(pca_path).parent.as_posix()) cs = np.cumsum( pca.explained_variance_ratio_ ) # explained_variance_ratio_ndarray of shape (n_components,) @@ -642,6 +751,48 @@ def make(self, key): variance_percentage = VARIANCE_THRESHOLD * 100 latent_dim_desc = f">={VARIANCE_THRESHOLD*100}% of variance explained by {(cs>VARIANCE_THRESHOLD).nonzero()[0].min()+1} components." + # Load the configuration from the file + kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + + # Generate scree plot + scree_fig = plt.figure() + num_pcs = len(pca.components_) + plt.plot(np.arange(num_pcs) + 1, np.cumsum(pca.explained_variance_ratio_)) + plt.xlabel("PCs") + plt.ylabel("Explained variance") + plt.gcf().set_size_inches((2.5, 2)) + plt.grid() + plt.tight_layout() + + # Generate PCs plot + pcs_fig = viz_utils.plot_pcs( + pca, + **kpms_config, + interactive=False, + project_dir=kpms_project_output_dir, + ) + + # Load the pcs-xy.pdf file + pcs_xy_file = kpms_project_output_dir / "pcs-xy.pdf" + + if not pcs_xy_file.exists(): + raise FileNotFoundError( + f"No pcs xy file (`pcs-xy.pdf`) found in the project directory {kpms_project_output_dir}" + ) + + # Save plots to temporary directory + import tempfile + + tmpdir = tempfile.TemporaryDirectory() + fname = f"{key['kpset_id']}_{key['bodyparts_id']}" + + scree_path = Path(tmpdir.name) / f"{fname}_scree_plot.png" + scree_fig.savefig(scree_path) + + pcs_path = Path(tmpdir.name) / f"{fname}_pcs_plot.png" + pcs_fig.savefig(pcs_path) + + # Insert main results self.insert1( dict( **key, @@ -651,6 +802,18 @@ def make(self, key): ) ) + # Insert plots + self.Plots.insert1( + { + **key, + "scree_plot": scree_path, + "pcs_plot": pcs_path, + "pcs_xy_plot": pcs_xy_file, + } + ) + + tmpdir.cleanup() + @schema class PreFitTask(dj.Manual): @@ -692,9 +855,47 @@ class PreFit(dj.Computed): -> PreFitTask # `PreFitTask` Key --- model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" + pre_fit_time=NULL : datetime # datetime of the model fitting computation. pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation """ + class ConfigFile(dj.Part): + """ + Store the updated configuration file after PreFit computation. + """ + + definition = """ + -> master + --- + updated_config_file: attach # Updated config file after PreFit computation + """ + + class CheckpointFile(dj.Part): + """ + Store the checkpoint file used for resuming the fitting process. + """ + + definition = """ + -> master + --- + checkpoint_file_name: varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.p'). + checkpoint_file: filepath@moseq-train-processed # path to the checkpoint file. + """ + + class Plots(dj.Part): + """ + Store the fitting progress of the PreFit computation: + - Plots in PDF and PNG formats used for visualization. + - Checkpoint file used for resuming the fitting process (~500MB). + """ + + definition = """ + -> master + --- + fitting_progress_plot_png: attach + fitting_progress_plot_pdf: attach + """ + def make(self, key): """ Fit AR-HMM model for initial behavioral syllable discovery. @@ -747,6 +948,7 @@ def make(self, key): # Load the updated config for use in model fitting kpms_dj_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + # Load the PCA model from the project directory pca_path = kpms_project_output_dir / "pca.p" if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) @@ -755,6 +957,7 @@ def make(self, key): f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) + # Format the data for model fitting coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) @@ -762,6 +965,7 @@ def make(self, key): coordinates=coordinates, confidences=confidences, **kpms_dj_config ) + # Update the kpms_dj_config.yml with the new sigmasq_loc kpms_reader.update_kpms_dj_config( kpms_project_output_dir, sigmasq_loc=estimate_sigmasq_loc( @@ -769,22 +973,30 @@ def make(self, key): ), ) + # Load the updated config for use in model fitting kpms_dj_config = kpms_reader.load_kpms_dj_config( project_dir=kpms_project_output_dir ) + # Initialize the model model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) + # Update the model hyperparameters model = update_hypparams( model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim) ) - model_name_str = f"latent_dim_{int(pre_latent_dim)}_kappa_{float(pre_kappa)}_iters_{int(pre_num_iterations)}" + # Determine model directory name for outputs + if model_name is None or not str(model_name).strip(): + model_dir_name = f"latent_dim_{float(pre_latent_dim)}_kappa_{float(pre_kappa)}_iters_{float(pre_num_iterations)}" + else: + model_dir_name = str(model_name) - start_time = datetime.now(timezone.utc) + execution_time = datetime.now(timezone.utc) + # Fit the model model, model_name = fit_model( model=model, - model_name=model_name_str, + model_name=model_dir_name, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), @@ -793,11 +1005,44 @@ def make(self, key): generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ save_every_n_iters=25, ) - end_time = datetime.now(timezone.utc) - duration_seconds = (end_time - start_time).total_seconds() - copy_pdf_to_png(kpms_project_output_dir, model_name) + # Normalize to folder name returned by fit_model + model_dir_name = Path(model_name).name + + # Copy the PDF progress plot to PNG + viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_dir_name) + + else: + # Load mode must specify a model_name + if model_name is None or not str(model_name).strip(): + raise ValueError("model_name is required when task_mode='load'") + model_dir_name = Path(model_name).name + + # Get the path to the updated config file + updated_cfg_path = (kpms_project_output_dir / "kpms_dj_config.yml").as_posix() + + # Check for fitting progress files + prefit_model_dir = kpms_project_output_dir / model_dir_name + pdf_path = prefit_model_dir / "fitting_progress.pdf" + png_path = prefit_model_dir / "fitting_progress.png" + if not pdf_path.exists(): + raise FileNotFoundError(f"PreFit PDF progress plot not found at {pdf_path}") + if not png_path.exists(): + raise FileNotFoundError(f"PreFit PNG progress plot not found at {png_path}") + + # Find checkpoint file + checkpoint_files = [] + for pattern in ("checkpoint*", "*.h5"): + checkpoint_files.extend(prefit_model_dir.glob(pattern)) + if checkpoint_files: + checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_size) + else: + raise FileNotFoundError(f"No checkpoint files found in {prefit_model_dir}") + completion_time = datetime.now(timezone.utc) + + if task_mode == "trigger": + duration_seconds = (completion_time - execution_time).total_seconds() else: duration_seconds = None @@ -806,12 +1051,35 @@ def make(self, key): **key, "model_name": ( kpms_project_output_dir.relative_to(get_kpms_processed_data_dir()) - / model_name + / model_dir_name ).as_posix(), "pre_fit_duration": duration_seconds, } ) + self.ConfigFile.insert1( + dict( + **key, + updated_config_file=updated_cfg_path, + ) + ) + + self.Plots.insert1( + { + **key, + "fitting_progress_plot_png": png_path, + "fitting_progress_plot_pdf": pdf_path, + } + ) + + self.CheckpointFile.insert1( + { + **key, + "checkpoint_file_name": checkpoint_file.name, + "checkpoint_file": checkpoint_file, + } + ) + @schema class FullFitTask(dj.Manual): @@ -853,9 +1121,46 @@ class FullFit(dj.Computed): -> FullFitTask # `FullFitTask` Key --- model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" + full_fit_time=NULL : datetime # datetime of the full fitting computation. full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation """ + class ConfigFile(dj.Part): + """ + Store the updated configuration file after FullFit computation. + """ + + definition = """ + -> master + --- + updated_config_file: attach # the updated config file after FullFit computation + """ + + class CheckpointFile(dj.Part): + """ + Store the checkpoint file used for resuming the fitting process. + """ + + definition = """ + -> master + --- + checkpoint_file_name: varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.p'). + checkpoint_file: filepath@moseq-train-processed # path to the checkpoint file. + """ + + class Plots(dj.Part): + """ + Store the fitting progress of the FullFit computation: + - Plots in PDF and PNG formats used for visualization. + """ + + definition = """ + -> master + --- + fitting_progress_plot_png: attach + fitting_progress_plot_pdf: attach + """ + def make(self, key): """ Fit the complete Keypoint-SLDS model with spatial and temporal dynamics. @@ -888,7 +1193,6 @@ def make(self, key): get_kpms_processed_data_dir(), (PCATask & key).fetch1("kpms_project_output_dir"), ) - full_latent_dim, full_kappa, full_num_iterations, task_mode, model_name = ( FullFitTask & key ).fetch1( @@ -898,17 +1202,22 @@ def make(self, key): "task_mode", "model_name", ) + if task_mode == "trigger": + + # Update the kpms_dj_config.yml with latent dimension and kappa values kpms_reader.update_kpms_dj_config( project_dir=kpms_project_output_dir, latent_dim=int(full_latent_dim), kappa=float(full_kappa), ) + # Load the updated config for data formatting kpms_dj_config = kpms_reader.load_kpms_dj_config( project_dir=kpms_project_output_dir ) + # Load the PCA model pca_path = kpms_project_output_dir / "pca.p" if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) @@ -917,12 +1226,15 @@ def make(self, key): f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) + # Format the data for model fitting coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) data, metadata = format_data( coordinates=coordinates, confidences=confidences, **kpms_dj_config ) + + # Update the kpms_dj_config.yml with the new sigmasq_loc kpms_reader.update_kpms_dj_config( project_dir=kpms_project_output_dir, sigmasq_loc=estimate_sigmasq_loc( @@ -930,21 +1242,29 @@ def make(self, key): ), ) + # Load the updated config for use in model fitting kpms_dj_config = kpms_reader.load_kpms_dj_config( project_dir=kpms_project_output_dir ) + # Initialize the model model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) + # Update the model hyperparameters model = update_hypparams( model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) ) + # Generate the model directory name + if model_name is None or not str(model_name).strip(): + model_dir_name = f"latent_dim_{float(full_latent_dim)}_kappa_{float(full_kappa)}_iters_{float(full_num_iterations)}" + else: + model_dir_name = str(model_name) - model_name_str = f"latent_dim_{int(full_latent_dim)}_kappa_{float(full_kappa)}_iters_{int(full_num_iterations)}" + execution_time = datetime.now(timezone.utc) - start_time = datetime.now(timezone.utc) + # Fit the model model, model_name = fit_model( model=model, - model_name=model_name_str, + model_name=model_dir_name, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), @@ -953,15 +1273,47 @@ def make(self, key): generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ save_every_n_iters=25, ) - end_time = datetime.now(timezone.utc) - duration_seconds = (end_time - start_time).total_seconds() + # Reindex the syllables in the checkpoint file reindex_syllables_in_checkpoint( project_dir=kpms_project_output_dir.as_posix(), - model_name=Path(model_name).parts[-1], + model_name=Path(model_name).name, ) - copy_pdf_to_png(kpms_project_output_dir, model_name) + # Copy the PDF progress plot to PNG + viz_utils.copy_pdf_to_png(kpms_project_output_dir, Path(model_name).name) + + # Get the path to the updated config file + updated_cfg_path = kpms_reader._dj_config_path(kpms_project_output_dir) + + # Get the path to the full fit model directory + fullfit_model_dir = kpms_project_output_dir / Path(model_name).name + + # Check for progress plot files + pdf_path = fullfit_model_dir / "fitting_progress.pdf" + png_path = fullfit_model_dir / "fitting_progress.png" + if not pdf_path.exists(): + raise FileNotFoundError( + f"FullFit PDF progress plot not found at {pdf_path}" + ) + if not png_path.exists(): + raise FileNotFoundError( + f"FullFit PNG progress plot not found at {png_path}" + ) + + # Find checkpoint file + checkpoint_files = [] + for pattern in ("checkpoint*", "*.h5"): + checkpoint_files.extend(fullfit_model_dir.glob(pattern)) + if checkpoint_files: + checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_size) + else: + raise FileNotFoundError(f"No checkpoint files found in {fullfit_model_dir}") + + completion_time = datetime.now(timezone.utc) + + if task_mode == "trigger": + duration_seconds = (completion_time - execution_time).total_seconds() else: duration_seconds = None @@ -970,12 +1322,39 @@ def make(self, key): **key, "model_name": ( kpms_project_output_dir.relative_to(get_kpms_processed_data_dir()) - / model_name + / Path(model_name).name ).as_posix(), + "full_fit_time": completion_time, "full_fit_duration": duration_seconds, } ) + # Insert config file + self.ConfigFile.insert1( + { + **key, + "updated_config_file": updated_cfg_path, + } + ) + + # Insert plots + self.Plots.insert1( + { + **key, + "fitting_progress_plot_png": png_path, + "fitting_progress_plot_pdf": pdf_path, + } + ) + + # Insert checkpoint file + self.CheckpointFile.insert1( + { + **key, + "checkpoint_file_name": checkpoint_file.name, + "checkpoint_file": checkpoint_file, + } + ) + @schema class SelectedFullFit(dj.Manual): From 0390ee9a8cab6e2ee430a762b208b244a2723461 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 24 Oct 2025 00:46:52 +0200 Subject: [PATCH 05/41] update(tutorial_pipeline): to remove `moseq_report` --- notebooks/tutorial_pipeline.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/notebooks/tutorial_pipeline.py b/notebooks/tutorial_pipeline.py index 0f86f32..5e98937 100644 --- a/notebooks/tutorial_pipeline.py +++ b/notebooks/tutorial_pipeline.py @@ -4,7 +4,7 @@ from element_lab import lab from element_animal import subject from element_session import session_with_datetime as session -from element_moseq import moseq_train, moseq_infer, moseq_report +from element_moseq import moseq_train, moseq_infer from element_animal.subject import Subject from element_lab.lab import Source, Lab, Protocol, User, Project @@ -44,7 +44,6 @@ def get_kpms_processed_data_dir() -> str: "session", "moseq_train", "moseq_infer", - "moseq_report", "Device", ] @@ -88,4 +87,3 @@ class Device(dj.Lookup): moseq_train.activate(db_prefix + "moseq_train", linking_module=__name__) moseq_infer.activate(db_prefix + "moseq_infer", linking_module=__name__) -moseq_report.activate(db_prefix + "moseq_report", linking_module=__name__) From 672055b21f0da4fbde2b644ee4c7682328374ffc Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 24 Oct 2025 00:47:29 +0200 Subject: [PATCH 06/41] refactor(moseq_infer): to include plots from `moseq_report` --- element_moseq/moseq_infer.py | 149 ++++++++++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 0919dc4..f6a7fae 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -162,7 +162,7 @@ class Inference(dj.Computed): definition = """ -> InferenceTask # `InferenceTask` key --- - syllable_segmentation_file : attach # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings + syllable_segmentation_file : attach # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings inference_duration=NULL : float # Time duration (seconds) of the inference computation """ @@ -205,6 +205,42 @@ class GridMoviesSampledInstances(dj.Part): instances: longblob # List of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame """ + class InferencePlots(dj.Part): + """Store the main inference plots. + + Attributes: + InferenceTask (foreign key) : `InferenceTask` key. + plot_name (varchar) : Name of the plot. + plot_file (attach) : File path of the plot. + """ + + definition = """ + -> master + plot_name: varchar(150) # Name of the plot (e.g. syllable_frequencies, similarity_dendrogram_png, similarity_dendrogram_pdf, all_trajectories_gif, all_trajectories_pdf) + --- + plot_file: attach # File path of the plot + """ + + class TrajectoryPlots(dj.Part): + """Store the per-syllable trajectory plots. + + Attributes: + InferenceTask (foreign key) : `InferenceTask` key. + syllable_id (int) : Syllable ID. + plot_gif (attach) : GIF plot file. + plot_pdf (attach) : PDF plot file. + grid_movie (attach) : Grid movie file. + """ + + definition = """ + -> master + syllable_id: int # Syllable ID + --- + plot_gif: attach # GIF plot file + plot_pdf: attach # PDF plot file + grid_movie: attach # Grid movie file + """ + def make_fetch(self, key): """ Fetch data required for model inference. @@ -387,7 +423,7 @@ def make_compute( ) else: - + # For load mode # load results results = load_results( project_dir=inference_output_dir.parent, @@ -444,10 +480,107 @@ def make_compute( {**key, "syllable": syllable, "instances": sampled_instance} ) + # Prepare inference plots data + inference_plots_data = [] + if task_mode == "trigger": + # Main plots generated during trigger mode + inference_plots_data = [ + { + **key, + "plot_name": "syllable_frequencies", + "plot_file": inference_output_dir / "syllable_frequencies.png", + }, + { + **key, + "plot_name": "similarity_dendrogram_png", + "plot_file": inference_output_dir / "similarity_dendrogram.png", + }, + { + **key, + "plot_name": "similarity_dendrogram_pdf", + "plot_file": inference_output_dir / "similarity_dendrogram.pdf", + }, + { + **key, + "plot_name": "all_trajectories_gif", + "plot_file": inference_output_dir + / "trajectory_plots" + / "all_trajectories.gif", + }, + { + **key, + "plot_name": "all_trajectories_pdf", + "plot_file": inference_output_dir + / "trajectory_plots" + / "all_trajectories.pdf", + }, + ] + else: + # For load mode, check if files exist + main_plots = [ + ("syllable_frequencies", "syllable_frequencies.png"), + ("similarity_dendrogram_png", "similarity_dendrogram.png"), + ("similarity_dendrogram_pdf", "similarity_dendrogram.pdf"), + ("all_trajectories_gif", "trajectory_plots/all_trajectories.gif"), + ("all_trajectories_pdf", "trajectory_plots/all_trajectories.pdf"), + ] + + for plot_name, file_path in main_plots: + full_path = inference_output_dir / file_path + if full_path.exists(): + inference_plots_data.append( + { + **key, + "plot_name": plot_name, + "plot_file": full_path, + } + ) + + # Prepare trajectory plots data + trajectory_plots_data = [] + for syllable in sampled_instances.keys(): + syllable_gif = ( + inference_output_dir / "trajectory_plots" / f"syllable{syllable}.gif" + ) + syllable_pdf = ( + inference_output_dir / "trajectory_plots" / f"syllable{syllable}.pdf" + ) + grid_movie_mp4 = ( + inference_output_dir / "grid_movies" / f"syllable{syllable}.mp4" + ) + grid_movie_gif = ( + inference_output_dir + / "grid_movies" + / f"syllable{syllable}_grid_movie.gif" + ) + + # Convert MP4 to GIF if needed (from InferenceReport logic) + if grid_movie_mp4.exists() and not grid_movie_gif.exists(): + import imageio + + reader = imageio.get_reader(grid_movie_mp4) + fps = reader.get_meta_data()["fps"] + writer = imageio.get_writer(grid_movie_gif, fps=fps, loop=0) + for frame in reader: + writer.append_data(frame) + writer.close() + + trajectory_plots_data.append( + { + **key, + "syllable_id": syllable, + "plot_gif": syllable_gif if syllable_gif.exists() else None, + "plot_pdf": syllable_pdf if syllable_pdf.exists() else None, + "grid_movie": grid_movie_gif if grid_movie_gif.exists() else None, + } + ) + return ( duration_seconds, motion_sequence_data, grid_movie_data, + inference_plots_data, + trajectory_plots_data, inference_output_dir, ) @@ -457,6 +590,8 @@ def make_insert( duration_seconds, motion_sequence_data, grid_movie_data, + inference_plots_data, + trajectory_plots_data, inference_output_dir, ): """ @@ -477,3 +612,13 @@ def make_insert( for grid_record in grid_movie_data: self.GridMoviesSampledInstances.insert1(grid_record) + + for plot_record in inference_plots_data: + self.InferencePlots.insert1(plot_record) + + for trajectory_record in trajectory_plots_data: + if any( + trajectory_record.get(field) + for field in ["plot_gif", "plot_pdf", "grid_movie"] + ): + self.TrajectoryPlots.insert1(trajectory_record) From 116c31cc432f30fd31bb377a6ae9f2814b5bf2e8 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 24 Oct 2025 01:13:06 +0200 Subject: [PATCH 07/41] minor update --- element_moseq/moseq_train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index fb4a06f..7db25e9 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -463,8 +463,6 @@ def make_compute( ) cleaned_coords = interpolate_keypoints(raw_coords, outliers["mask"]) cleaned_conf = np.where(outliers["mask"], 0, raw_conf) - - # Keep keys as pose file names for downstream format_data cleaned_coordinates[pose_estimation_name] = cleaned_coords cleaned_confidences[pose_estimation_name] = cleaned_conf From 9aedae0ea66e99dcd0844331faeae5b5ef8b7b98 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 24 Oct 2025 03:01:35 +0200 Subject: [PATCH 08/41] chore: rename new attrib from `updated_config_file` to `config_file` --- element_moseq/moseq_train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 7db25e9..e3a1384 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -257,7 +257,7 @@ class ConfigFile(dj.Part): -> master --- base_config_file: attach # the first creation of the config file - updated_config_file: attach # the updated config file after processing + config_file: attach # the updated config file after processing """ def make_fetch(self, key): @@ -560,7 +560,7 @@ def make_insert( { **key, "base_config_file": base_config_filepath, - "updated_config_file": updated_config_filepath, + "config_file": updated_config_filepath, } ) @@ -865,7 +865,7 @@ class ConfigFile(dj.Part): definition = """ -> master --- - updated_config_file: attach # Updated config file after PreFit computation + config_file: attach # Updated config file after PreFit computation """ class CheckpointFile(dj.Part): @@ -1058,7 +1058,7 @@ def make(self, key): self.ConfigFile.insert1( dict( **key, - updated_config_file=updated_cfg_path, + config_file=updated_cfg_path, ) ) @@ -1131,7 +1131,7 @@ class ConfigFile(dj.Part): definition = """ -> master --- - updated_config_file: attach # the updated config file after FullFit computation + config_file: attach # the updated config file after FullFit computation """ class CheckpointFile(dj.Part): @@ -1331,7 +1331,7 @@ def make(self, key): self.ConfigFile.insert1( { **key, - "updated_config_file": updated_cfg_path, + "config_file": updated_cfg_path, } ) From 1d420acf9761a6840dc7657e46a4caf62555fc23 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 00:41:11 +0200 Subject: [PATCH 09/41] update(PreProcessing): refactoring and add new useful attributes and part table --- element_moseq/moseq_train.py | 173 ++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 75 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index e3a1384..9bcc0ad 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -221,18 +221,20 @@ class PreProcessing(dj.Computed): Attributes: PCATask (foreign key) : Unique ID for each `PCATask` key. - coordinates (longblob) : Dictionary mapping filenames to cleaned keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). - confidences (longblob) : Dictionary mapping filenames to updated likelihood scores as ndarrays of shape (n_frames, n_bodyparts). formatted_bodyparts (longblob) : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. + coordinates (longblob) : Cleaned coordinates dictionary {recording_name: array} after outlier removal. + confidences (longblob) : Cleaned confidences dictionary {recording_name: array} after outlier removal. average_frame_rate (float) : Average frame rate of the videos for model training (used for kappa calculation). + pre_processing_time (datetime) : datetime of the preprocessing execution. + pre_processing_duration (int) : Execution time of the preprocessing in seconds. """ definition = """ -> PCATask # Unique ID for each `PCATask` key --- - coordinates : longblob # Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]) - confidences : longblob # Dictionary mapping filenames to likelihood scores as ndarrays of shape (n_frames, n_bodyparts) formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. + coordinates : longblob # Cleaned coordinates dictionary (recording_name: array) after outlier removal + confidences : longblob # Cleaned confidences dictionary (recording_name: array) after outlier removal average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). pre_processing_time : datetime # datetime of the preprocessing execution. pre_processing_duration : int # Execution time of the preprocessing in seconds. @@ -245,7 +247,17 @@ class Video(dj.Part): --- video_duration : int # Duration of each video in minutes frame_rate : float # Frame rate of the video in frames per second (Hz) - outlier_plot=NULL : attach # Plot of the outlier keypoints + file_size : float # File size of the video in megabytes (MB) + """ + + class OutlierRemoval(dj.Part): + """Store outlier detection QA plots per video.""" + + definition = """ + -> master + video_name: varchar(255) + --- + outlier_plot: attach # QA visualization showing detected outliers and interpolation """ class ConfigFile(dj.Part): @@ -256,8 +268,8 @@ class ConfigFile(dj.Part): definition = """ -> master --- - base_config_file: attach # the first creation of the config file - config_file: attach # the updated config file after processing + base_config_file: attach # the first version of the KPMS config file after setting up the project + config_file: attach # the updated KPMS DJ config file after processing """ def make_fetch(self, key): @@ -340,68 +352,72 @@ def make_compute( find_medoid_distance_outliers, interpolate_keypoints, load_keypoints, - plot_medoid_distance_outliers, ) + from .plotting.viz_utils import plot_medoid_distance_outliers + execution_time = datetime.now(timezone.utc) if task_mode == "trigger": from keypoint_moseq import setup_project + # check if the project output directory exists try: kpms_project_output_dir = find_full_path( get_kpms_processed_data_dir(), kpms_project_output_dir ) + # if the project output directory does not exist, create it except FileNotFoundError: kpms_project_output_dir = ( Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) - kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) - videos_dir = find_full_path( - get_kpms_root_data_dir(), - Path(keypoint_videofile_metadata[0]["video_path"]).parent, - ) + # Setup of the project creates KPMS base `config.yml` file copying the pose estimation config file + from .readers.kpms_reader import _pose_estimation_config_path + + pose_estimation_config_file = Path(_pose_estimation_config_path(kpset_dir)) + if not pose_estimation_config_file.exists(): + raise FileNotFoundError( + f"No `config.yml` or `config.yaml` file found in {kpset_dir}" + ) if pose_estimation_method == "deeplabcut": - from .readers.kpms_reader import _base_config_path - - # Find pose estimation config file - base_config_file = _base_config_path(kpset_dir) - base_config_file = Path(base_config_file) - if not base_config_file.exists(): - raise FileNotFoundError(f"No DLC config file found in {kpset_dir}") - # Create the kpms output diectory and config files setup_project( project_dir=kpms_project_output_dir.as_posix(), - deeplabcut_config=base_config_file.as_posix(), + deeplabcut_config=pose_estimation_config_file.as_posix(), ) else: raise NotImplementedError( "Currently, `deeplabcut` is the only pose estimation method supported by this Element. Please reach out at `support@datajoint.com` if you use another method." ) + # task mode is load else: kpms_project_output_dir = find_full_path( get_kpms_processed_data_dir(), kpms_project_output_dir ) kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) - videos_dir = find_full_path( - get_kpms_root_data_dir(), - Path(keypoint_videofile_metadata[0]["video_path"]).parent, - ) # Format keypoint data raw_coordinates, raw_confidences, formatted_bodyparts = load_keypoints( filepath_pattern=kpset_dir, format=pose_estimation_method ) - # Extract frame rate from keypoint video files + # Confirm that `use_bodyparts` are a subset of `formatted_bodyparts` + if not set(use_bodyparts).issubset(set(formatted_bodyparts)): + raise ValueError( + f"use_bodyparts ({use_bodyparts}) is not a subset of formatted bodyparts ({formatted_bodyparts})" + ) + + # Extract frame rate and file size from keypoint video files video_metadata_dict = dict() frame_rates = [] for row in keypoint_videofile_metadata: video_id = int(row["video_id"]) - video_path = find_full_path( - get_kpms_root_data_dir(), row["video_path"] - ).as_posix() - cap = cv2.VideoCapture(video_path) + video_path = find_full_path(get_kpms_root_data_dir(), row["video_path"]) + + # Get file size in MB (rounded to 2 decimal places) + file_size_mb = round(video_path.stat().st_size / (1024 * 1024), 2) + + # Get video properties + cap = cv2.VideoCapture(video_path.as_posix()) fps = float(cap.get(cv2.CAP_PROP_FPS)) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() @@ -415,36 +431,45 @@ def make_compute( "video_name": Path(row["video_path"]).stem, "video_duration": duration_minutes, "frame_rate": fps, + "file_size": file_size_mb, "outlier_plot": None, } average_frame_rate = float(np.mean(frame_rates)) - # Generate a copy of pose estimation config file + # Get all unique parent directories for all video files + parent_dirs = { + Path(video["video_path"]).parent for video in keypoint_videofile_metadata + } + # Check if there is only one unique parent + if len(parent_dirs) > 1: + raise ValueError( + f"Videos are located in multiple directories: {parent_dirs}. All videos must be in the same directory." + ) videos_dir = find_full_path( get_kpms_root_data_dir(), Path(keypoint_videofile_metadata[0]["video_path"]).parent, ) - # Confirm that `use_bodyparts` are a subset of `formatted_bodyparts` - if not set(use_bodyparts).issubset(set(formatted_bodyparts)): - raise ValueError( - f"use_bodyparts ({use_bodyparts}) is not a subset of formatter bodyparts ({formatted_bodyparts})" - ) - base_config, _ = kpms_reader.dj_generate_config( - project_dir=kpms_project_output_dir, + # Generate a new KPMS DJ config file copying the KPMS base config file in the same kpms project output directory + ( + kpms_dj_config_path, + kpms_dj_config_dict, + kpms_base_config_path, + kpms_base_config_dict, + ) = kpms_reader.dj_generate_config( + kpms_project_dir=kpms_project_output_dir, video_dir=str(videos_dir), use_bodyparts=list(use_bodyparts), anterior_bodyparts=list(anterior_bodyparts), posterior_bodyparts=list(posterior_bodyparts), outlier_scale_factor=float(outlier_scale_factor), ) - # Update the config file content - kpms_reader.update_kpms_dj_config( - kpms_project_output_dir, fps=average_frame_rate - ) - # load the udpated config file - kpms_dj_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + # Update the KPMS DJ config file content with the average frame rate + kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( + config_dict=kpms_dj_config_dict, fps=average_frame_rate + ) + kpms_dj_config_path = kpms_reader._kpms_dj_config_path(kpms_project_output_dir) # Clean outlier keypoints and generate plots cleaned_coordinates = {} @@ -454,10 +479,8 @@ def make_compute( video_id = int(row["video_id"]) pose_estimation_path = row["pose_estimation_path"] pose_estimation_name = Path(pose_estimation_path).stem - raw_coords = raw_coordinates[pose_estimation_name].copy() raw_conf = raw_confidences[pose_estimation_name].copy() - outliers = find_medoid_distance_outliers( raw_coords, outlier_scale_factor=outlier_scale_factor ) @@ -465,32 +488,19 @@ def make_compute( cleaned_conf = np.where(outliers["mask"], 0, raw_conf) cleaned_coordinates[pose_estimation_name] = cleaned_coords cleaned_confidences[pose_estimation_name] = cleaned_conf - - plot_medoid_distance_outliers( + outlier_plot, outlier_plot_path = plot_medoid_distance_outliers( project_dir=kpms_project_output_dir.as_posix(), recording_name=pose_estimation_name, original_coordinates=raw_coords, interpolated_coordinates=cleaned_coords, outlier_mask=outliers["mask"], outlier_thresholds=outliers["thresholds"], - **kpms_dj_config, - ) - plot_file_path = ( - kpms_project_output_dir - / "QA" - / "plots" - / "keypoint_distance_outliers" - / f"{pose_estimation_name}.png" - ) - if not plot_file_path.exists(): - raise FileNotFoundError( - f"Outlier plot file not found at {plot_file_path}" - ) - video_metadata_dict[video_id]["outlier_plot"] = plot_file_path.as_posix() - - # path to the config files - base_config_filepath = kpms_reader._base_config_path(kpms_project_output_dir) - updated_config_filepath = kpms_reader._dj_config_path(kpms_project_output_dir) + **kpms_dj_config_dict, + ) # outlier plot stored at kpms_project_output_dir/QA/plots/keypoint_distance_outliers/f"{pose_estimation_name}.png + video_metadata_dict[video_id] = { + **video_metadata_dict[video_id], + "outlier_plot_path": outlier_plot_path, + } completion_time = datetime.now(timezone.utc) if task_mode == "trigger": @@ -504,8 +514,8 @@ def make_compute( formatted_bodyparts, average_frame_rate, video_metadata_dict, - updated_config_filepath, - base_config_filepath, + kpms_dj_config_path, + kpms_base_config_path, execution_time, duration_seconds, ) @@ -518,8 +528,8 @@ def make_insert( formatted_bodyparts, average_frame_rate, video_metadata_dict, - updated_config_filepath, - base_config_filepath, + kpms_dj_config_path, + kpms_base_config_path, execution_time, duration_seconds, ): @@ -531,9 +541,9 @@ def make_insert( self.insert1( { **key, + "formatted_bodyparts": formatted_bodyparts, "coordinates": cleaned_coordinates, "confidences": cleaned_confidences, - "formatted_bodyparts": formatted_bodyparts, "average_frame_rate": average_frame_rate, "pre_processing_time": execution_time, "pre_processing_duration": duration_seconds, @@ -549,7 +559,20 @@ def make_insert( "video_name": meta["video_name"], "video_duration": meta["video_duration"], "frame_rate": meta["frame_rate"], - "outlier_plot": meta["outlier_plot"], + "file_size": meta["file_size"], + } + for vid, meta in video_metadata_dict.items() + ] + ) + + # Insert outlier removal QA plots + if video_metadata_dict: + self.OutlierRemoval.insert( + [ + { + **key, + "video_name": meta["video_name"], + "outlier_plot": meta["outlier_plot_path"], } for vid, meta in video_metadata_dict.items() ] @@ -559,8 +582,8 @@ def make_insert( self.ConfigFile.insert1( { **key, - "base_config_file": base_config_filepath, - "config_file": updated_config_filepath, + "config_file": kpms_dj_config_path, + "base_config_file": kpms_base_config_path, } ) From 3f3aa5e3a2981cd817646721164953e373dfbbac Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 01:41:12 +0200 Subject: [PATCH 10/41] refactor(preprocessing): minor --- element_moseq/moseq_train.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 9bcc0ad..a5302b4 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -13,6 +13,7 @@ import datajoint as dj import matplotlib.pyplot as plt import numpy as np +import yaml from element_interface.utils import find_full_path from .plotting import viz_utils @@ -232,10 +233,10 @@ class PreProcessing(dj.Computed): definition = """ -> PCATask # Unique ID for each `PCATask` key --- - formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - coordinates : longblob # Cleaned coordinates dictionary (recording_name: array) after outlier removal - confidences : longblob # Cleaned confidences dictionary (recording_name: array) after outlier removal - average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). + formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. + coordinates : longblob # Cleaned coordinates dictionary (recording_name: array) after outlier removal + confidences : longblob # Cleaned confidences dictionary (recording_name: array) after outlier removal + average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). pre_processing_time : datetime # datetime of the preprocessing execution. pre_processing_duration : int # Execution time of the preprocessing in seconds. """ @@ -257,7 +258,7 @@ class OutlierRemoval(dj.Part): -> master video_name: varchar(255) --- - outlier_plot: attach # QA visualization showing detected outliers and interpolation + outlier_plot: attach # QA visualization showing detected outliers and interpolation. """ class ConfigFile(dj.Part): @@ -268,8 +269,8 @@ class ConfigFile(dj.Part): definition = """ -> master --- - base_config_file: attach # the first version of the KPMS config file after setting up the project - config_file: attach # the updated KPMS DJ config file after processing + base_config_file: attach # KPMS config attachment. Stored as binary in database. + config_file: attach # Updated KPMS DJ config attachment. Stored as binary in database. """ def make_fetch(self, key): @@ -465,11 +466,18 @@ def make_compute( outlier_scale_factor=float(outlier_scale_factor), ) - # Update the KPMS DJ config file content with the average frame rate + # Get absolute paths for attach fields + kpms_dj_config_path = find_full_path( + get_kpms_processed_data_dir(), kpms_dj_config_path + ) + kpms_base_config_path = find_full_path( + get_kpms_processed_data_dir(), kpms_base_config_path + ) + + # Update the KPMS DJ config file on disk with the average frame rate kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( - config_dict=kpms_dj_config_dict, fps=average_frame_rate + kpms_project_dir=kpms_project_output_dir, fps=average_frame_rate ) - kpms_dj_config_path = kpms_reader._kpms_dj_config_path(kpms_project_output_dir) # Clean outlier keypoints and generate plots cleaned_coordinates = {} @@ -488,7 +496,7 @@ def make_compute( cleaned_conf = np.where(outliers["mask"], 0, raw_conf) cleaned_coordinates[pose_estimation_name] = cleaned_coords cleaned_confidences[pose_estimation_name] = cleaned_conf - outlier_plot, outlier_plot_path = plot_medoid_distance_outliers( + _, outlier_plot_path = plot_medoid_distance_outliers( project_dir=kpms_project_output_dir.as_posix(), recording_name=pose_estimation_name, original_coordinates=raw_coords, @@ -577,7 +585,6 @@ def make_insert( for vid, meta in video_metadata_dict.items() ] ) - # Insert configuration files self.ConfigFile.insert1( { From 31aa45a351fe478fa9da3d337712fd5ac0430575 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 01:46:45 +0200 Subject: [PATCH 11/41] update(preprocessing): from `video_name` to `video_id` --- element_moseq/moseq_train.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index a5302b4..9a981fc 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -244,7 +244,7 @@ class PreProcessing(dj.Computed): class Video(dj.Part): definition = """ -> master - video_name: varchar(255) + video_id: varchar(255) --- video_duration : int # Duration of each video in minutes frame_rate : float # Frame rate of the video in frames per second (Hz) @@ -256,7 +256,7 @@ class OutlierRemoval(dj.Part): definition = """ -> master - video_name: varchar(255) + video_id: varchar(255) --- outlier_plot: attach # QA visualization showing detected outliers and interpolation. """ @@ -429,7 +429,6 @@ def make_compute( duration_minutes = int((frame_count / fps) / 60.0) frame_rates.append(fps) video_metadata_dict[video_id] = { - "video_name": Path(row["video_path"]).stem, "video_duration": duration_minutes, "frame_rate": fps, "file_size": file_size_mb, @@ -564,7 +563,7 @@ def make_insert( [ { **key, - "video_name": meta["video_name"], + "video_id": vid, "video_duration": meta["video_duration"], "frame_rate": meta["frame_rate"], "file_size": meta["file_size"], @@ -579,7 +578,7 @@ def make_insert( [ { **key, - "video_name": meta["video_name"], + "video_id": vid, "outlier_plot": meta["outlier_plot_path"], } for vid, meta in video_metadata_dict.items() From 9e5d47661f3d40abba14f45dac33f6e4c93d2bde Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 01:59:20 +0200 Subject: [PATCH 12/41] refactor(PCAFit) --- element_moseq/moseq_train.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 9a981fc..df65456 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -650,25 +650,27 @@ def make(self, key): "coordinates", "confidences" ) - # Load the configuration from the file - kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + # Load the configuration from database + kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_path + ) execution_time = datetime.now(timezone.utc) # Format keypoint data data, _ = format_data( - **kpms_config, coordinates=coordinates, confidences=confidences + **kpms_dj_config_dict, coordinates=coordinates, confidences=confidences ) # Fit PCA model and save as `pca.p` file if task_mode == "trigger": - pca = fit_pca(**data, **kpms_config) + pca = fit_pca(**data, **kpms_dj_config_dict) save_pca(pca, kpms_project_output_dir.as_posix()) # Check for pca.p file - pca_p_file = kpms_project_output_dir / "pca.p" - - if not pca_p_file.exists(): + pca_path = kpms_project_output_dir / "pca.p" + if not pca_path.exists(): raise FileNotFoundError( f"No pca file (`pca.p`) found in the project directory {kpms_project_output_dir}" ) @@ -694,8 +696,8 @@ def make(self, key): [ { **key, - "file_name": pca_p_file.name, - "file_path": pca_p_file, + "file_name": pca_path.name, + "file_path": pca_path, } ] ) From 20a40576d6ccef3179cb2305b6a6abe5b6b88581 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 02:36:16 +0200 Subject: [PATCH 13/41] refactor(latent_dimension) --- element_moseq/moseq_train.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index df65456..ef61b24 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -752,11 +752,12 @@ def make(self, key): 3. Determine number of components needed for 90% variance. 4. Insert results into table. """ - - VARIANCE_THRESHOLD = 0.90 + import tempfile from keypoint_moseq import load_pca + VARIANCE_THRESHOLD = 0.90 + kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir") kpms_project_output_dir = ( Path(get_kpms_processed_data_dir()) / kpms_project_output_dir @@ -780,8 +781,11 @@ def make(self, key): variance_percentage = VARIANCE_THRESHOLD * 100 latent_dim_desc = f">={VARIANCE_THRESHOLD*100}% of variance explained by {(cs>VARIANCE_THRESHOLD).nonzero()[0].min()+1} components." - # Load the configuration from the file - kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + # Load the configuration from database + kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_path + ) # Generate scree plot scree_fig = plt.figure() @@ -796,9 +800,9 @@ def make(self, key): # Generate PCs plot pcs_fig = viz_utils.plot_pcs( pca, - **kpms_config, interactive=False, project_dir=kpms_project_output_dir, + **kpms_dj_config_dict, ) # Load the pcs-xy.pdf file @@ -810,8 +814,6 @@ def make(self, key): ) # Save plots to temporary directory - import tempfile - tmpdir = tempfile.TemporaryDirectory() fname = f"{key['kpset_id']}_{key['bodyparts_id']}" From 8980b3992c368aacb77b3fcd8a3a03deb43a5942 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 05:02:06 +0200 Subject: [PATCH 14/41] refactor(viz_utils) --- element_moseq/plotting/viz_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/element_moseq/plotting/viz_utils.py b/element_moseq/plotting/viz_utils.py index 407a812..b5d9611 100644 --- a/element_moseq/plotting/viz_utils.py +++ b/element_moseq/plotting/viz_utils.py @@ -60,8 +60,11 @@ def plot_medoid_distance_outliers( Returns ------- - None - The plot is saved to 'QA/plots/keypoint_distance_outliers/{recording_name}.png'. + tuple of (fig, plot_path) + fig: matplotlib.figure.Figure + The generated figure object + plot_path: str + Path to the saved plot file: 'QA/plots/keypoint_distance_outliers/{recording_name}.png' """ from keypoint_moseq.util import get_distance_to_medoid, plot_keypoint_traces @@ -98,7 +101,7 @@ def plot_medoid_distance_outliers( logger.info( f"Saved keypoint distance outlier plot for {recording_name} to {plot_path}." ) - return fig + return fig, plot_path def plot_pcs( @@ -298,3 +301,4 @@ def copy_pdf_to_png(project_dir, model_name): images[0].save(png_path, "PNG") logger.info(f"Generated PNG progress plot at {png_path}") + return png_path, pdf_path From 37e875df6ec1fc08be1bc1243cc93580948d69d2 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 05:04:05 +0200 Subject: [PATCH 15/41] refactor(kpms_reader) --- element_moseq/readers/kpms_reader.py | 244 ++++++++++++++++++++------- 1 file changed, 184 insertions(+), 60 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 464f544..d2bf991 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -7,29 +7,73 @@ logger = dj.logger -DJ_CONFIG = "kpms_dj_config.yml" -BASE_CONFIG = "config.yml" +KPMS_DJ_CONFIG = "kpms_dj_config.yml" +CONFIG_FILENAMES = [ + "config.yml", + "config.yaml", +] # Used for both pose estimation and KPMS base configs -def _dj_config_path(project_dir: Union[str, os.PathLike]) -> str: - return str(Path(project_dir) / DJ_CONFIG) +def _pose_estimation_config_path(kpset_dir: Union[str, os.PathLike]) -> str: + """ + Return the path to the pose estimation config file (e.g., DeepLabCut config.yaml) in the keypoint set directory. + + Args: + kpset_dir: Keypoint set directory (where pose estimation files are located) + + Returns: + Path to pose estimation config file (config.yml or config.yaml) + """ + kpset_path = Path(kpset_dir) + # Check for config.yml first (preferred) + config_yml = kpset_path / CONFIG_FILENAMES[0] + if config_yml.exists(): + return str(config_yml) + # Fall back to config.yaml + config_yaml = kpset_path / CONFIG_FILENAMES[1] + if config_yaml.exists(): + return str(config_yaml) + # If neither exists, return the default (config.yml) + return str(config_yml) -def _base_config_path(project_dir: Union[str, os.PathLike]) -> str: - """Return the path to the base config file, checking for both .yml and .yaml extensions.""" - project_path = Path(project_dir) +def _kpms_base_config_path(kpms_project_dir: Union[str, os.PathLike]) -> str: + """ + Return the path to the KPMS base config file (created by keypoint_moseq's setup_project) in the KPMS output directory. + + Args: + kpms_project_dir: KPMS project output directory + + Returns: + Path to KPMS base config file (config.yml or config.yaml) + """ + project_path = Path(kpms_project_dir) # Check for config.yml first (preferred) - config_yml = project_path / "config.yml" + config_yml = project_path / CONFIG_FILENAMES[0] if config_yml.exists(): return str(config_yml) # Fall back to config.yaml - config_yaml = project_path / "config.yaml" + config_yaml = project_path / CONFIG_FILENAMES[1] if config_yaml.exists(): return str(config_yaml) # If neither exists, return the default (config.yml) return str(config_yml) +def _kpms_dj_config_path(kpms_project_dir: Union[str, os.PathLike]) -> str: + """ + Return the path to the KPMS DJ config file (kpms_dj_config.yml) in the KPMS output directory. + This is the DataJoint-specific config file that gets updated during the pipeline. + + Args: + kpms_project_dir: KPMS project output directory + + Returns: + Path to KPMS DJ config file (kpms_dj_config.yml) + """ + return str(Path(kpms_project_dir) / KPMS_DJ_CONFIG) + + def _check_config_validity(config: Dict[str, Any]) -> bool: """ Minimal mirror of keypoint_moseq.io.check_config_validity logic that matters @@ -55,47 +99,82 @@ def _check_config_validity(config: Dict[str, Any]) -> bool: return True -def dj_generate_config(project_dir: str, **kwargs) -> str: +def dj_generate_config(kpms_project_dir: str, **kwargs) -> tuple: """ - Generate or refresh `/kpms_dj_config.yml`. + Generate or refresh `/kpms_dj_config.yml` from the KPMS base config. Behavior: - - If the DJ config doesn't exist, start from the **base** `/config.yml` - created by upstream `setup_project`, then overlay kwargs and write DJ config. - - If the DJ config exists, load it, overlay kwargs, and rewrite it. + - If the KPMS DJ config doesn't exist, start from the KPMS base `/config.yml` + (created by keypoint_moseq's `setup_project`), then overlay kwargs and write KPMS DJ config. + - If the KPMS DJ config exists, load it, overlay kwargs, and rewrite it. + + Args: + kpms_project_dir: KPMS project output directory + **kwargs: Key-value pairs to update in the config - Returns the path to `kpms_dj_config.yml`. + Returns: + Tuple of (kpms_dj_config_path, kpms_dj_config_dict, kpms_base_config_path, kpms_base_config_dict) """ - project_dir = str(project_dir) - base_cfg_path = _base_config_path(project_dir) - dj_cfg_path = _dj_config_path(project_dir) + kpms_project_dir = str(kpms_project_dir) + kpms_base_config_path = _kpms_base_config_path(kpms_project_dir) + kpms_dj_config_path = _kpms_dj_config_path(kpms_project_dir) + + # Load KPMS base config if it exists + kpms_base_config_dict = None + if Path(kpms_base_config_path).exists(): + with open(kpms_base_config_path, "r") as f: + kpms_base_config_dict = yaml.safe_load(f) or {} - if Path(dj_cfg_path).exists(): - with open(dj_cfg_path, "r") as f: - cfg = yaml.safe_load(f) or {} + # Generate or update KPMS DJ config + if Path(kpms_dj_config_path).exists(): + with open(kpms_dj_config_path, "r") as f: + kpms_dj_config_dict = yaml.safe_load(f) or {} else: - if not Path(base_cfg_path).exists(): + if not Path(kpms_base_config_path).exists(): raise FileNotFoundError( - f"Missing base config at {base_cfg_path}. Run upstream setup_project first. " - f"Expected either config.yml or config.yaml in {project_dir}." + f"Missing KPMS base config at {kpms_base_config_path}. " + f"Run keypoint_moseq's setup_project first. " + f"Expected either config.yml or config.yaml in {kpms_project_dir}." ) - with open(base_cfg_path, "r") as f: - cfg = yaml.safe_load(f) or {} - cfg.update(kwargs) + kpms_dj_config_dict = kpms_base_config_dict.copy() - if "skeleton" not in cfg or cfg["skeleton"] is None: - cfg["skeleton"] = [] + kpms_dj_config_dict.update(kwargs) - with open(dj_cfg_path, "w") as f: - yaml.safe_dump(cfg, f, sort_keys=False) - return dj_cfg_path, base_cfg_path + if "skeleton" not in kpms_dj_config_dict or kpms_dj_config_dict["skeleton"] is None: + kpms_dj_config_dict["skeleton"] = [] + + with open(kpms_dj_config_path, "w") as f: + yaml.safe_dump(kpms_dj_config_dict, f, sort_keys=False) + + return ( + kpms_dj_config_path, + kpms_dj_config_dict, + kpms_base_config_path, + kpms_base_config_dict, + ) def load_kpms_dj_config( - project_dir: str, check_if_valid: bool = True, build_indexes: bool = True + kpms_project_dir: str = None, + config_path: str = None, + check_if_valid: bool = True, + build_indexes: bool = True, ) -> Dict[str, Any]: """ - Load `/kpms_dj_config.yml`. + Load kpms_dj_config.yml from either a KPMS project directory or a direct file path. + + Args: + kpms_project_dir: KPMS project output directory containing kpms_dj_config.yml (optional) + config_path: Direct path to kpms_dj_config.yml file (optional) + check_if_valid: Check anatomy subset validity + build_indexes: Add jax arrays 'anterior_idxs' and 'posterior_idxs' + + Returns: + Configuration dictionary + + Raises: + ValueError: If neither or both kpms_project_dir and config_path are provided + FileNotFoundError: If the config file doesn't exist Mirrors keypoint_moseq.io.load_config behavior: - check_if_valid -> anatomy subset checks @@ -104,47 +183,92 @@ def load_kpms_dj_config( """ import jax.numpy as jnp - dj_cfg_path = _dj_config_path(project_dir) - if not Path(dj_cfg_path).exists(): + # Validate input parameters + if kpms_project_dir is None and config_path is None: + raise ValueError("Either 'kpms_project_dir' or 'config_path' must be provided.") + if kpms_project_dir is not None and config_path is not None: + raise ValueError( + "Cannot provide both 'kpms_project_dir' and 'config_path'. Choose one." + ) + + # Determine the config file path + if config_path is not None: + kpms_dj_cfg_path = config_path + else: + kpms_dj_cfg_path = _kpms_dj_config_path(kpms_project_dir) + + if not Path(kpms_dj_cfg_path).exists(): raise FileNotFoundError( - f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()." + f"Missing DJ config at {kpms_dj_cfg_path}. Create it with dj_generate_config()." ) - with open(dj_cfg_path, "r") as f: - cfg = yaml.safe_load(f) or {} + with open(kpms_dj_cfg_path, "r") as f: + cfg_dict = yaml.safe_load(f) or {} if check_if_valid: - _check_config_validity(cfg) + _check_config_validity(cfg_dict) if build_indexes: - anterior = cfg.get("anterior_bodyparts", []) - posterior = cfg.get("posterior_bodyparts", []) - use_bps = cfg.get("use_bodyparts", []) - cfg["anterior_idxs"] = jnp.array([use_bps.index(bp) for bp in anterior]) - cfg["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) + anterior = cfg_dict.get("anterior_bodyparts", []) + posterior = cfg_dict.get("posterior_bodyparts", []) + use_bps = cfg_dict.get("use_bodyparts", []) + cfg_dict["anterior_idxs"] = jnp.array([use_bps.index(bp) for bp in anterior]) + cfg_dict["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) - if "skeleton" not in cfg or cfg["skeleton"] is None: - cfg["skeleton"] = [] + if "skeleton" not in cfg_dict or cfg_dict["skeleton"] is None: + cfg_dict["skeleton"] = [] - return cfg + return cfg_dict -def update_kpms_dj_config(project_dir: str, **kwargs) -> Dict[str, Any]: +def update_kpms_dj_config( + kpms_project_dir: str = None, config_dict: Dict[str, Any] = None, **kwargs +) -> Dict[str, Any]: """ - Update `kpms_dj_config.yml` with provided top-level kwargs (same pattern as - keypoint_moseq.io.update_config), then rewrite the file and return the dict. + Update kpms_dj_config with provided kwargs. + This function updates the file on disk. + This function returns the updated config dictionary. + + Args: + kpms_project_dir: KPMS project output directory containing kpms_dj_config.yml (optional) + config_dict: Existing config dictionary to update (optional) + **kwargs: Key-value pairs to update in the config + + Returns: + Updated configuration dictionary + + Raises: + ValueError: If neither or both kpms_project_dir and config_dict are provided + + If kpms_project_dir is provided, loads the config from file, updates it, saves it back, and returns it. + If config_dict is provided, updates it directly and returns it (no file I/O). """ - dj_cfg_path = _dj_config_path(project_dir) - if not Path(dj_cfg_path).exists(): - raise FileNotFoundError( - f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()." + # Validate input parameters + if kpms_project_dir is None and config_dict is None: + raise ValueError("Either 'kpms_project_dir' or 'config_dict' must be provided.") + if kpms_project_dir is not None and config_dict is not None: + raise ValueError( + "Cannot provide both 'kpms_project_dir' and 'config_dict'. Choose one." ) - with open(dj_cfg_path, "r") as f: - cfg_dict = yaml.safe_load(f) or {} + # Load from file if kpms_project_dir is provided + if kpms_project_dir is not None: + kpms_dj_cfg_path = _kpms_dj_config_path(kpms_project_dir) + if not Path(kpms_dj_cfg_path).exists(): + raise FileNotFoundError( + f"Missing DJ config at {kpms_dj_cfg_path}. Create it with dj_generate_config()." + ) + + with open(kpms_dj_cfg_path, "r") as f: + cfg_dict = yaml.safe_load(f) or {} - cfg_dict.update(kwargs) + cfg_dict.update(kwargs) + + with open(kpms_dj_cfg_path, "w") as f: + yaml.safe_dump(cfg_dict, f, sort_keys=False) + else: + # Update the provided dict directly (no file I/O) + cfg_dict = config_dict.copy() # Make a copy to avoid mutating the input + cfg_dict.update(kwargs) - with open(dj_cfg_path, "w") as f: - yaml.safe_dump(cfg_dict, f, sort_keys=False) return cfg_dict From 7391789492be529ddc2bfff29f225913f892984b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 05:05:00 +0200 Subject: [PATCH 16/41] refactor(moseq_train) --- element_moseq/moseq_train.py | 298 +++++++++++++++++------------------ 1 file changed, 143 insertions(+), 155 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index ef61b24..93a2ddf 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -5,6 +5,7 @@ import importlib import inspect +import pickle from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -13,7 +14,6 @@ import datajoint as dj import matplotlib.pyplot as plt import numpy as np -import yaml from element_interface.utils import find_full_path from .plotting import viz_utils @@ -225,20 +225,20 @@ class PreProcessing(dj.Computed): formatted_bodyparts (longblob) : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. coordinates (longblob) : Cleaned coordinates dictionary {recording_name: array} after outlier removal. confidences (longblob) : Cleaned confidences dictionary {recording_name: array} after outlier removal. - average_frame_rate (float) : Average frame rate of the videos for model training (used for kappa calculation). + average_frame_rate (int) : Average frame rate of the videos for model training (used for kappa calculation). pre_processing_time (datetime) : datetime of the preprocessing execution. pre_processing_duration (int) : Execution time of the preprocessing in seconds. """ definition = """ - -> PCATask # Unique ID for each `PCATask` key + -> PCATask # Unique ID for each `PCATask` key --- - formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - coordinates : longblob # Cleaned coordinates dictionary (recording_name: array) after outlier removal - confidences : longblob # Cleaned confidences dictionary (recording_name: array) after outlier removal - average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). - pre_processing_time : datetime # datetime of the preprocessing execution. - pre_processing_duration : int # Execution time of the preprocessing in seconds. + formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. + coordinates : longblob # Cleaned coordinates dictionary (recording_name: array) after outlier removal + confidences : longblob # Cleaned confidences dictionary (recording_name: array) after outlier removal + average_frame_rate : int # Average frame rate of the videos for model training (used for kappa calculation). + pre_processing_time=NULL : datetime # datetime of the preprocessing execution. + pre_processing_duration=NULL : int # Execution time of the preprocessing in seconds. """ class Video(dj.Part): @@ -256,7 +256,7 @@ class OutlierRemoval(dj.Part): definition = """ -> master - video_id: varchar(255) + video_id : varchar(255) --- outlier_plot: attach # QA visualization showing detected outliers and interpolation. """ @@ -269,8 +269,8 @@ class ConfigFile(dj.Part): definition = """ -> master --- - base_config_file: attach # KPMS config attachment. Stored as binary in database. - config_file: attach # Updated KPMS DJ config attachment. Stored as binary in database. + base_config_file=NULL : attach # Base KPMS config attachment. + config_file : attach # Updated KPMS DJ config attachment. """ def make_fetch(self, key): @@ -434,8 +434,7 @@ def make_compute( "file_size": file_size_mb, "outlier_plot": None, } - average_frame_rate = float(np.mean(frame_rates)) - + average_frame_rate = int(round(np.mean(frame_rates))) # Get all unique parent directories for all video files parent_dirs = { Path(video["video_path"]).parent for video in keypoint_videofile_metadata @@ -618,8 +617,8 @@ class File(dj.Part): definition = """ -> master + file_name : varchar(255) # name of the pca file (e.g. 'pca.p'). --- - file_name : varchar(1000) # name of the pca file (e.g. 'pca.p'). file_path : filepath@moseq-train-processed # path to the pca file (relative to the project output directory). """ @@ -659,21 +658,35 @@ def make(self, key): execution_time = datetime.now(timezone.utc) # Format keypoint data - data, _ = format_data( + data, metadata = format_data( **kpms_dj_config_dict, coordinates=coordinates, confidences=confidences ) + # Save data and metadata as pickle files + data_filename = "data.pkl" + metadata_filename = "metadata.pkl" + data_path = kpms_project_output_dir / data_filename + metadata_path = kpms_project_output_dir / metadata_filename + with open(data_path, "wb") as f: + pickle.dump(data, f) + with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) + # Fit PCA model and save as `pca.p` file if task_mode == "trigger": pca = fit_pca(**data, **kpms_dj_config_dict) save_pca(pca, kpms_project_output_dir.as_posix()) + pca_filename = "pca.p" + pca_path = kpms_project_output_dir / pca_filename + # Check for pca.p file - pca_path = kpms_project_output_dir / "pca.p" if not pca_path.exists(): raise FileNotFoundError( f"No pca file (`pca.p`) found in the project directory {kpms_project_output_dir}" ) + # Insert all files in a single operation + file_paths = [pca_path, data_path, metadata_path] completion_time = datetime.now(timezone.utc) @@ -691,14 +704,14 @@ def make(self, key): } ) - # Insert in File table self.File.insert( [ { **key, - "file_name": pca_path.name, - "file_path": pca_path, + "file_name": file_path.name, + "file_path": file_path, } + for file_path in file_paths ] ) @@ -720,7 +733,7 @@ class LatentDimension(dj.Computed): --- variance_percentage : float # Variance threshold. Fixed value to 90 percent. latent_dimension : int # Number of principal components required to explain the specified variance. - latent_dim_desc : varchar(1000) # Automated description of the computation result. + latent_dim_desc='' : varchar(1000) # Automated description of the computation result. """ class Plots(dj.Part): @@ -731,9 +744,9 @@ class Plots(dj.Part): definition = """ -> master --- - scree_plot: attach # A cumulative scree plot showing explained variance - pcs_plot: attach # A visualization of each Principal Component (PC) - pcs_xy_plot: attach # A visualization of the Principal Components (PCs) in the XY plane + scree_plot : attach # A cumulative scree plot showing explained variance + pcs_plot : attach # A visualization of each Principal Component (PC) + pcs_xy_plot : attach # A visualization of the Principal Components (PCs) in the XY plane """ def make(self, key): @@ -867,7 +880,7 @@ class PreFitTask(dj.Manual): pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` - task_mode='load' :enum('load','trigger') # 'load': load computed analysis results, 'trigger': trigger computation + task_mode='load' :enum('trigger','load') # 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task """ @@ -898,19 +911,19 @@ class ConfigFile(dj.Part): definition = """ -> master --- - config_file: attach # Updated config file after PreFit computation + config_file: attach # Updated KPMS DJ config file after PreFit computation. """ class CheckpointFile(dj.Part): """ - Store the checkpoint file used for resuming the fitting process. + Store the checkpoint file used for resuming the pre-fitting process. """ definition = """ -> master --- - checkpoint_file_name: varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.p'). - checkpoint_file: filepath@moseq-train-processed # path to the checkpoint file. + checkpoint_file_name : varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.h5'). + checkpoint_file : filepath@moseq-train-processed # Path to the checkpoint file. """ class Plots(dj.Part): @@ -969,93 +982,80 @@ def make(self, key): if task_mode == "trigger": from keypoint_moseq import estimate_sigmasq_loc - # Update the existing kpms_dj_config.yml with new latent_dim and kappa values - kpms_reader.update_kpms_dj_config( - kpms_project_output_dir, - latent_dim=int(pre_latent_dim), - kappa=float(pre_kappa), - ) - - # Load the updated config for use in model fitting - kpms_dj_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) - - # Load the PCA model from the project directory - pca_path = kpms_project_output_dir / "pca.p" - if pca_path.exists(): - pca = load_pca(kpms_project_output_dir.as_posix()) - else: - raise FileNotFoundError( - f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" - ) - - # Format the data for model fitting + kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + pca_path = (PCAFit.File & key & 'file_name="pca.p"').fetch1("file_path") + pca = load_pca(Path(pca_path).parent.as_posix()) coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - data, metadata = format_data( - coordinates=coordinates, confidences=confidences, **kpms_dj_config + data_path = (PCAFit.File & key & 'file_name="data.pkl"').fetch1("file_path") + metadata_path = (PCAFit.File & key & 'file_name="metadata.pkl"').fetch1( + "file_path" ) + data = pickle.load(open(data_path, "rb")) + metadata = pickle.load(open(metadata_path, "rb")) + average_frame_rate = (PreProcessing & key).fetch1("average_frame_rate") - # Update the kpms_dj_config.yml with the new sigmasq_loc - kpms_reader.update_kpms_dj_config( - kpms_project_output_dir, + # Update kpms_dj_config file in disk with new latent_dim and kappa values + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_path + ) + kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( + config_dict=kpms_dj_config_dict, + latent_dim=pre_latent_dim, + kappa=pre_kappa, sigmasq_loc=estimate_sigmasq_loc( - data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) + data["Y"], data["mask"], filter_size=average_frame_rate ), ) - # Load the updated config for use in model fitting - kpms_dj_config = kpms_reader.load_kpms_dj_config( - project_dir=kpms_project_output_dir - ) - # Initialize the model - model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) - + model = init_model( + data=data, metadata=metadata, pca=pca, **kpms_dj_config_dict + ) # Update the model hyperparameters model = update_hypparams( - model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim) + model, + kappa=float(pre_kappa.item()), + latent_dim=int(pre_latent_dim.item()), ) - # Determine model directory name for outputs if model_name is None or not str(model_name).strip(): - model_dir_name = f"latent_dim_{float(pre_latent_dim)}_kappa_{float(pre_kappa)}_iters_{float(pre_num_iterations)}" + model_name = f"latent_dim_{pre_latent_dim.item()}_kappa_{pre_kappa.item()}_iters_{pre_num_iterations.item()}" else: - model_dir_name = str(model_name) + model_name = str(model_name) execution_time = datetime.now(timezone.utc) # Fit the model - model, model_name = fit_model( + model, _ = fit_model( model=model, - model_name=model_dir_name, + model_name=model_name, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), ar_only=True, num_iters=pre_num_iterations, generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ - save_every_n_iters=25, + save_every_n_iters=10, + ) # model_name is already the correct directory name + # Create a PNG version fo the PDF progress plot + png_path, pdf_path = viz_utils.copy_pdf_to_png( + kpms_project_output_dir, model_name ) - - # Normalize to folder name returned by fit_model - model_dir_name = Path(model_name).name - - # Copy the PDF progress plot to PNG - viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_dir_name) + # Define model_name_full_path for checkpoint file search + model_name_full_path = find_full_path(kpms_project_output_dir, model_name) else: # Load mode must specify a model_name if model_name is None or not str(model_name).strip(): raise ValueError("model_name is required when task_mode='load'") - model_dir_name = Path(model_name).name + model_name_full_path = find_full_path(kpms_project_output_dir, model_name) + pdf_path = model_name_full_path / "fitting_progress.pdf" + png_path = model_name_full_path / "fitting_progress.png" # Get the path to the updated config file - updated_cfg_path = (kpms_project_output_dir / "kpms_dj_config.yml").as_posix() + kpms_dj_config_path = kpms_reader._kpms_dj_config_path(kpms_project_output_dir) - # Check for fitting progress files - prefit_model_dir = kpms_project_output_dir / model_dir_name - pdf_path = prefit_model_dir / "fitting_progress.pdf" - png_path = prefit_model_dir / "fitting_progress.png" if not pdf_path.exists(): raise FileNotFoundError(f"PreFit PDF progress plot not found at {pdf_path}") if not png_path.exists(): @@ -1064,11 +1064,13 @@ def make(self, key): # Find checkpoint file checkpoint_files = [] for pattern in ("checkpoint*", "*.h5"): - checkpoint_files.extend(prefit_model_dir.glob(pattern)) + checkpoint_files.extend(model_name_full_path.glob(pattern)) if checkpoint_files: checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_size) else: - raise FileNotFoundError(f"No checkpoint files found in {prefit_model_dir}") + raise FileNotFoundError( + f"No checkpoint files found in {model_name_full_path}" + ) completion_time = datetime.now(timezone.utc) @@ -1080,10 +1082,7 @@ def make(self, key): self.insert1( { **key, - "model_name": ( - kpms_project_output_dir.relative_to(get_kpms_processed_data_dir()) - / model_dir_name - ).as_posix(), + "model_name": model_name, "pre_fit_duration": duration_seconds, } ) @@ -1091,7 +1090,7 @@ def make(self, key): self.ConfigFile.insert1( dict( **key, - config_file=updated_cfg_path, + config_file=kpms_dj_config_path, ) ) @@ -1164,25 +1163,24 @@ class ConfigFile(dj.Part): definition = """ -> master --- - config_file: attach # the updated config file after FullFit computation + config_file: attach # Updated KPMS DJ config attachment. """ class CheckpointFile(dj.Part): """ - Store the checkpoint file used for resuming the fitting process. + Store the checkpoint file used for resuming the full-fitting process. """ definition = """ -> master --- - checkpoint_file_name: varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.p'). - checkpoint_file: filepath@moseq-train-processed # path to the checkpoint file. + checkpoint_file_name : varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.p'). + checkpoint_file : filepath@moseq-train-processed # Path to the checkpoint file in the processed data directory. """ class Plots(dj.Part): """ - Store the fitting progress of the FullFit computation: - - Plots in PDF and PNG formats used for visualization. + Store the fitting progress of the FullFit computation. """ definition = """ @@ -1235,111 +1233,101 @@ def make(self, key): ) if task_mode == "trigger": - - # Update the kpms_dj_config.yml with latent dimension and kappa values - kpms_reader.update_kpms_dj_config( - project_dir=kpms_project_output_dir, - latent_dim=int(full_latent_dim), - kappa=float(full_kappa), - ) - - # Load the updated config for data formatting - kpms_dj_config = kpms_reader.load_kpms_dj_config( - project_dir=kpms_project_output_dir - ) - - # Load the PCA model - pca_path = kpms_project_output_dir / "pca.p" - if pca_path.exists(): - pca = load_pca(kpms_project_output_dir.as_posix()) - else: - raise FileNotFoundError( - f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" - ) - - # Format the data for model fitting + pca_path = (PCAFit.File & key & 'file_name="pca.p"').fetch1("file_path") + pca = load_pca(Path(pca_path).parent.as_posix()) coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - data, metadata = format_data( - coordinates=coordinates, confidences=confidences, **kpms_dj_config + data_path = (PCAFit.File & key & 'file_name="data.pkl"').fetch1("file_path") + data = pickle.load(open(data_path, "rb")) + metadata_path = (PCAFit.File & key & 'file_name="metadata.pkl"').fetch1( + "file_path" ) + metadata = pickle.load(open(metadata_path, "rb")) + average_frame_rate = (PreProcessing & key).fetch1("average_frame_rate") - # Update the kpms_dj_config.yml with the new sigmasq_loc - kpms_reader.update_kpms_dj_config( - project_dir=kpms_project_output_dir, + kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_path + ) + # Update kpms_dj_config file in disk with new latent_dim and kappa values + kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( + config_dict=kpms_dj_config_dict, + latent_dim=full_latent_dim, + kappa=full_kappa, sigmasq_loc=estimate_sigmasq_loc( - data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) + data["Y"], data["mask"], filter_size=average_frame_rate ), ) - # Load the updated config for use in model fitting - kpms_dj_config = kpms_reader.load_kpms_dj_config( - project_dir=kpms_project_output_dir - ) - # Initialize the model - model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) + model = init_model( + data=data, metadata=metadata, pca=pca, **kpms_dj_config_dict + ) # Update the model hyperparameters model = update_hypparams( - model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) + model, + kappa=float(full_kappa.item()), + latent_dim=int(full_latent_dim.item()), ) - # Generate the model directory name + # Determine model directory name for outputs if model_name is None or not str(model_name).strip(): - model_dir_name = f"latent_dim_{float(full_latent_dim)}_kappa_{float(full_kappa)}_iters_{float(full_num_iterations)}" + model_name = f"latent_dim_{full_latent_dim.item()}_kappa_{full_kappa.item()}_iters_{full_num_iterations.item()}" else: - model_dir_name = str(model_name) + model_name = str(model_name) execution_time = datetime.now(timezone.utc) - # Fit the model model, model_name = fit_model( model=model, - model_name=model_dir_name, + model_name=model_name, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), ar_only=False, num_iters=full_num_iterations, generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ - save_every_n_iters=25, + save_every_n_iters=10, ) # Reindex the syllables in the checkpoint file reindex_syllables_in_checkpoint( project_dir=kpms_project_output_dir.as_posix(), - model_name=Path(model_name).name, + model_name=model_name, ) - # Copy the PDF progress plot to PNG - viz_utils.copy_pdf_to_png(kpms_project_output_dir, Path(model_name).name) + # Create a PNG version fo the PDF progress plot + png_path, pdf_path = viz_utils.copy_pdf_to_png( + kpms_project_output_dir, model_name + ) + # Define model_name_full_path for checkpoint file search + model_name_full_path = find_full_path(kpms_project_output_dir, model_name) + else: + # Load mode must specify a model_name + if model_name is None or not str(model_name).strip(): + raise ValueError("model_name is required when task_mode='load'") + model_name_full_path = find_full_path(kpms_project_output_dir, model_name) + pdf_path = model_name_full_path / "fitting_progress.pdf" + png_path = model_name_full_path / "fitting_progress.png" # Get the path to the updated config file - updated_cfg_path = kpms_reader._dj_config_path(kpms_project_output_dir) + kpms_dj_config_path = kpms_reader._kpms_dj_config_path(kpms_project_output_dir) - # Get the path to the full fit model directory - fullfit_model_dir = kpms_project_output_dir / Path(model_name).name - - # Check for progress plot files - pdf_path = fullfit_model_dir / "fitting_progress.pdf" - png_path = fullfit_model_dir / "fitting_progress.png" if not pdf_path.exists(): - raise FileNotFoundError( - f"FullFit PDF progress plot not found at {pdf_path}" - ) + raise FileNotFoundError(f"PreFit PDF progress plot not found at {pdf_path}") if not png_path.exists(): - raise FileNotFoundError( - f"FullFit PNG progress plot not found at {png_path}" - ) + raise FileNotFoundError(f"PreFit PNG progress plot not found at {png_path}") # Find checkpoint file checkpoint_files = [] for pattern in ("checkpoint*", "*.h5"): - checkpoint_files.extend(fullfit_model_dir.glob(pattern)) + checkpoint_files.extend(model_name_full_path.glob(pattern)) if checkpoint_files: checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_size) else: - raise FileNotFoundError(f"No checkpoint files found in {fullfit_model_dir}") + raise FileNotFoundError( + f"No checkpoint files found in {model_name_full_path}" + ) completion_time = datetime.now(timezone.utc) @@ -1364,7 +1352,7 @@ def make(self, key): self.ConfigFile.insert1( { **key, - "config_file": updated_cfg_path, + "config_file": kpms_dj_config_path, } ) From e737a6c1de195124f8444ddf98c468cfa2b0520d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 05:59:40 +0200 Subject: [PATCH 17/41] refactor(moseq_train) --- element_moseq/moseq_train.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 93a2ddf..7001937 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1,5 +1,5 @@ """ -Code adapted from the Datta Lab +Code adapted from the Datta Lab: https://dattalab.github.io/moseq2-website/index.html DataJoint Schema for Keypoint-MoSeq training pipeline """ @@ -899,6 +899,7 @@ class PreFit(dj.Computed): -> PreFitTask # `PreFitTask` Key --- model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" + model_dict : longblob # Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed pre_fit_time=NULL : datetime # datetime of the model fitting computation. pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation """ @@ -914,16 +915,16 @@ class ConfigFile(dj.Part): config_file: attach # Updated KPMS DJ config file after PreFit computation. """ - class CheckpointFile(dj.Part): + class File(dj.Part): """ Store the checkpoint file used for resuming the pre-fitting process. """ definition = """ -> master + file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5'). --- - checkpoint_file_name : varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.h5'). - checkpoint_file : filepath@moseq-train-processed # Path to the checkpoint file. + file : filepath@moseq-train-processed # Path to the file in the processed data directory. """ class Plots(dj.Part): @@ -1037,7 +1038,7 @@ def make(self, key): num_iters=pre_num_iterations, generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ save_every_n_iters=10, - ) # model_name is already the correct directory name + ) # Create a PNG version fo the PDF progress plot png_path, pdf_path = viz_utils.copy_pdf_to_png( kpms_project_output_dir, model_name @@ -1066,7 +1067,7 @@ def make(self, key): for pattern in ("checkpoint*", "*.h5"): checkpoint_files.extend(model_name_full_path.glob(pattern)) if checkpoint_files: - checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_size) + checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_mtime) else: raise FileNotFoundError( f"No checkpoint files found in {model_name_full_path}" @@ -1084,6 +1085,7 @@ def make(self, key): **key, "model_name": model_name, "pre_fit_duration": duration_seconds, + "model_dict": model, } ) @@ -1150,9 +1152,10 @@ class FullFit(dj.Computed): definition = """ -> FullFitTask # `FullFitTask` Key --- - model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" + model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name". + model_dict : longblob # Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed. full_fit_time=NULL : datetime # datetime of the full fitting computation. - full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation + full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation. """ class ConfigFile(dj.Part): @@ -1166,16 +1169,16 @@ class ConfigFile(dj.Part): config_file: attach # Updated KPMS DJ config attachment. """ - class CheckpointFile(dj.Part): + class File(dj.Part): """ Store the checkpoint file used for resuming the full-fitting process. """ definition = """ -> master + file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5'). --- - checkpoint_file_name : varchar(1000) # Name of the checkpoint file (e.g. 'checkpoint.p'). - checkpoint_file : filepath@moseq-train-processed # Path to the checkpoint file in the processed data directory. + file : filepath@moseq-train-processed # Path to the file in the processed data directory. """ class Plots(dj.Part): @@ -1323,7 +1326,7 @@ def make(self, key): for pattern in ("checkpoint*", "*.h5"): checkpoint_files.extend(model_name_full_path.glob(pattern)) if checkpoint_files: - checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_size) + checkpoint_file = max(checkpoint_files, key=lambda f: f.stat().st_mtime) else: raise FileNotFoundError( f"No checkpoint files found in {model_name_full_path}" @@ -1345,6 +1348,7 @@ def make(self, key): ).as_posix(), "full_fit_time": completion_time, "full_fit_duration": duration_seconds, + "model_dict": model, } ) @@ -1366,7 +1370,7 @@ def make(self, key): ) # Insert checkpoint file - self.CheckpointFile.insert1( + self.File.insert1( { **key, "checkpoint_file_name": checkpoint_file.name, From 9de67bd2879a45872a30dc63b709ad729ba298e4 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 06:57:25 +0200 Subject: [PATCH 18/41] refactor(moseq_train) --- element_moseq/moseq_train.py | 76 ++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 7001937..78b54a8 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -612,14 +612,14 @@ class PCAFit(dj.Computed): class File(dj.Part): """ - Store the PCA files (pca.p file). + Store the PCA files (pca.p, data.pkl, metadata.pkl files). """ definition = """ -> master - file_name : varchar(255) # name of the pca file (e.g. 'pca.p'). + file_name : varchar(255) # name of the file (e.g. 'pca.p', 'data.pkl', 'metadata.pkl'). --- - file_path : filepath@moseq-train-processed # path to the pca file (relative to the project output directory). + file_path : filepath@moseq-train-processed # path to the file (relative to the project output directory). """ def make(self, key): @@ -685,7 +685,6 @@ def make(self, key): raise FileNotFoundError( f"No pca file (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - # Insert all files in a single operation file_paths = [pca_path, data_path, metadata_path] completion_time = datetime.now(timezone.utc) @@ -899,7 +898,6 @@ class PreFit(dj.Computed): -> PreFitTask # `PreFitTask` Key --- model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" - model_dict : longblob # Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed pre_fit_time=NULL : datetime # datetime of the model fitting computation. pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation """ @@ -917,12 +915,12 @@ class ConfigFile(dj.Part): class File(dj.Part): """ - Store the checkpoint file used for resuming the pre-fitting process. + Store the checkpoint file and model data file used for resuming the pre-fitting process. """ definition = """ -> master - file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5'). + file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5', 'model_data.pkl'). --- file : filepath@moseq-train-processed # Path to the file in the processed data directory. """ @@ -1080,15 +1078,33 @@ def make(self, key): else: duration_seconds = None + # Save model dictionary as pickle file + model_data_filename = "model_data.pkl" + model_data_file = kpms_project_output_dir / model_data_filename + with open(model_data_file, "wb") as f: + pickle.dump(model, f) + + file_paths = [checkpoint_file, model_data_file] + self.insert1( { **key, "model_name": model_name, "pre_fit_duration": duration_seconds, - "model_dict": model, } ) + self.File.insert( + [ + { + **key, + "file_name": file.name, + "file": file.as_posix(), + } + for file in file_paths + ] + ) + self.ConfigFile.insert1( dict( **key, @@ -1104,14 +1120,6 @@ def make(self, key): } ) - self.CheckpointFile.insert1( - { - **key, - "checkpoint_file_name": checkpoint_file.name, - "checkpoint_file": checkpoint_file, - } - ) - @schema class FullFitTask(dj.Manual): @@ -1153,7 +1161,6 @@ class FullFit(dj.Computed): -> FullFitTask # `FullFitTask` Key --- model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name". - model_dict : longblob # Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed. full_fit_time=NULL : datetime # datetime of the full fitting computation. full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation. """ @@ -1171,12 +1178,12 @@ class ConfigFile(dj.Part): class File(dj.Part): """ - Store the checkpoint file used for resuming the full-fitting process. + Store the checkpoint file and model data file used for resuming the full-fitting process. """ definition = """ -> master - file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5'). + file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5', 'model_data.pkl'). --- file : filepath@moseq-train-processed # Path to the file in the processed data directory. """ @@ -1242,10 +1249,10 @@ def make(self, key): "coordinates", "confidences" ) data_path = (PCAFit.File & key & 'file_name="data.pkl"').fetch1("file_path") - data = pickle.load(open(data_path, "rb")) metadata_path = (PCAFit.File & key & 'file_name="metadata.pkl"').fetch1( "file_path" ) + data = pickle.load(open(data_path, "rb")) metadata = pickle.load(open(metadata_path, "rb")) average_frame_rate = (PreProcessing & key).fetch1("average_frame_rate") @@ -1339,6 +1346,14 @@ def make(self, key): else: duration_seconds = None + # Save model dictionary as pickle file + model_data_filename = "model_data.pkl" + model_data_file = kpms_project_output_dir / model_data_filename + with open(model_data_file, "wb") as f: + pickle.dump(model, f) + + file_paths = [checkpoint_file, model_data_file] + self.insert1( { **key, @@ -1348,10 +1363,20 @@ def make(self, key): ).as_posix(), "full_fit_time": completion_time, "full_fit_duration": duration_seconds, - "model_dict": model, } ) + self.File.insert( + [ + { + **key, + "file_name": file.name, + "file": file.as_posix(), + } + for file in file_paths + ] + ) + # Insert config file self.ConfigFile.insert1( { @@ -1369,15 +1394,6 @@ def make(self, key): } ) - # Insert checkpoint file - self.File.insert1( - { - **key, - "checkpoint_file_name": checkpoint_file.name, - "checkpoint_file": checkpoint_file, - } - ) - @schema class SelectedFullFit(dj.Manual): From 6efbb21ef6af8c78ffc2cf8eb2a93a9e4cf643ab Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 07:33:29 +0200 Subject: [PATCH 19/41] chore(moseq_train): minor fix --- element_moseq/moseq_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 78b54a8..0a5fc23 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1080,7 +1080,7 @@ def make(self, key): # Save model dictionary as pickle file model_data_filename = "model_data.pkl" - model_data_file = kpms_project_output_dir / model_data_filename + model_data_file = model_name_full_path / model_data_filename with open(model_data_file, "wb") as f: pickle.dump(model, f) @@ -1348,7 +1348,7 @@ def make(self, key): # Save model dictionary as pickle file model_data_filename = "model_data.pkl" - model_data_file = kpms_project_output_dir / model_data_filename + model_data_file = model_name_full_path / model_data_filename with open(model_data_file, "wb") as f: pickle.dump(model, f) From d9a7049f9bc9ff7146ab77ac3b3a41f3be3e3e94 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 21:55:47 +0200 Subject: [PATCH 20/41] feat(moseq_train): add `ModelScore` --- element_moseq/moseq_train.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 0a5fc23..2219166 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1395,6 +1395,54 @@ def make(self, key): ) +@schema +class ModelScore(dj.Computed): + """Compute model scores for trained models. + + Computes Marginal Log Likelihood (MLL) for a single model. + """ + + definition = """ + -> FullFit + --- + score=NULL : float # Model score (MLL for single model) + std_error=NULL : float # Standard error of the model score + """ + + def make(self, key): + import jax.numpy as jnp + from keypoint_moseq import load_checkpoint + from keypoint_moseq.fitting import marginal_log_likelihood + + # Get checkpoint file for this specific model + checkpoint_file = ( + FullFit.File & key & 'file_name LIKE "%checkpoint.h5"' + ).fetch1("file") + + # Load the checkpoint to get model data + model, data, _, _ = load_checkpoint(path=checkpoint_file) + + # Compute marginal log likelihood for this model + mask = jnp.array(data["mask"]) + x = jnp.array(model["states"]["x"]) + Ab = jnp.array(model["params"]["Ab"]) + Q = jnp.array(model["params"]["Q"]) + pi = jnp.array(model["params"]["pi"]) + + # Compute marginal log likelihood - this is the correct metric for single models + mll = marginal_log_likelihood(mask, x, Ab, Q, pi) + score = float(mll) # Store as "score" - this is MLL + std_error = 0.0 # No standard error for single model MLL + + self.insert1( + { + **key, + "score": score, + "std_error": std_error, + } + ) + + @schema class SelectedFullFit(dj.Manual): """Register selected FullFit models for use in the inference pipeline. From 978bb1d3aeda2560fe0abc94cb7f45c4c5c3901d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 21:59:24 +0200 Subject: [PATCH 21/41] fix(Inference): refactor and solve `lost of connection` --- element_moseq/moseq_infer.py | 618 +++++++++++++++-------------------- 1 file changed, 267 insertions(+), 351 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index f6a7fae..9d8a0cf 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -1,14 +1,16 @@ """ -Code adapted from the Datta Lab +Code adapted from the Datta Lab: https://dattalab.github.io/moseq2-website/index.html DataJoint Schema for Keypoint-MoSeq inference pipeline """ import importlib import inspect +import pickle from datetime import datetime, timezone from pathlib import Path import datajoint as dj +import numpy as np from element_interface.utils import find_full_path from matplotlib import pyplot as plt @@ -83,6 +85,7 @@ class Model(dj.Manual): --- model_name : varchar(1000) # User-friendly model name model_dir : varchar(1000) # Model directory relative to root data directory + model_file : filepath@moseq-infer-processed # Model pkl file containing states, parameters, hyperparameters, noise prior, and random seed model_desc='' : varchar(1000) # Optional. User-defined description of the model -> [nullable] moseq_train.SelectedFullFit # Optional. FullFit key. """ @@ -148,6 +151,37 @@ class InferenceTask(dj.Manual): task_mode='load' : enum('load', 'trigger') # Task mode for the inference task """ + @classmethod + def infer_output_dir(cls, key: dict, relative: bool = False, mkdir: bool = False): + """Return the expected inference_output_dir. + + Based on convention: model_dir / inference_output_dir + If inference_output_dir is empty, generates a default based on model and recording. + + Args: + key: DataJoint key specifying a pairing of VideoRecording and Model. + relative (bool): Report directory relative to processed data directory. + mkdir (bool): Default False. Make directory if it doesn't exist. + """ + # Get model directory + model_dir_rel, model_file = (Model * moseq_train.SelectedFullFit & key).fetch1( + "model_dir", "model_file" + ) + kpms_processed = moseq_train.get_kpms_processed_data_dir() + + # Get recording info for default naming + recording_id = (VideoRecording & key).fetch1("recording_id") + + # Generate default output directory name + default_output_dir = f"inference_recording_id_{recording_id}" + + if mkdir: + # Create directory in the processed directory, not inside model directory + output_dir = Path(kpms_processed) / model_dir_rel / default_output_dir + output_dir.mkdir(parents=True, exist_ok=True) + + return default_output_dir + @schema class Inference(dj.Computed): @@ -155,119 +189,66 @@ class Inference(dj.Computed): Attributes: InferenceTask (foreign_key) : `InferenceTask` key. - syllable_segmentation_file (attach) : File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings. + syllable_segmentation_file (filepath): File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings. inference_duration (float) : Time duration (seconds) of the inference computation. """ definition = """ -> InferenceTask # `InferenceTask` key --- - syllable_segmentation_file : attach # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings + syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings inference_duration=NULL : float # Time duration (seconds) of the inference computation """ - class MotionSequence(dj.Part): - """Store the results of the model inference. - - Attributes: - InferenceTask (foreign key) : `InferenceTask` key. - video_name (varchar) : Name of the video. - syllable (longblob) : Syllable labels (z). The syllable label assigned to each frame (i.e. the state indexes assigned by the model). - latent_state (longblob) : Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model. - centroid (longblob) : Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model. - heading (longblob) : Inferred heading (h). The heading of the animal in each frame, as estimated by the model. - motion_sequence_file (attach) : File path of the temporal sequence of motion data (CSV format). - """ - - definition = """ - -> master - video_name : varchar(150) # Name of the video - --- - syllable : longblob # Syllable labels (z). The syllable label assigned to each frame (i.e. the state indexes assigned by the model) - latent_state : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model - centroid : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model - heading : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model - motion_sequence_file: attach # File path of the temporal sequence of motion data (CSV format) - """ - - class GridMoviesSampledInstances(dj.Part): - """Store the sampled instances of the grid movies. - - Attributes: - syllable (int) : Syllable label. - instances (longblob) : List of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame. - """ - - definition = """ - -> master - syllable: int # Syllable label - --- - instances: longblob # List of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame - """ - - class InferencePlots(dj.Part): - """Store the main inference plots. - - Attributes: - InferenceTask (foreign key) : `InferenceTask` key. - plot_name (varchar) : Name of the plot. - plot_file (attach) : File path of the plot. - """ - - definition = """ - -> master - plot_name: varchar(150) # Name of the plot (e.g. syllable_frequencies, similarity_dendrogram_png, similarity_dendrogram_pdf, all_trajectories_gif, all_trajectories_pdf) - --- - plot_file: attach # File path of the plot - """ - - class TrajectoryPlots(dj.Part): - """Store the per-syllable trajectory plots. - - Attributes: - InferenceTask (foreign key) : `InferenceTask` key. - syllable_id (int) : Syllable ID. - plot_gif (attach) : GIF plot file. - plot_pdf (attach) : PDF plot file. - grid_movie (attach) : Grid movie file. - """ - - definition = """ - -> master - syllable_id: int # Syllable ID - --- - plot_gif: attach # GIF plot file - plot_pdf: attach # PDF plot file - grid_movie: attach # Grid movie file - """ - def make_fetch(self, key): - """ - Fetch data required for model inference. - """ - ( - keypointset_dir, - inference_output_dir, - num_iterations, - model_id, - pose_estimation_method, - task_mode, - ) = (InferenceTask & key).fetch1( + + (keypointset_dir, inference_output_dir, num_iterations, task_mode,) = ( + InferenceTask & key + ).fetch1( "keypointset_dir", "inference_output_dir", "num_iterations", - "model_id", - "pose_estimation_method", "task_mode", ) + model_dir_rel, model_file = (Model * moseq_train.SelectedFullFit & key).fetch1( + "model_dir", "model_file" + ) # model dir relative to processed data directory + + model_key = (Model * moseq_train.SelectedFullFit & key).fetch1("KEY") + checkpoint_file_path = ( + moseq_train.FullFit.File & model_key & 'file_name="checkpoint.h5"' + ).fetch1("file") + kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( + "config_file" + ) + pca_file_path = ( + moseq_train.PCAFit.File & model_key & 'file_name="pca.p"' + ).fetch1("file_path") + + data_file_path = ( + moseq_train.PCAFit.File & model_key & 'file_name="data.pkl"' + ).fetch1("file_path") + metadata_file_path = ( + moseq_train.PCAFit.File & model_key & 'file_name="metadata.pkl"' + ).fetch1("file_path") + coordinates, confidences = (moseq_train.PreProcessing & model_key).fetch( + "coordinates", "confidences" + ) return ( keypointset_dir, inference_output_dir, num_iterations, - model_id, - pose_estimation_method, task_mode, + model_dir_rel, + model_file, + checkpoint_file_path, + kpms_dj_config_file, + pca_file_path, + data_file_path, + metadata_file_path, + coordinates, + confidences, ) def make_compute( @@ -276,9 +257,16 @@ def make_compute( keypointset_dir, inference_output_dir, num_iterations, - model_id, - pose_estimation_method, task_mode, + model_dir_rel, + model_file, + checkpoint_file_path, + kpms_dj_config_file, + pca_file_path, + data_file_path, + metadata_file_path, + coordinates, + confidences, ): """ Compute model inference results. @@ -303,296 +291,88 @@ def make_compute( """ from keypoint_moseq import ( apply_model, - filter_centroids_headings, format_data, - generate_grid_movies, - generate_trajectory_plots, - get_syllable_instances, load_checkpoint, load_keypoints, load_pca, load_results, - plot_similarity_dendrogram, - plot_syllable_frequencies, - sample_instances, save_results_as_csv, ) # Constants used by default as in kpms - DEFAULT_NUM_ITERS = 50 - FILTER_SIZE = 9 - MIN_DURATION = 3 - MIN_FREQUENCY = 0.005 - GRID_SAMPLES = 4 * 6 # minimum rows * cols + DEFAULT_NUM_ITERS = 500 + # Get directories first kpms_root = moseq_train.get_kpms_root_data_dir() kpms_processed = moseq_train.get_kpms_processed_data_dir() - model_dir = find_full_path( - kpms_processed, - (Model & f"model_id = {model_id}").fetch1("model_dir"), - ) - keypointset_dir = find_full_path(kpms_root, keypointset_dir) - - inference_output_dir = Path(model_dir) / inference_output_dir - - if not inference_output_dir.exists(): - inference_output_dir.mkdir(parents=True, exist_ok=True) - - pca_path = model_dir.parent / "pca.p" - if pca_path: - pca = load_pca(model_dir.parent.as_posix()) - else: - raise FileNotFoundError( - f"No pca model (`pca.p`) found in the parent model directory {model_dir.parent}" - ) - - model_path = model_dir / "checkpoint.h5" - if model_path: - model = load_checkpoint( - project_dir=model_dir.parent, model_name=model_dir.parts[-1] - )[0] - else: - raise FileNotFoundError( - f"No model (`checkpoint.h5`) found in the model directory {model_dir}" + if not inference_output_dir: + inference_output_dir = InferenceTask.infer_output_dir( + key, relative=True, mkdir=True ) + # Update the inference_output_dir in the database + InferenceTask.update1({**key, "inference_output_dir": inference_output_dir}) - if pose_estimation_method == "deeplabcut": - coordinates, confidences, _ = load_keypoints( - filepath_pattern=keypointset_dir, format=pose_estimation_method - ) + # Construct the full path to the inference output directory + inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir - kpms_dj_config = kpms_reader.load_kpms_dj_config(model_dir.parent) + inference_output_dir.mkdir(parents=True, exist_ok=True) + model_dir = find_full_path(kpms_processed, model_dir_rel) + keypointset_dir = find_full_path(kpms_root, keypointset_dir) - if kpms_dj_config: - data, metadata = format_data(coordinates, confidences, **kpms_dj_config) - else: - raise FileNotFoundError( - f"No valid `kpms_dj_config` found in the parent model directory {model_dir.parent}" - ) + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_file + ) + metadata = pickle.load(open(metadata_file_path, "rb")) + data = pickle.load(open(data_file_path, "rb")) + model_data = pickle.load(open(model_file, "rb")) if task_mode == "trigger": start_time = datetime.now(timezone.utc) results = apply_model( - model=model, + model_name=inference_output_dir.name, + model=model_data, data=data, metadata=metadata, - pca=pca, - project_dir=model_dir.parent.as_posix(), - model_name=Path(model_dir).name, - results_path=(inference_output_dir / "results.h5").as_posix(), + pca=pca_file_path, + project_dir=inference_output_dir.parent, + results_path=(inference_output_dir / "results.h5"), return_model=False, num_iters=num_iterations or DEFAULT_NUM_ITERS, overwrite=True, - **kpms_dj_config, + save_results=True, + **kpms_dj_config_dict, ) - end_time = datetime.now(timezone.utc) + end_time = datetime.now(timezone.utc) duration_seconds = (end_time - start_time).total_seconds() + # Create results directory and save CSV files save_results_as_csv( results=results, save_dir=(inference_output_dir / "results_as_csv").as_posix(), ) - fig, _ = plot_syllable_frequencies( - results=results, path=inference_output_dir.as_posix() - ) - fig.savefig(inference_output_dir / "syllable_frequencies.png") - plt.close(fig) - - generate_trajectory_plots( - coordinates=coordinates, - results=results, - output_dir=(inference_output_dir / "trajectory_plots").as_posix(), - **kpms_dj_config, - ) - - sampled_instances = generate_grid_movies( - coordinates=coordinates, - results=results, - output_dir=(inference_output_dir / "grid_movies").as_posix(), - **kpms_dj_config, - ) - - plot_similarity_dendrogram( - coordinates=coordinates, - results=results, - save_path=(inference_output_dir / "similarity_dendrogram").as_posix(), - **kpms_dj_config, - ) - else: # For load mode - # load results results = load_results( - project_dir=inference_output_dir.parent, - model_name=inference_output_dir.parts[-1], - ) - - # extract syllables from results - syllables = {k: v["syllable"] for k, v in results.items()} - - # extract and smooth centroids and headings - centroids = {k: v["centroid"] for k, v in results.items()} - headings = {k: v["heading"] for k, v in results.items()} - - centroids, headings = filter_centroids_headings( - centroids, headings, filter_size=FILTER_SIZE - ) - - # extract sample instances for each syllable - syllable_instances = get_syllable_instances( - syllables, min_duration=MIN_DURATION, min_frequency=MIN_FREQUENCY - ) - # Map each syllable to a list of its sampled events. - sampled_instances = sample_instances( - syllable_instances=syllable_instances, - num_samples=GRID_SAMPLES, - coordinates=coordinates, - centroids=centroids, - headings=headings, - ) - - duration_seconds = None - - # Prepare motion sequence data - motion_sequence_data = [] - for result_idx, result in results.items(): - motion_sequence_data.append( - { - **key, - "video_name": result_idx, - "syllable": result["syllable"], - "latent_state": result["latent_state"], - "centroid": result["centroid"], - "heading": result["heading"], - "motion_sequence_file": ( - inference_output_dir / "results_as_csv" / f"{result_idx}.csv" - ).as_posix(), - } - ) - - # Prepare grid movie data - grid_movie_data = [] - for syllable, sampled_instance in sampled_instances.items(): - grid_movie_data.append( - {**key, "syllable": syllable, "instances": sampled_instance} - ) - - # Prepare inference plots data - inference_plots_data = [] - if task_mode == "trigger": - # Main plots generated during trigger mode - inference_plots_data = [ - { - **key, - "plot_name": "syllable_frequencies", - "plot_file": inference_output_dir / "syllable_frequencies.png", - }, - { - **key, - "plot_name": "similarity_dendrogram_png", - "plot_file": inference_output_dir / "similarity_dendrogram.png", - }, - { - **key, - "plot_name": "similarity_dendrogram_pdf", - "plot_file": inference_output_dir / "similarity_dendrogram.pdf", - }, - { - **key, - "plot_name": "all_trajectories_gif", - "plot_file": inference_output_dir - / "trajectory_plots" - / "all_trajectories.gif", - }, - { - **key, - "plot_name": "all_trajectories_pdf", - "plot_file": inference_output_dir - / "trajectory_plots" - / "all_trajectories.pdf", - }, - ] - else: - # For load mode, check if files exist - main_plots = [ - ("syllable_frequencies", "syllable_frequencies.png"), - ("similarity_dendrogram_png", "similarity_dendrogram.png"), - ("similarity_dendrogram_pdf", "similarity_dendrogram.pdf"), - ("all_trajectories_gif", "trajectory_plots/all_trajectories.gif"), - ("all_trajectories_pdf", "trajectory_plots/all_trajectories.pdf"), - ] - - for plot_name, file_path in main_plots: - full_path = inference_output_dir / file_path - if full_path.exists(): - inference_plots_data.append( - { - **key, - "plot_name": plot_name, - "plot_file": full_path, - } - ) - - # Prepare trajectory plots data - trajectory_plots_data = [] - for syllable in sampled_instances.keys(): - syllable_gif = ( - inference_output_dir / "trajectory_plots" / f"syllable{syllable}.gif" - ) - syllable_pdf = ( - inference_output_dir / "trajectory_plots" / f"syllable{syllable}.pdf" - ) - grid_movie_mp4 = ( - inference_output_dir / "grid_movies" / f"syllable{syllable}.mp4" - ) - grid_movie_gif = ( - inference_output_dir - / "grid_movies" - / f"syllable{syllable}_grid_movie.gif" + project_dir=model_dir, + model_name=model_dir.name, ) + duration_seconds = None - # Convert MP4 to GIF if needed (from InferenceReport logic) - if grid_movie_mp4.exists() and not grid_movie_gif.exists(): - import imageio - - reader = imageio.get_reader(grid_movie_mp4) - fps = reader.get_meta_data()["fps"] - writer = imageio.get_writer(grid_movie_gif, fps=fps, loop=0) - for frame in reader: - writer.append_data(frame) - writer.close() - - trajectory_plots_data.append( - { - **key, - "syllable_id": syllable, - "plot_gif": syllable_gif if syllable_gif.exists() else None, - "plot_pdf": syllable_pdf if syllable_pdf.exists() else None, - "grid_movie": grid_movie_gif if grid_movie_gif.exists() else None, - } - ) + results_filepath = (inference_output_dir / "results.h5").as_posix() return ( duration_seconds, - motion_sequence_data, - grid_movie_data, - inference_plots_data, - trajectory_plots_data, - inference_output_dir, + results_filepath, ) def make_insert( self, key, duration_seconds, - motion_sequence_data, - grid_movie_data, - inference_plots_data, - trajectory_plots_data, - inference_output_dir, + results_filepath, ): """ Insert inference results into the database. @@ -601,24 +381,160 @@ def make_insert( { **key, "inference_duration": duration_seconds, - "syllable_segmentation_file": ( - inference_output_dir / "results.h5" - ).as_posix(), + "syllable_segmentation_file": results_filepath, } ) - for motion_record in motion_sequence_data: - self.MotionSequence.insert1(motion_record) - for grid_record in grid_movie_data: - self.GridMoviesSampledInstances.insert1(grid_record) +@schema +class MotionSequence(dj.Computed): + """Expand inference results into per-video sequences and sampled instances.""" + + definition = """ + -> Inference + --- + motion_sequence_duration=NULL : float + """ + + class VideoSequence(dj.Part): + """Store the per-video sequences.""" + + definition = """ + -> master + -> VideoRecording.File # Foreign key to VideoRecording.File + --- + syllables : longblob # Syllable labels (z). The syllable label assigned to each frame (i.e. the state indexes assigned by the model) + latent_states : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model + centroids : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model + headings : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model + file_csv : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format) + """ + + class SampledInstance(dj.Part): + """Store the sampled instances of the grid movies.""" + + definition = """ + -> master + syllable: int + --- + instances: longblob + """ + + def make(self, key): + import h5py + from keypoint_moseq import ( + filter_centroids_headings, + get_syllable_instances, + load_keypoints, + load_results, + sample_instances, + ) + + execution_time = datetime.now(timezone.utc) + + # Constants used by default as in kpms + FILTER_SIZE = 9 + MIN_DURATION = 3 + MIN_FREQUENCY = 0.005 + GRID_SAMPLES = 4 * 6 # minimum rows * cols + # Fetch base params + ( + keypointset_dir, + inference_output_dir, + model_dir, + num_iterations, + task_mode, + ) = (InferenceTask * Model & key).fetch1( + "keypointset_dir", + "inference_output_dir", + "model_dir", + "num_iterations", + "task_mode", + ) + kpms_root = moseq_train.get_kpms_root_data_dir() + kpms_processed = moseq_train.get_kpms_processed_data_dir() + + # Get the full paths + keypointset_dir = find_full_path(kpms_root, keypointset_dir) + + # Handle default inference_output_dir if not provided + if not inference_output_dir: + inference_output_dir = InferenceTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # Update the inference_output_dir in the database + InferenceTask.update1({**key, "inference_output_dir": inference_output_dir}) + + inference_output_dir = Path(model_dir) / inference_output_dir + inference_output_dir = find_full_path(kpms_processed, inference_output_dir) + + model_key = (Model * moseq_train.SelectedFullFit & key).fetch1("KEY") + coordinates, confidences = (moseq_train.PreProcessing & model_key).fetch( + "coordinates", "confidences" + ) + + results_file = (Inference & key).fetch1("syllable_segmentation_file") + + file_ids, file_paths = (VideoRecording.File & key).fetch("file_id", "file_path") + + video_name_to_file_id = {} + for file_id, file_path in zip(file_ids, file_paths): + base_video_name = Path(file_path).stem + video_name_to_file_id[base_video_name] = file_id + + with h5py.File(results_file, "r") as results: + syllables = {k: np.array(v["syllable"]) for k, v in results.items()} + latent_states = {k: np.array(v["latent_state"]) for k, v in results.items()} + centroids = {k: np.array(v["centroid"]) for k, v in results.items()} + headings = {k: np.array(v["heading"]) for k, v in results.items()} + video_keys = list(results.keys()) + + filtered_centroids, filtered_headings = filter_centroids_headings( + centroids, headings, filter_size=FILTER_SIZE + ) + + motion_rows = [] + for vid in video_keys: + matched_file_id = None + for base_video_name, file_id in video_name_to_file_id.items(): + if vid.startswith(base_video_name): + matched_file_id = file_id + break + + if matched_file_id is not None: + motion_rows.append( + { + **key, + "file_id": matched_file_id, + "syllables": syllables[vid], + "latent_states": latent_states[vid], + "centroids": filtered_centroids[vid], + "headings": filtered_headings[vid], + "file_csv": ( + inference_output_dir / "results_as_csv" / f"{vid}.csv" + ).as_posix(), + } + ) + syllable_instances = get_syllable_instances( + syllables, min_duration=MIN_DURATION, min_frequency=MIN_FREQUENCY + ) + sampled = sample_instances( + syllable_instances=syllable_instances, + num_samples=GRID_SAMPLES, + coordinates=coordinates, + centroids=filtered_centroids, + headings=filtered_headings, + ) + + sampled_rows = [ + {**key, "syllable": s, "instances": inst} for s, inst in sampled.items() + ] + + completion_time = datetime.now(timezone.utc) + duration_seconds = (completion_time - execution_time).total_seconds() + + self.insert1({**key, "motion_sequence_duration": duration_seconds}) - for plot_record in inference_plots_data: - self.InferencePlots.insert1(plot_record) + self.VideoSequence.insert(motion_rows) - for trajectory_record in trajectory_plots_data: - if any( - trajectory_record.get(field) - for field in ["plot_gif", "plot_pdf", "grid_movie"] - ): - self.TrajectoryPlots.insert1(trajectory_record) + self.SampledInstance.insert(sampled_rows) From 3e7981da214559bf87904b584bb5b3642ae1d6e5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 22:45:55 +0200 Subject: [PATCH 22/41] fix(moseq_infer) --- element_moseq/moseq_infer.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 9d8a0cf..5cd9b19 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -211,6 +211,13 @@ def make_fetch(self, key): "task_mode", ) + if not inference_output_dir: + inference_output_dir = InferenceTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # Update the inference_output_dir in the database + InferenceTask.update1({**key, "inference_output_dir": inference_output_dir}) + model_dir_rel, model_file = (Model * moseq_train.SelectedFullFit & key).fetch1( "model_dir", "model_file" ) # model dir relative to processed data directory @@ -302,17 +309,12 @@ def make_compute( # Constants used by default as in kpms DEFAULT_NUM_ITERS = 500 + start_time = datetime.now(timezone.utc) + # Get directories first kpms_root = moseq_train.get_kpms_root_data_dir() kpms_processed = moseq_train.get_kpms_processed_data_dir() - if not inference_output_dir: - inference_output_dir = InferenceTask.infer_output_dir( - key, relative=True, mkdir=True - ) - # Update the inference_output_dir in the database - InferenceTask.update1({**key, "inference_output_dir": inference_output_dir}) - # Construct the full path to the inference output directory inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir @@ -328,7 +330,6 @@ def make_compute( data = pickle.load(open(data_file_path, "rb")) model_data = pickle.load(open(model_file, "rb")) if task_mode == "trigger": - start_time = datetime.now(timezone.utc) results = apply_model( model_name=inference_output_dir.name, model=model_data, @@ -344,15 +345,15 @@ def make_compute( **kpms_dj_config_dict, ) - end_time = datetime.now(timezone.utc) - duration_seconds = (end_time - start_time).total_seconds() - # Create results directory and save CSV files save_results_as_csv( results=results, save_dir=(inference_output_dir / "results_as_csv").as_posix(), ) + end_time = datetime.now(timezone.utc) + duration_seconds = (end_time - start_time).total_seconds() + else: # For load mode results = load_results( From 83a825c28d891897bc71203906ac4aeeaafac351 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sat, 25 Oct 2025 22:47:29 +0200 Subject: [PATCH 23/41] feat(SelectedFullFit): classmethod for automated insertion of best model --- element_moseq/moseq_train.py | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 2219166..56e572f 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1459,3 +1459,45 @@ class SelectedFullFit(dj.Manual): registered_model_name : varchar(1000) # User-friendly model name registered_model_desc='' : varchar(1000) # Optional user-defined description """ + + @classmethod + def select_best_model(cls, key, model_desc="Best model based on MLL score"): + """Automatically select the best model for a FullFit based on highest MLL score. + + Args: + pcafit_key (dict): PCAFit key to filter models + model_desc (str): Description for the selected model + + Returns: + dict: The key of the selected model + """ + # Get all models with their scores for this specific PCAFit + models_with_scores = (FullFit * ModelScore & key).fetch() + + if len(models_with_scores) == 0: + raise ValueError( + f"No models with scores found for PCAFit {key}. Run ModelScore.populate() first." + ) + + # Find the model with the highest score (best MLL) + best_model_idx = models_with_scores["score"].argmax() + best_model_key = { + k: models_with_scores[k][best_model_idx] for k in FullFit.primary_key + } + best_model_name = models_with_scores["model_name"][best_model_idx] + best_score = models_with_scores["score"][best_model_idx] + + print(f"Selected best model: {best_model_name}") + print(f"Model score (MLL): {best_score:.2f}") + + # Insert the best model into SelectedFullFit + cls.insert1( + { + **best_model_key, + "registered_model_name": best_model_name, + "registered_model_desc": f"{model_desc} (MLL: {best_score:.2f})", + }, + skip_duplicates=True, + ) + + return best_model_key From 4949452cf71b58687ec7f1ef462f26c7ee57da6a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sun, 26 Oct 2025 01:05:51 +0200 Subject: [PATCH 24/41] chore(PCATask): add classmethod to infer output dir when empty --- element_moseq/moseq_train.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 56e572f..5d324ef 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -214,6 +214,33 @@ class PCATask(dj.Manual): task_mode='load' :enum('load','trigger') # 'load' to load existing results, 'trigger' to compute new PCA """ + @classmethod + def infer_output_dir(cls, key: dict, relative: bool = False, mkdir: bool = False): + """Return the expected kpms_project_output_dir. + + If kpms_project_output_dir is empty, generates a default based on keypointset and session info. + + Args: + key: DataJoint key specifying a PCATask. + relative (bool): Report directory relative to processed data directory. + mkdir (bool): Default False. Make directory if it doesn't exist. + """ + # Get keypointset info for default naming + kpset_id = (KeypointSet & key).fetch1("kpset_id") + + # Get bodyparts info for unique naming + bodyparts_id = (BodyParts & key).fetch1("bodyparts_id") + + # Generate default output directory name + default_output_dir = f"kpset_id_{kpset_id}_bodyparts_id_{bodyparts_id}" + + if mkdir: + # Create directory in the processed directory + output_dir = Path(get_kpms_processed_data_dir()) / default_output_dir + output_dir.mkdir(parents=True, exist_ok=True) + + return default_output_dir + @schema class PreProcessing(dj.Computed): @@ -292,6 +319,14 @@ def make_fetch(self, key): PCATask & key ).fetch1("kpms_project_output_dir", "task_mode", "outlier_scale_factor") + # Handle default kpms_project_output_dir if not provided + if not kpms_project_output_dir: + kpms_project_output_dir = PCATask.infer_output_dir( + key, relative=True, mkdir=True + ) + # Update the kpms_project_output_dir in the database + PCATask.update1({**key, "kpms_project_output_dir": kpms_project_output_dir}) + return ( anterior_bodyparts, posterior_bodyparts, From 32964ac602c8be0801cb54c47de436e3c9b65e67 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sun, 26 Oct 2025 01:10:11 +0200 Subject: [PATCH 25/41] add `moseq_report` --- element_moseq/moseq_report.py | 240 ++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 element_moseq/moseq_report.py diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py new file mode 100644 index 0000000..94999a5 --- /dev/null +++ b/element_moseq/moseq_report.py @@ -0,0 +1,240 @@ +""" +DataJoint Schema for Keypoint-MoSeq reporting and visualization +""" + +import importlib +import inspect +from datetime import datetime, timezone +from pathlib import Path + +import datajoint as dj +import h5py +import numpy as np +from element_interface.utils import find_full_path +from matplotlib import pyplot as plt + +from . import moseq_infer, moseq_train +from .readers import kpms_reader + +schema = dj.schema() +_linking_module = None +logger = dj.logger + + +def activate( + report_schema_name: str, + *, + create_schema: bool = True, + create_tables: bool = True, + linking_module: str = None, +): + """Activate this schema. + + Args: + report_schema_name (str): Schema name on the database server to activate the `moseq_report` schema. + create_schema (bool): When True (default), create schema in the database if it + does not yet exist. + create_tables (bool): When True (default), create schema tables in the database + if they do not yet exist. + linking_module (str): A module (or name) containing the required dependencies. + + Functions: + get_kpms_root_data_dir(): Returns absolute path for root data director(y/ies) with all behavioral recordings, as (list of) string(s) + get_kpms_processed_data_dir(): Optional. Returns absolute path for processed data. + """ + + if isinstance(linking_module, str): + linking_module = importlib.import_module(linking_module) + assert inspect.ismodule( + linking_module + ), "The argument 'linking_module' must be a module or module name" + + # activate + schema.activate( + report_schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=linking_module.__dict__, + ) + + +@schema +class BehavioralSummary(dj.Computed): + """Generate and store behavioral analysis visualizations from Keypoint-MoSeq inference.""" + + definition = """ + -> moseq_infer.Inference + --- + syllable_frequencies_plot : attach # File path of the syllable frequencies plot + similarity_dendrogram_png : attach # File path of the similarity dendrogram plot (PNG) + similarity_dendrogram_pdf : attach # File path of the similarity dendrogram plot (PDF) + """ + + def make(self, key): + + from keypoint_moseq import ( + format_data, + plot_similarity_dendrogram, + plot_syllable_frequencies, + ) + + model_dir = (moseq_infer.Model & key).fetch1("model_dir") + kpms_processed = moseq_train.get_kpms_processed_data_dir() + inference_output_dir = (moseq_infer.InferenceTask & key).fetch1( + "inference_output_dir" + ) + inference_output_dir = Path(model_dir) / inference_output_dir + inference_output_dir = find_full_path(kpms_processed, inference_output_dir) + + # Get inference data from upstream tables + results_file = (moseq_infer.Inference & key).fetch1( + "syllable_segmentation_file" + ) + + # Load results from H5 file + results = h5py.File(results_file, "r") + + # Generate syllable frequencies plot + fig, _ = plot_syllable_frequencies(results=results, path=inference_output_dir) + fig.savefig(inference_output_dir / "syllable_frequencies.png") + plt.close(fig) + + # Get coordinates and config for similarity dendrogram + model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1( + "KEY" + ) + coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates") + + # Get fps from config + config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1("config_file") + kpms_dj_config = kpms_reader.load_kpms_dj_config(config_path=config_file) + + # Generate similarity dendrogram plots + plot_similarity_dendrogram( + coordinates=coordinates, + results=results, + save_path=(inference_output_dir / "similarity_dendrogram").as_posix(), + **kpms_dj_config, + ) + + # Insert the record + self.insert1( + { + **key, + "syllable_frequencies_plot": inference_output_dir + / "syllable_frequencies.png", + "similarity_dendrogram_png": inference_output_dir + / "similarity_dendrogram.png", + "similarity_dendrogram_pdf": inference_output_dir + / "similarity_dendrogram.png", # Same file for now + } + ) + + +@schema +class TrajectoryPlot(dj.Computed): + """Generate per-syllable trajectory plots and grid movies for behavioral syllable analysis.""" + + definition = """ + -> moseq_infer.Inference + --- + all_trajectories_gif : attach # File path of the all trajectories GIF plot + all_trajectories_pdf : attach # File path of the all trajectories PDF plot + traj_duration=NULL : float # Time duration (seconds) + """ + + class Syllable(dj.Part): + definition = """ + -> master + syllable_id: int # Syllable ID + --- + plot_gif: attach # GIF plot file + plot_pdf: attach # PDF plot file + grid_movie: attach # Grid movie file + """ + + def make(self, key): + """Generate trajectory plots and grid movies.""" + from keypoint_moseq import generate_grid_movies, generate_trajectory_plots + + start_time = datetime.now(timezone.utc) + + # Get inference data + results_file = (moseq_infer.Inference & key).fetch1( + "syllable_segmentation_file" + ) + model_dir = (moseq_infer.Model & key).fetch1("model_dir") + inference_output_dir = (moseq_infer.InferenceTask & key).fetch1( + "inference_output_dir" + ) + + # Get model data from training schema + model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1( + "KEY" + ) + coordinates_dict = (moseq_train.PreProcessing & model_key).fetch1("coordinates") + + # Get config + kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( + "config_file" + ) + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_file + ) + + # Construct output directory + kpms_processed = moseq_train.get_kpms_processed_data_dir() + output_dir = Path(model_dir) / inference_output_dir + output_dir = find_full_path(kpms_processed, output_dir) + + # Create output directories + trajectory_dir = output_dir / "trajectory_plots" + grid_movies_dir = output_dir / "grid_movies" + trajectory_dir.mkdir(parents=True, exist_ok=True) + grid_movies_dir.mkdir(parents=True, exist_ok=True) + + # Load results + results = h5py.File(results_file, "r") + + # Generate trajectory plots + generate_trajectory_plots( + coordinates=coordinates_dict, + results=results, + output_dir=trajectory_dir.as_posix(), + **kpms_dj_config_dict, + ) + + # Generate grid movies + generate_grid_movies( + coordinates=coordinates_dict, + results=results, + output_dir=grid_movies_dir.as_posix(), + **kpms_dj_config_dict, + ) + + # Calculate duration + duration_seconds = (datetime.now(timezone.utc) - start_time).total_seconds() + + # Insert main record + self.insert1( + { + **key, + "all_trajectories_gif": trajectory_dir / "all_trajectories.gif", + "all_trajectories_pdf": trajectory_dir / "all_trajectories.pdf", + "traj_duration": duration_seconds, + } + ) + + # Insert per-syllable visuals + for syllable in (moseq_infer.MotionSequence.SampledInstance & key).fetch( + "syllable" + ): + self.Syllable.insert1( + { + **key, + "syllable_id": syllable, + "plot_gif": trajectory_dir / f"syllable{syllable}.gif", + "plot_pdf": trajectory_dir / f"syllable{syllable}.pdf", + "grid_movie": grid_movies_dir / f"syllable{syllable}.mp4", + } + ) From e444a158947662ac5f23a223330c22403289f944 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sun, 26 Oct 2025 01:13:52 +0200 Subject: [PATCH 26/41] cleanup --- element_moseq/moseq_report.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py index 94999a5..3d9c8b1 100644 --- a/element_moseq/moseq_report.py +++ b/element_moseq/moseq_report.py @@ -90,8 +90,6 @@ def make(self, key): results_file = (moseq_infer.Inference & key).fetch1( "syllable_segmentation_file" ) - - # Load results from H5 file results = h5py.File(results_file, "r") # Generate syllable frequencies plot @@ -99,22 +97,18 @@ def make(self, key): fig.savefig(inference_output_dir / "syllable_frequencies.png") plt.close(fig) - # Get coordinates and config for similarity dendrogram + # Generate similarity dendrogram plots model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1( "KEY" ) coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates") - - # Get fps from config config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1("config_file") - kpms_dj_config = kpms_reader.load_kpms_dj_config(config_path=config_file) - - # Generate similarity dendrogram plots + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(config_path=config_file) plot_similarity_dendrogram( coordinates=coordinates, results=results, save_path=(inference_output_dir / "similarity_dendrogram").as_posix(), - **kpms_dj_config, + **kpms_dj_config_dict, ) # Insert the record @@ -159,7 +153,6 @@ def make(self, key): start_time = datetime.now(timezone.utc) - # Get inference data results_file = (moseq_infer.Inference & key).fetch1( "syllable_segmentation_file" ) @@ -167,14 +160,10 @@ def make(self, key): inference_output_dir = (moseq_infer.InferenceTask & key).fetch1( "inference_output_dir" ) - - # Get model data from training schema model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1( "KEY" ) coordinates_dict = (moseq_train.PreProcessing & model_key).fetch1("coordinates") - - # Get config kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( "config_file" ) @@ -215,7 +204,6 @@ def make(self, key): # Calculate duration duration_seconds = (datetime.now(timezone.utc) - start_time).total_seconds() - # Insert main record self.insert1( { **key, @@ -224,8 +212,6 @@ def make(self, key): "traj_duration": duration_seconds, } ) - - # Insert per-syllable visuals for syllable in (moseq_infer.MotionSequence.SampledInstance & key).fetch( "syllable" ): From 471d105f8778a60ecf2b8ddbbc586acbf3105ca4 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Sun, 26 Oct 2025 01:15:51 +0200 Subject: [PATCH 27/41] update tutorial_pipeline --- notebooks/tutorial_pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/notebooks/tutorial_pipeline.py b/notebooks/tutorial_pipeline.py index 5e98937..0f86f32 100644 --- a/notebooks/tutorial_pipeline.py +++ b/notebooks/tutorial_pipeline.py @@ -4,7 +4,7 @@ from element_lab import lab from element_animal import subject from element_session import session_with_datetime as session -from element_moseq import moseq_train, moseq_infer +from element_moseq import moseq_train, moseq_infer, moseq_report from element_animal.subject import Subject from element_lab.lab import Source, Lab, Protocol, User, Project @@ -44,6 +44,7 @@ def get_kpms_processed_data_dir() -> str: "session", "moseq_train", "moseq_infer", + "moseq_report", "Device", ] @@ -87,3 +88,4 @@ class Device(dj.Lookup): moseq_train.activate(db_prefix + "moseq_train", linking_module=__name__) moseq_infer.activate(db_prefix + "moseq_infer", linking_module=__name__) +moseq_report.activate(db_prefix + "moseq_report", linking_module=__name__) From bccc03d08872171bd611f4b4a9914518e1721188 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 27 Oct 2025 18:31:50 +0100 Subject: [PATCH 28/41] refactor(kpms_reader): path functions --- element_moseq/readers/kpms_reader.py | 30 ++++++++++------------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index d2bf991..fdd9d9a 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -25,16 +25,11 @@ def _pose_estimation_config_path(kpset_dir: Union[str, os.PathLike]) -> str: Path to pose estimation config file (config.yml or config.yaml) """ kpset_path = Path(kpset_dir) - # Check for config.yml first (preferred) - config_yml = kpset_path / CONFIG_FILENAMES[0] - if config_yml.exists(): - return str(config_yml) - # Fall back to config.yaml - config_yaml = kpset_path / CONFIG_FILENAMES[1] - if config_yaml.exists(): - return str(config_yaml) - # If neither exists, return the default (config.yml) - return str(config_yml) + for filename in CONFIG_FILENAMES: + config_path = kpset_path / filename + if config_path.exists(): + return str(config_path) + return str(kpset_path / CONFIG_FILENAMES[0]) def _kpms_base_config_path(kpms_project_dir: Union[str, os.PathLike]) -> str: @@ -48,16 +43,11 @@ def _kpms_base_config_path(kpms_project_dir: Union[str, os.PathLike]) -> str: Path to KPMS base config file (config.yml or config.yaml) """ project_path = Path(kpms_project_dir) - # Check for config.yml first (preferred) - config_yml = project_path / CONFIG_FILENAMES[0] - if config_yml.exists(): - return str(config_yml) - # Fall back to config.yaml - config_yaml = project_path / CONFIG_FILENAMES[1] - if config_yaml.exists(): - return str(config_yaml) - # If neither exists, return the default (config.yml) - return str(config_yml) + for filename in CONFIG_FILENAMES: + config_path = project_path / filename + if config_path.exists(): + return str(config_path) + return str(project_path / CONFIG_FILENAMES[0]) def _kpms_dj_config_path(kpms_project_dir: Union[str, os.PathLike]) -> str: From 89dfe66bcc2886ccafcdbd02bb3491d0d9274f0e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 04:19:08 +0100 Subject: [PATCH 29/41] fix(kpms_readers): copy config.yml appropriately --- element_moseq/readers/kpms_reader.py | 123 ++++++++++++++++++++------- 1 file changed, 90 insertions(+), 33 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index fdd9d9a..885e5c5 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union import datajoint as dj import yaml @@ -82,9 +82,10 @@ def _check_config_validity(config: Dict[str, Any]) -> bool: f"ACTION REQUIRED: `posterior_bodyparts` contains {bp} " "which is not one of the options in `use_bodyparts`." ) + if errors: - for e in errors: - print(e) + for error in errors: + logger.warning(error) return False return True @@ -122,19 +123,44 @@ def dj_generate_config(kpms_project_dir: str, **kwargs) -> tuple: else: if not Path(kpms_base_config_path).exists(): raise FileNotFoundError( - f"Missing KPMS base config at {kpms_base_config_path}. " - f"Run keypoint_moseq's setup_project first. " - f"Expected either config.yml or config.yaml in {kpms_project_dir}." + f"Missing KPMS base config at {kpms_base_config_path}" ) kpms_dj_config_dict = kpms_base_config_dict.copy() + # Update bodyparts if provided + if "bodyparts" in kwargs: + kpms_dj_config_dict["bodyparts"] = list(kwargs["bodyparts"]) + + if "use_bodyparts" in kwargs: + use_bodyparts = list(kwargs["use_bodyparts"]) + kpms_dj_config_dict["use_bodyparts"] = use_bodyparts + + # Filter anterior/posterior to be subsets of use_bodyparts + if "anterior_bodyparts" in kwargs: + anterior = [ + bp for bp in kwargs["anterior_bodyparts"] if bp in use_bodyparts + ] + kwargs["anterior_bodyparts"] = anterior + + if "posterior_bodyparts" in kwargs: + posterior = [ + bp for bp in kwargs["posterior_bodyparts"] if bp in use_bodyparts + ] + kwargs["posterior_bodyparts"] = posterior + kpms_dj_config_dict.update(kwargs) - if "skeleton" not in kpms_dj_config_dict or kpms_dj_config_dict["skeleton"] is None: + if "skeleton" not in kpms_dj_config_dict: kpms_dj_config_dict["skeleton"] = [] with open(kpms_dj_config_path, "w") as f: - yaml.safe_dump(kpms_dj_config_dict, f, sort_keys=False) + yaml.safe_dump( + kpms_dj_config_dict, + f, + sort_keys=False, + default_flow_style=False, + allow_unicode=True, + ) return ( kpms_dj_config_path, @@ -173,13 +199,10 @@ def load_kpms_dj_config( """ import jax.numpy as jnp - # Validate input parameters if kpms_project_dir is None and config_path is None: - raise ValueError("Either 'kpms_project_dir' or 'config_path' must be provided.") + raise ValueError("Either 'kpms_project_dir' or 'config_path' must be provided") if kpms_project_dir is not None and config_path is not None: - raise ValueError( - "Cannot provide both 'kpms_project_dir' and 'config_path'. Choose one." - ) + raise ValueError("Cannot provide both 'kpms_project_dir' and 'config_path'") # Determine the config file path if config_path is not None: @@ -188,9 +211,7 @@ def load_kpms_dj_config( kpms_dj_cfg_path = _kpms_dj_config_path(kpms_project_dir) if not Path(kpms_dj_cfg_path).exists(): - raise FileNotFoundError( - f"Missing DJ config at {kpms_dj_cfg_path}. Create it with dj_generate_config()." - ) + raise FileNotFoundError(f"Missing DJ config at {kpms_dj_cfg_path}") with open(kpms_dj_cfg_path, "r") as f: cfg_dict = yaml.safe_load(f) or {} @@ -202,17 +223,28 @@ def load_kpms_dj_config( anterior = cfg_dict.get("anterior_bodyparts", []) posterior = cfg_dict.get("posterior_bodyparts", []) use_bps = cfg_dict.get("use_bodyparts", []) - cfg_dict["anterior_idxs"] = jnp.array([use_bps.index(bp) for bp in anterior]) - cfg_dict["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) - if "skeleton" not in cfg_dict or cfg_dict["skeleton"] is None: + valid_anterior = [bp for bp in anterior if bp in use_bps] + valid_posterior = [bp for bp in posterior if bp in use_bps] + + cfg_dict["anterior_idxs"] = jnp.array( + [use_bps.index(bp) for bp in valid_anterior] + ) + cfg_dict["posterior_idxs"] = jnp.array( + [use_bps.index(bp) for bp in valid_posterior] + ) + + if "skeleton" not in cfg_dict: cfg_dict["skeleton"] = [] return cfg_dict def update_kpms_dj_config( - kpms_project_dir: str = None, config_dict: Dict[str, Any] = None, **kwargs + kpms_project_dir: str = None, + config_dict: Dict[str, Any] = None, + config_path: str = None, + **kwargs, ) -> Dict[str, Any]: """ Update kpms_dj_config with provided kwargs. @@ -233,32 +265,57 @@ def update_kpms_dj_config( If kpms_project_dir is provided, loads the config from file, updates it, saves it back, and returns it. If config_dict is provided, updates it directly and returns it (no file I/O). """ - # Validate input parameters + if kpms_project_dir is None and config_dict is None: - raise ValueError("Either 'kpms_project_dir' or 'config_dict' must be provided.") - if kpms_project_dir is not None and config_dict is not None: - raise ValueError( - "Cannot provide both 'kpms_project_dir' and 'config_dict'. Choose one." - ) + raise ValueError("Either 'kpms_project_dir' or 'config_dict' must be provided") - # Load from file if kpms_project_dir is provided if kpms_project_dir is not None: kpms_dj_cfg_path = _kpms_dj_config_path(kpms_project_dir) if not Path(kpms_dj_cfg_path).exists(): - raise FileNotFoundError( - f"Missing DJ config at {kpms_dj_cfg_path}. Create it with dj_generate_config()." - ) + raise FileNotFoundError(f"Missing DJ config at {kpms_dj_cfg_path}") with open(kpms_dj_cfg_path, "r") as f: cfg_dict = yaml.safe_load(f) or {} + if "bodyparts" in kwargs: + cfg_dict["bodyparts"] = list(kwargs.get("bodyparts")) + + if "use_bodyparts" in kwargs: + use_bodyparts = list(kwargs.get("use_bodyparts")) + cfg_dict["use_bodyparts"] = use_bodyparts + # NOTE: skeleton is NOT modified - it remains from the base config + cfg_dict.update(kwargs) with open(kpms_dj_cfg_path, "w") as f: - yaml.safe_dump(cfg_dict, f, sort_keys=False) + yaml.safe_dump( + cfg_dict, + f, + sort_keys=False, + default_flow_style=False, + allow_unicode=True, + ) else: - # Update the provided dict directly (no file I/O) - cfg_dict = config_dict.copy() # Make a copy to avoid mutating the input + cfg_dict = config_dict.copy() + + if "bodyparts" in kwargs: + cfg_dict["bodyparts"] = list(kwargs.get("bodyparts")) + + if "use_bodyparts" in kwargs: + use_bodyparts = list(kwargs.get("use_bodyparts")) + cfg_dict["use_bodyparts"] = use_bodyparts + # NOTE: skeleton is NOT modified - it remains from the base config + cfg_dict.update(kwargs) + if config_path is not None: + with open(config_path, "w") as f: + yaml.safe_dump( + cfg_dict, + f, + sort_keys=False, + default_flow_style=False, + allow_unicode=True, + ) + return cfg_dict From 865dcbcef16e6d70327b5a5b675c264d25f9074f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 04:21:55 +0100 Subject: [PATCH 30/41] refactor(moseq_infer) --- element_moseq/moseq_infer.py | 70 ++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 5cd9b19..e7e9f10 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -85,7 +85,7 @@ class Model(dj.Manual): --- model_name : varchar(1000) # User-friendly model name model_dir : varchar(1000) # Model directory relative to root data directory - model_file : filepath@moseq-infer-processed # Model pkl file containing states, parameters, hyperparameters, noise prior, and random seed + model_file : filepath@moseq-infer-processed # Checkpoint file (h5 format) model_desc='' : varchar(1000) # Optional. User-defined description of the model -> [nullable] moseq_train.SelectedFullFit # Optional. FullFit key. """ @@ -225,7 +225,7 @@ def make_fetch(self, key): model_key = (Model * moseq_train.SelectedFullFit & key).fetch1("KEY") checkpoint_file_path = ( moseq_train.FullFit.File & model_key & 'file_name="checkpoint.h5"' - ).fetch1("file") + ).fetch1("file_path") kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( "config_file" ) @@ -318,48 +318,46 @@ def make_compute( # Construct the full path to the inference output directory inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir - inference_output_dir.mkdir(parents=True, exist_ok=True) - model_dir = find_full_path(kpms_processed, model_dir_rel) - keypointset_dir = find_full_path(kpms_root, keypointset_dir) + if task_mode == "trigger": + if not inference_output_dir.exists(): + inference_output_dir.mkdir(parents=True, exist_ok=True) - kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_file - ) + keypointset_dir = find_full_path(kpms_root, keypointset_dir) - metadata = pickle.load(open(metadata_file_path, "rb")) - data = pickle.load(open(data_file_path, "rb")) - model_data = pickle.load(open(model_file, "rb")) if task_mode == "trigger": - results = apply_model( - model_name=inference_output_dir.name, - model=model_data, - data=data, - metadata=metadata, - pca=pca_file_path, - project_dir=inference_output_dir.parent, - results_path=(inference_output_dir / "results.h5"), - return_model=False, - num_iters=num_iterations or DEFAULT_NUM_ITERS, - overwrite=True, - save_results=True, - **kpms_dj_config_dict, + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_file ) - # Create results directory and save CSV files - save_results_as_csv( - results=results, - save_dir=(inference_output_dir / "results_as_csv").as_posix(), - ) + metadata = pickle.load(open(metadata_file_path, "rb")) + data = pickle.load(open(data_file_path, "rb")) + model_data = pickle.load(open(model_file, "rb")) + if task_mode == "trigger": + results = apply_model( + model_name=inference_output_dir.name, + model=model_data, + data=data, + metadata=metadata, + pca=pca_file_path, + project_dir=inference_output_dir.parent, + results_path=(inference_output_dir / "results.h5"), + return_model=False, + num_iters=num_iterations or DEFAULT_NUM_ITERS, + overwrite=True, + save_results=True, + **kpms_dj_config_dict, + ) - end_time = datetime.now(timezone.utc) - duration_seconds = (end_time - start_time).total_seconds() + # Create results directory and save CSV files + save_results_as_csv( + results=results, + save_dir=(inference_output_dir / "results_as_csv").as_posix(), + ) + + end_time = datetime.now(timezone.utc) + duration_seconds = (end_time - start_time).total_seconds() else: - # For load mode - results = load_results( - project_dir=model_dir, - model_name=model_dir.name, - ) duration_seconds = None results_filepath = (inference_output_dir / "results.h5").as_posix() From ba2d3c44f4d0c017c8a57c7df53ae05e2b65c3ad Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 04:53:36 +0100 Subject: [PATCH 31/41] refactor(moseq_train): add new table & minor fixes --- element_moseq/moseq_train.py | 546 ++++++++++++++++++++++++++--------- 1 file changed, 408 insertions(+), 138 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 5d324ef..feb60bc 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -5,6 +5,7 @@ import importlib import inspect +import os import pickle from datetime import datetime, timezone from pathlib import Path @@ -14,6 +15,8 @@ import datajoint as dj import matplotlib.pyplot as plt import numpy as np +import pandas as pd +import yaml from element_interface.utils import find_full_path from .plotting import viz_utils @@ -249,9 +252,8 @@ class PreProcessing(dj.Computed): Attributes: PCATask (foreign key) : Unique ID for each `PCATask` key. - formatted_bodyparts (longblob) : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - coordinates (longblob) : Cleaned coordinates dictionary {recording_name: array} after outlier removal. - confidences (longblob) : Cleaned confidences dictionary {recording_name: array} after outlier removal. + coordinates (longblob) : Cleaned coordinates dictionary after outlier removal. + confidences (longblob) : Cleaned confidences dictionary after outlier removal. average_frame_rate (int) : Average frame rate of the videos for model training (used for kappa calculation). pre_processing_time (datetime) : datetime of the preprocessing execution. pre_processing_duration (int) : Execution time of the preprocessing in seconds. @@ -260,9 +262,8 @@ class PreProcessing(dj.Computed): definition = """ -> PCATask # Unique ID for each `PCATask` key --- - formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - coordinates : longblob # Cleaned coordinates dictionary (recording_name: array) after outlier removal - confidences : longblob # Cleaned confidences dictionary (recording_name: array) after outlier removal + coordinates : longblob # Cleaned coordinates dictionary after outlier removal. + confidences : longblob # Cleaned confidences dictionary after outlier removal. average_frame_rate : int # Average frame rate of the videos for model training (used for kappa calculation). pre_processing_time=NULL : datetime # datetime of the preprocessing execution. pre_processing_duration=NULL : int # Execution time of the preprocessing in seconds. @@ -278,16 +279,6 @@ class Video(dj.Part): file_size : float # File size of the video in megabytes (MB) """ - class OutlierRemoval(dj.Part): - """Store outlier detection QA plots per video.""" - - definition = """ - -> master - video_id : varchar(255) - --- - outlier_plot: attach # QA visualization showing detected outliers and interpolation. - """ - class ConfigFile(dj.Part): """ Store the configuration files (first creation of the config file and the updates after processing). @@ -319,14 +310,6 @@ def make_fetch(self, key): PCATask & key ).fetch1("kpms_project_output_dir", "task_mode", "outlier_scale_factor") - # Handle default kpms_project_output_dir if not provided - if not kpms_project_output_dir: - kpms_project_output_dir = PCATask.infer_output_dir( - key, relative=True, mkdir=True - ) - # Update the kpms_project_output_dir in the database - PCATask.update1({**key, "kpms_project_output_dir": kpms_project_output_dir}) - return ( anterior_bodyparts, posterior_bodyparts, @@ -384,41 +367,57 @@ def make_compute( 6. Detect and remove outlier keypoints using medoid distance analysis, then interpolate missing values. 7. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. These two attributes can be used to calculate the kappa value. """ + import jax from keypoint_moseq import ( find_medoid_distance_outliers, interpolate_keypoints, load_keypoints, + outlier_removal, + overlay_keypoints_on_video, ) from .plotting.viz_utils import plot_medoid_distance_outliers + jax.config.update("jax_enable_x64", True) + execution_time = datetime.now(timezone.utc) + if task_mode == "trigger": from keypoint_moseq import setup_project - # check if the project output directory exists - try: - kpms_project_output_dir = find_full_path( - get_kpms_processed_data_dir(), kpms_project_output_dir + # Resolve kpms_project_output_dir to absolute and create it if needed + if not kpms_project_output_dir: + kpms_project_output_dir = PCATask.infer_output_dir( + key, relative=True, mkdir=True ) - # if the project output directory does not exist, create it - except FileNotFoundError: - kpms_project_output_dir = ( - Path(get_kpms_processed_data_dir()) / kpms_project_output_dir + PCATask.update1( + {**key, "kpms_project_output_dir": kpms_project_output_dir} ) + kpms_project_output_dir = ( + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir + ) + + # Resolve kpset_dir to absolute and check if it exists kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) - # Setup of the project creates KPMS base `config.yml` file copying the pose estimation config file - from .readers.kpms_reader import _pose_estimation_config_path + if not kpset_dir.exists(): + raise FileNotFoundError( + f"No keypoint set directory found in {kpset_dir}" + ) - pose_estimation_config_file = Path(_pose_estimation_config_path(kpset_dir)) + # Find the pose estimation config file + pose_estimation_config_file = Path( + kpms_reader._pose_estimation_config_path(kpset_dir) + ) if not pose_estimation_config_file.exists(): raise FileNotFoundError( - f"No `config.yml` or `config.yaml` file found in {kpset_dir}" + f"No config file (`config.yml` or `config.yaml`) found in {kpset_dir}" ) + if pose_estimation_method == "deeplabcut": setup_project( project_dir=kpms_project_output_dir.as_posix(), deeplabcut_config=pose_estimation_config_file.as_posix(), + overwrite=True, # Allow regenerating config if directory exists ) else: raise NotImplementedError( @@ -442,7 +441,35 @@ def make_compute( f"use_bodyparts ({use_bodyparts}) is not a subset of formatted bodyparts ({formatted_bodyparts})" ) - # Extract frame rate and file size from keypoint video files + # Use only the bodyparts in `use_bodyparts` + use_bodypart_indices = [formatted_bodyparts.index(bp) for bp in use_bodyparts] + filtered_coordinates = {} + filtered_confidences = {} + for recording_name in raw_coordinates: + filtered_coordinates[recording_name] = raw_coordinates[recording_name][ + :, use_bodypart_indices, : + ] + filtered_confidences[recording_name] = raw_confidences[recording_name][ + :, use_bodypart_indices + ] + + # Sanity check of the number of features (dimensions) in the filtered_coordinates + num_bodyparts = len(use_bodyparts) + num_features = num_bodyparts * 2 # (x, y) coordinates for each bodypart + for recording_name in filtered_coordinates: + coords_shape = filtered_coordinates[ + recording_name + ].shape # Typically (frames, bodyparts, 2) + actual_features = ( + coords_shape[1] * coords_shape[2] + ) # bodyparts * 2 (for x/y) + if actual_features != num_features: + raise ValueError( + f"Feature mismatch in {recording_name}: " + f"expected {num_features}, got {actual_features}" + ) + + # Get frame rates and file sizes for each video video_metadata_dict = dict() frame_rates = [] for row in keypoint_videofile_metadata: @@ -470,6 +497,7 @@ def make_compute( "outlier_plot": None, } average_frame_rate = int(round(np.mean(frame_rates))) + # Get all unique parent directories for all video files parent_dirs = { Path(video["video_path"]).parent for video in keypoint_videofile_metadata @@ -484,7 +512,21 @@ def make_compute( Path(keypoint_videofile_metadata[0]["video_path"]).parent, ) - # Generate a new KPMS DJ config file copying the KPMS base config file in the same kpms project output directory + # Filter anterior/posterior to only include those present in use_bodyparts + filtered_anterior = [bp for bp in anterior_bodyparts if bp in use_bodyparts] + filtered_posterior = [bp for bp in posterior_bodyparts if bp in use_bodyparts] + + # Sanity check: Ensure all filtered anterior/posterior bodyparts are in use_bodyparts + if not set(filtered_anterior).issubset(set(use_bodyparts)): + raise ValueError( + f"Filtered anterior bodyparts contain elements not in use_bodyparts: {set(filtered_anterior) - set(use_bodyparts)}" + ) + if not set(filtered_posterior).issubset(set(use_bodyparts)): + raise ValueError( + f"Filtered posterior bodyparts contain elements not in use_bodyparts: {set(filtered_posterior) - set(use_bodyparts)}" + ) + + # Generate KPMS DJ config file with all new parameters ( kpms_dj_config_path, kpms_dj_config_dict, @@ -494,9 +536,10 @@ def make_compute( kpms_project_dir=kpms_project_output_dir, video_dir=str(videos_dir), use_bodyparts=list(use_bodyparts), - anterior_bodyparts=list(anterior_bodyparts), - posterior_bodyparts=list(posterior_bodyparts), + anterior_bodyparts=filtered_anterior, + posterior_bodyparts=filtered_posterior, outlier_scale_factor=float(outlier_scale_factor), + fps=average_frame_rate, # Pass fps directly to avoid redundant update ) # Get absolute paths for attach fields @@ -507,52 +550,29 @@ def make_compute( get_kpms_processed_data_dir(), kpms_base_config_path ) - # Update the KPMS DJ config file on disk with the average frame rate - kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( - kpms_project_dir=kpms_project_output_dir, fps=average_frame_rate - ) - - # Clean outlier keypoints and generate plots - cleaned_coordinates = {} - cleaned_confidences = {} + logger.info("Starting outlier removal...") - for row in keypoint_videofile_metadata: - video_id = int(row["video_id"]) - pose_estimation_path = row["pose_estimation_path"] - pose_estimation_name = Path(pose_estimation_path).stem - raw_coords = raw_coordinates[pose_estimation_name].copy() - raw_conf = raw_confidences[pose_estimation_name].copy() - outliers = find_medoid_distance_outliers( - raw_coords, outlier_scale_factor=outlier_scale_factor - ) - cleaned_coords = interpolate_keypoints(raw_coords, outliers["mask"]) - cleaned_conf = np.where(outliers["mask"], 0, raw_conf) - cleaned_coordinates[pose_estimation_name] = cleaned_coords - cleaned_confidences[pose_estimation_name] = cleaned_conf - _, outlier_plot_path = plot_medoid_distance_outliers( - project_dir=kpms_project_output_dir.as_posix(), - recording_name=pose_estimation_name, - original_coordinates=raw_coords, - interpolated_coordinates=cleaned_coords, - outlier_mask=outliers["mask"], - outlier_thresholds=outliers["thresholds"], - **kpms_dj_config_dict, - ) # outlier plot stored at kpms_project_output_dir/QA/plots/keypoint_distance_outliers/f"{pose_estimation_name}.png - video_metadata_dict[video_id] = { - **video_metadata_dict[video_id], - "outlier_plot_path": outlier_plot_path, - } + # Apply outlier removal to all recordings at once + cleaned_coordinates, cleaned_confidences = outlier_removal( + coordinates=filtered_coordinates, + confidences=filtered_confidences, + project_dir=kpms_project_output_dir.as_posix(), + overwrite=False, + outlier_scale_factor=outlier_scale_factor, + bodyparts=list(use_bodyparts), + ) + logger.info("...Outlier removal completed") completion_time = datetime.now(timezone.utc) - if task_mode == "trigger": - duration_seconds = (completion_time - execution_time).total_seconds() - else: - duration_seconds = None + duration_seconds = ( + (completion_time - execution_time).total_seconds() + if task_mode == "trigger" + else None + ) return ( cleaned_coordinates, cleaned_confidences, - formatted_bodyparts, average_frame_rate, video_metadata_dict, kpms_dj_config_path, @@ -566,7 +586,6 @@ def make_insert( key, cleaned_coordinates, cleaned_confidences, - formatted_bodyparts, average_frame_rate, video_metadata_dict, kpms_dj_config_path, @@ -582,7 +601,6 @@ def make_insert( self.insert1( { **key, - "formatted_bodyparts": formatted_bodyparts, "coordinates": cleaned_coordinates, "confidences": cleaned_confidences, "average_frame_rate": average_frame_rate, @@ -606,18 +624,6 @@ def make_insert( ] ) - # Insert outlier removal QA plots - if video_metadata_dict: - self.OutlierRemoval.insert( - [ - { - **key, - "video_id": vid, - "outlier_plot": meta["outlier_plot_path"], - } - for vid, meta in video_metadata_dict.items() - ] - ) # Insert configuration files self.ConfigFile.insert1( { @@ -628,6 +634,260 @@ def make_insert( ) +@schema +class PreProcessingQA(dj.Computed): + """ + Check if any bodyparts have a high proportion of NaNs and generate and store QA materials (outlier removal plots and overlay videos). + Attributes: + PreProcessing (foreign key) : `PreProcessing` Key. + nan_df (longblob) : DataFrame containing NaN proportion breakdown by bodypart. + """ + + definition = """ + -> PreProcessing # `PreProcessing` Key + --- + nan_breakdown : attach # HTML table containing NaN proportion breakdown by bodypart + qa_duration : float # Duration of QA in seconds + """ + + class VideoQA(dj.Part): + """Store QA materials: outlier plots and overlay videos.""" + + definition = """ + -> master + video_id : varchar(255) + --- + outlier_plot : attach # QA visualization showing detected outliers and interpolation. + overlay_video : attach # Overlay keypoints on the video attachment. + """ + + def make_fetch(self, key): + """ + Fetch required data for QA processing from database tables. + """ + use_bodyparts = (BodyParts & key).fetch1("use_bodyparts") + coordinates = (PreProcessing & key).fetch1("coordinates") + kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir") + kpms_project_output_dir = find_full_path( + get_kpms_processed_data_dir(), kpms_project_output_dir + ) + kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + keypoint_videofile_metadata = (KeypointSet.VideoFile & key).fetch(as_dict=True) + fps_lookup = (PreProcessing.Video & key).fetch( + "video_id", "frame_rate", as_dict=True + ) + return ( + use_bodyparts, + coordinates, + kpms_project_output_dir, + kpms_dj_config_path, + keypoint_videofile_metadata, + fps_lookup, + ) + + def make_compute( + self, + key, + use_bodyparts, + coordinates, + kpms_project_output_dir, + kpms_dj_config_path, + keypoint_videofile_metadata, + fps_lookup, + ): + """ + Compute QA materials including NaN analysis and overlay video generation. + + Args: + key (dict): Primary key from the `PreProcessing` table. + coordinates (dict): Cleaned coordinates dictionary. + kpms_project_output_dir (Path): Project output directory path. + kpms_dj_config_dict (dict): KPMS configuration dictionary. + keypoint_videofile_metadata (list): Video metadata list. + + Returns: + tuple: QA data including breakdown data, QA materials, and duration. + """ + from keypoint_moseq import overlay_keypoints_on_video + + fps_lookup = {v["video_id"]: float(v["frame_rate"]) for v in fps_lookup} + + kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_path, + build_indexes=True, + ) + + execution_time = datetime.now(timezone.utc) + + # Calculate NaN proportions breakdown for each recording and bodypart + # Replicating keypoint-moseq's check_nan_proportions logic + keys = sorted(coordinates.keys()) + nan_props = [np.isnan(coordinates[k]).any(-1).mean(0) for k in keys] + + # Create the DataFrame for HTML table + nan_df = pd.DataFrame(data=nan_props, index=keys, columns=use_bodyparts) + nan_styler = ( + nan_df.style.background_gradient(cmap="RdYlBu_r", axis=None) + .format("{:.1%}") + .set_caption( + '

NaN Proportion Breakdown

' + ) + .set_table_styles( + [ + { + "selector": "th", + "props": [ + ("font-size", "11pt"), + ("background-color", "#F2F2F2"), + ("color", "#222"), + ("font-weight", "bold"), + ("padding", "8px"), + ], + }, + { + "selector": "td", + "props": [ + ("font-size", "10pt"), + ("padding", "6px 12px"), + ("text-align", "center"), + ], + }, + { + "selector": "caption", + "props": [ + ("caption-side", "top"), + ("font-size", "14pt"), + ("color", "#29487d"), + ("font-weight", "bold"), + ("margin-bottom", "12px"), + ], + }, + { + "selector": "", + "props": [ + ("border-collapse", "collapse"), + ("margin", "25px auto"), + ], + }, + ] + ) + .set_properties(**{"border": "1px solid #ddd"}) + ) + + # Save HTML to temporary file for DataJoint attach + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + f.write(nan_styler.to_html()) + nan_html_path = f.name + + # Generate QA materials (plots and videos) for each recording + qa_data = [] + for row in keypoint_videofile_metadata: + video_id = int(row["video_id"]) + pose_estimation_path = row["pose_estimation_path"] + pose_estimation_name = Path(pose_estimation_path).stem + + # Construct the expected plot path (already generated by outlier_removal) + qa_dir = ( + kpms_project_output_dir / "QA" / "plots" / "keypoint_distance_outliers" + ) + outlier_plot_path = qa_dir / f"{pose_estimation_name}.png" + + # Create overlay video output path: project_dir/QA/videos/overlay_keypoints/ + overlay_video_dir = ( + kpms_project_output_dir / "QA" / "videos" / "overlay_keypoints" + ) + overlay_video_dir.mkdir(parents=True, exist_ok=True) + overlay_video_path = ( + overlay_video_dir / f"{pose_estimation_name}_overlay.mp4" + ) + + # Get the full path to the video file + video_file_path = find_full_path( + get_kpms_root_data_dir(), row["video_path"] + ) + + # Generate overlay video for this specific recording (skip if already exists) + if not overlay_video_path.exists(): + # Calculate frames for 1 minute of video + frame_rate = fps_lookup.get( + video_id, 30.0 + ) # TODO: Default to 30fps if not found + frames_for_dur = int(frame_rate * 6) + + logger.info( + f"Processing video {video_id}: {frame_rate}fps -> {frames_for_dur} frames for 1min" + ) + + overlay_keypoints_on_video( + video_path=video_file_path.as_posix(), + coordinates=coordinates[ + pose_estimation_name + ], # Pass coordinates for this specific recording + skeleton=kpms_dj_config_dict["skeleton"], + bodyparts=list(use_bodyparts), + output_path=overlay_video_path.as_posix(), + frames=range(frames_for_dur), + ) + logger.info(f"Generated overlay video: {overlay_video_path}") + else: + logger.info( + f"Overlay video already exists, skipping: {overlay_video_path}" + ) + + qa_data.append( + { + "video_id": video_id, + "outlier_plot_path": outlier_plot_path, + "overlay_video_path": overlay_video_path, + } + ) + + completion_time = datetime.now(timezone.utc) + duration_seconds = (completion_time - execution_time).total_seconds() + + return ( + nan_html_path, + qa_data, + duration_seconds, + ) + + def make_insert( + self, + key, + nan_html_path, + qa_data, + duration_seconds, + ): + """ + Insert QA data into the PreProcessingQA table and part tables. + """ + # Insert in the main table + self.insert1( + { + **key, + "nan_breakdown": nan_html_path, + "qa_duration": duration_seconds, + } + ) + + # Insert QA materials (plots and videos) + if qa_data: + self.VideoQA.insert( + [ + { + **key, + "video_id": data["video_id"], + "outlier_plot": data["outlier_plot_path"], + "overlay_video": data["overlay_video_path"], + } + for data in qa_data + ] + ) + + @schema class PCAFit(dj.Computed): """Fit Principal Component Analysis (PCA) model for dimensionality reduction of keypoint data. @@ -670,9 +930,10 @@ def make(self, key): 3. Fit PCA model and save as `pca.p` file. 4. Insert creation datetime into table. """ - import tempfile + import jax + from keypoint_moseq import fit_pca, format_data, load_pca, save_pca - from keypoint_moseq import fit_pca, format_data, save_pca + jax.config.update("jax_enable_x64", True) kpms_project_output_dir, task_mode = (PCATask & key).fetch1( "kpms_project_output_dir", "task_mode" @@ -680,28 +941,30 @@ def make(self, key): kpms_project_output_dir = ( Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) + use_bodyparts = (BodyParts & key).fetch1("use_bodyparts") coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - # Load the configuration from database + # Load config file kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_path + config_path=kpms_dj_config_path, + build_indexes=True, ) execution_time = datetime.now(timezone.utc) # Format keypoint data data, metadata = format_data( - **kpms_dj_config_dict, coordinates=coordinates, confidences=confidences + coordinates=coordinates, + confidences=confidences, + use_bodyparts=use_bodyparts, ) - # Save data and metadata as pickle files - data_filename = "data.pkl" - metadata_filename = "metadata.pkl" - data_path = kpms_project_output_dir / data_filename - metadata_path = kpms_project_output_dir / metadata_filename + # Save data and metadata to pickle files + data_path = kpms_project_output_dir / "data.pkl" + metadata_path = kpms_project_output_dir / "metadata.pkl" with open(data_path, "wb") as f: pickle.dump(data, f) with open(metadata_path, "wb") as f: @@ -712,10 +975,8 @@ def make(self, key): pca = fit_pca(**data, **kpms_dj_config_dict) save_pca(pca, kpms_project_output_dir.as_posix()) - pca_filename = "pca.p" - pca_path = kpms_project_output_dir / pca_filename - - # Check for pca.p file + # Check if pca file exists + pca_path = kpms_project_output_dir / "pca.p" if not pca_path.exists(): raise FileNotFoundError( f"No pca file (`pca.p`) found in the project directory {kpms_project_output_dir}" @@ -723,13 +984,12 @@ def make(self, key): file_paths = [pca_path, data_path, metadata_path] completion_time = datetime.now(timezone.utc) + duration_seconds = ( + (completion_time - execution_time).total_seconds() + if task_mode == "trigger" + else None + ) - if task_mode == "trigger": - duration_seconds = (completion_time - execution_time).total_seconds() - else: - duration_seconds = None - - # Insert in the main table self.insert1( { **key, @@ -828,10 +1088,10 @@ def make(self, key): variance_percentage = VARIANCE_THRESHOLD * 100 latent_dim_desc = f">={VARIANCE_THRESHOLD*100}% of variance explained by {(cs>VARIANCE_THRESHOLD).nonzero()[0].min()+1} components." - # Load the configuration from database + # Load config file kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_path + config_path=kpms_dj_config_path, build_indexes=False ) # Generate scree plot @@ -860,13 +1120,11 @@ def make(self, key): f"No pcs xy file (`pcs-xy.pdf`) found in the project directory {kpms_project_output_dir}" ) - # Save plots to temporary directory + # Save plots tmpdir = tempfile.TemporaryDirectory() fname = f"{key['kpset_id']}_{key['bodyparts_id']}" - scree_path = Path(tmpdir.name) / f"{fname}_scree_plot.png" scree_fig.savefig(scree_path) - pcs_path = Path(tmpdir.name) / f"{fname}_pcs_plot.png" pcs_fig.savefig(pcs_path) @@ -913,9 +1171,9 @@ class PreFitTask(dj.Manual): pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- - model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(1000) # Optional. Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('trigger','load') # 'load': load computed analysis results, 'trigger': trigger computation - pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task + pre_fit_desc='' : varchar(1000) # Optional.User-defined description of the pre-fitting task """ @@ -957,14 +1215,12 @@ class File(dj.Part): -> master file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5', 'model_data.pkl'). --- - file : filepath@moseq-train-processed # Path to the file in the processed data directory. + file_path : filepath@moseq-train-processed # Path to the file in the processed data directory. """ class Plots(dj.Part): """ - Store the fitting progress of the PreFit computation: - - Plots in PDF and PNG formats used for visualization. - - Checkpoint file used for resuming the fitting process (~500MB). + Store the fitting progress of the PreFit computation. """ definition = """ @@ -1003,7 +1259,6 @@ def make(self, key): get_kpms_processed_data_dir(), (PCATask & key).fetch1("kpms_project_output_dir"), ) - pre_latent_dim, pre_kappa, pre_num_iterations, task_mode, model_name = ( PreFitTask & key ).fetch1( @@ -1013,10 +1268,17 @@ def make(self, key): "task_mode", "model_name", ) + if task_mode == "trigger": + # Configure JAX precision + import jax + + jax.config.update("jax_enable_x64", True) from keypoint_moseq import estimate_sigmasq_loc - kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") + kpms_dj_config_abs_path = (PreProcessing.ConfigFile & key).fetch1( + "config_file" + ) pca_path = (PCAFit.File & key & 'file_name="pca.p"').fetch1("file_path") pca = load_pca(Path(pca_path).parent.as_posix()) coordinates, confidences = (PreProcessing & key).fetch1( @@ -1031,17 +1293,25 @@ def make(self, key): average_frame_rate = (PreProcessing & key).fetch1("average_frame_rate") # Update kpms_dj_config file in disk with new latent_dim and kappa values - kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_path + kpms_dj_config_dict_for_save = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_abs_path, build_indexes=False ) - kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( - config_dict=kpms_dj_config_dict, - latent_dim=pre_latent_dim, - kappa=pre_kappa, - sigmasq_loc=estimate_sigmasq_loc( - data["Y"], data["mask"], filter_size=average_frame_rate + # Update and save config to disk + kpms_dj_config_dict_for_save = kpms_reader.update_kpms_dj_config( + config_dict=kpms_dj_config_dict_for_save, + config_path=kpms_dj_config_abs_path, + latent_dim=int(pre_latent_dim), + kappa=float(pre_kappa), + sigmasq_loc=float( + estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=average_frame_rate + ) ), ) + # Load config with indexes for model initialization + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_abs_path, build_indexes=True + ) # Initialize the model model = init_model( @@ -1134,7 +1404,7 @@ def make(self, key): { **key, "file_name": file.name, - "file": file.as_posix(), + "file_path": file.as_posix(), } for file in file_paths ] @@ -1176,9 +1446,9 @@ class FullFitTask(dj.Manual): full_kappa : int # Kappa value to use for the model full fitting (typically lower than pre-fit kappa). full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). --- - model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(1000) # Optional. Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# Trigger or load the task - full_fit_desc='' : varchar(1000) # User-defined description of the model full fitting task + full_fit_desc='' : varchar(1000) # Optional.User-defined description of the model full fitting task """ @@ -1195,7 +1465,7 @@ class FullFit(dj.Computed): definition = """ -> FullFitTask # `FullFitTask` Key --- - model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name". + model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name". full_fit_time=NULL : datetime # datetime of the full fitting computation. full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation. """ From 8f0df905946a55feb4d062e01dd00046021c6875 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 04:57:03 +0100 Subject: [PATCH 32/41] refactor(moseq_train) --- element_moseq/moseq_train.py | 188 +++++++++++++++++++++++------------ 1 file changed, 126 insertions(+), 62 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index feb60bc..7ec744e 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -19,6 +19,11 @@ import yaml from element_interface.utils import find_full_path +# Configure JAX for better compatibility with DataJoint/DeepHash +os.environ["JAX_ENABLE_X64"] = "False" +os.environ["JAX_ARRAY"] = "False" # Use legacy array API for better compatibility +os.environ["JAX_DYNAMIC_SHAPES"] = "False" + from .plotting import viz_utils from .readers import kpms_reader @@ -1490,7 +1495,7 @@ class File(dj.Part): -> master file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5', 'model_data.pkl'). --- - file : filepath@moseq-train-processed # Path to the file in the processed data directory. + file_path : filepath@moseq-train-processed # Path to the file in the processed data directory. """ class Plots(dj.Part): @@ -1523,6 +1528,10 @@ def make(self, key): 5. Reindex syllable labels by frequency. 6. Calculate fitting duration and insert results. """ + import jax + + jax.config.update("jax_enable_x64", True) + from keypoint_moseq import ( estimate_sigmasq_loc, fit_model, @@ -1548,6 +1557,10 @@ def make(self, key): ) if task_mode == "trigger": + import pickle + + from keypoint_moseq import load_checkpoint + pca_path = (PCAFit.File & key & 'file_name="pca.p"').fetch1("file_path") pca = load_pca(Path(pca_path).parent.as_posix()) coordinates, confidences = (PreProcessing & key).fetch1( @@ -1561,73 +1574,131 @@ def make(self, key): metadata = pickle.load(open(metadata_path, "rb")) average_frame_rate = (PreProcessing & key).fetch1("average_frame_rate") - kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file") - kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_path + kpms_dj_config_abs_path = (PreProcessing.ConfigFile & key).fetch1( + "config_file" ) - # Update kpms_dj_config file in disk with new latent_dim and kappa values - kpms_dj_config_dict = kpms_reader.update_kpms_dj_config( - config_dict=kpms_dj_config_dict, - latent_dim=full_latent_dim, - kappa=full_kappa, - sigmasq_loc=estimate_sigmasq_loc( + + kpms_dj_config_dict_for_save = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_abs_path, build_indexes=False + ) + sigmasq_loc_val = float( + estimate_sigmasq_loc( data["Y"], data["mask"], filter_size=average_frame_rate - ), + ) ) - - # Initialize the model - model = init_model( - data=data, metadata=metadata, pca=pca, **kpms_dj_config_dict + kpms_dj_config_dict_for_save = kpms_reader.update_kpms_dj_config( + config_dict=kpms_dj_config_dict_for_save, + config_path=kpms_dj_config_abs_path, + latent_dim=int(full_latent_dim), + kappa=float(full_kappa), + sigmasq_loc=sigmasq_loc_val, ) - # Update the model hyperparameters - model = update_hypparams( - model, - kappa=float(full_kappa.item()), - latent_dim=int(full_latent_dim.item()), + + kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=kpms_dj_config_abs_path, build_indexes=True ) + # Determine model directory name for outputs if model_name is None or not str(model_name).strip(): model_name = f"latent_dim_{full_latent_dim.item()}_kappa_{full_kappa.item()}_iters_{full_num_iterations.item()}" else: model_name = str(model_name) + # Try to load pre-fit model for the same latent_dim and kappa values + pre_model = None + # More optimal: check existence before fetching to avoid try/except + pre_model_key_query = ( + PreFitTask + & {"kpset_id": key["kpset_id"], "bodyparts_id": key["bodyparts_id"]} + & { + "pre_kappa": key["full_kappa"], + "pre_latent_dim": key["full_latent_dim"], + } + ) + if pre_model_key_query: + pre_model_key = pre_model_key_query.fetch1("KEY") + pre_model_file = ( + PreFit.File & pre_model_key & 'file_name="checkpoint.h5"' + ).fetch1("file_path") + with open(pre_model_file, "rb") as f: + pre_model = pickle.load(f) + logger.info( + f"Using PreFit model {pre_model_key_query} as warm start for FullFit" + ) + execution_time = datetime.now(timezone.utc) + + # Initialize model: Use PreFit if available, otherwise initialize fresh + try: + if pre_model is not None: + model_to_fit = pre_model + else: + # Only initialize fresh model if no PreFit available + model_to_fit = init_model( + data=data, metadata=metadata, pca=pca, **kpms_dj_config_dict + ) + # Update the model hyperparameters + model_to_fit = update_hypparams( + model_to_fit, + kappa=float(full_kappa.item()), + latent_dim=int(full_latent_dim.item()), + ) + except Exception as e: + raise ValueError(f"Model initialization failed: {e}") + # Fit the model - model, model_name = fit_model( - model=model, - model_name=model_name, - data=data, - metadata=metadata, - project_dir=kpms_project_output_dir.as_posix(), - ar_only=False, - num_iters=full_num_iterations, - generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ - save_every_n_iters=10, - ) + try: - # Reindex the syllables in the checkpoint file - reindex_syllables_in_checkpoint( - project_dir=kpms_project_output_dir.as_posix(), - model_name=model_name, - ) + model, model_name = fit_model( + model=model_to_fit, + model_name=model_name, + data=data, + metadata=metadata, + project_dir=kpms_project_output_dir.as_posix(), + ar_only=False, + num_iters=full_num_iterations, + generate_progress_plots=True, + save_every_n_iters=1, # TODO: to change to a higher value + verbose=False, + ) # checkpoint will be saved at project_dir/model_name + except Exception as e: + raise ValueError(f"FullFit training failed: {e}") + + try: + # Reindex the syllables in the checkpoint file + reindex_syllables_in_checkpoint( + project_dir=kpms_project_output_dir.as_posix(), + model_name=model_name, + ) + except Exception as e: + raise ValueError( + f"Reindexing syllables failed due to FullFit training failure: {e}" + ) # Create a PNG version fo the PDF progress plot - png_path, pdf_path = viz_utils.copy_pdf_to_png( - kpms_project_output_dir, model_name - ) - # Define model_name_full_path for checkpoint file search model_name_full_path = find_full_path(kpms_project_output_dir, model_name) + pdf_path = model_name_full_path / "fitting_progress.pdf" + png_path = model_name_full_path / "fitting_progress.png" + + if pdf_path.exists(): + png_path, pdf_path = viz_utils.copy_pdf_to_png( + kpms_project_output_dir, model_name + ) + else: + logger.warning(f"No progress PDF found at {pdf_path}") else: # Load mode must specify a model_name if model_name is None or not str(model_name).strip(): - raise ValueError("model_name is required when task_mode='load'") + raise ValueError("`model_name` is required when task_mode='load'") + model_name_full_path = find_full_path(kpms_project_output_dir, model_name) pdf_path = model_name_full_path / "fitting_progress.pdf" png_path = model_name_full_path / "fitting_progress.png" # Get the path to the updated config file - kpms_dj_config_path = kpms_reader._kpms_dj_config_path(kpms_project_output_dir) - + kpms_dj_config_abs_path = kpms_reader._kpms_dj_config_path( + kpms_project_output_dir + ) if not pdf_path.exists(): raise FileNotFoundError(f"PreFit PDF progress plot not found at {pdf_path}") if not png_path.exists(): @@ -1644,14 +1715,7 @@ def make(self, key): f"No checkpoint files found in {model_name_full_path}" ) - completion_time = datetime.now(timezone.utc) - - if task_mode == "trigger": - duration_seconds = (completion_time - execution_time).total_seconds() - else: - duration_seconds = None - - # Save model dictionary as pickle file + # Save model dictionary as pickle file in the model directory model_data_filename = "model_data.pkl" model_data_file = model_name_full_path / model_data_filename with open(model_data_file, "wb") as f: @@ -1659,6 +1723,13 @@ def make(self, key): file_paths = [checkpoint_file, model_data_file] + completion_time = datetime.now(timezone.utc) + duration_seconds = ( + (completion_time - execution_time).total_seconds() + if task_mode == "trigger" + else None + ) + self.insert1( { **key, @@ -1676,21 +1747,19 @@ def make(self, key): { **key, "file_name": file.name, - "file": file.as_posix(), + "file_path": file.as_posix(), } for file in file_paths ] ) - # Insert config file self.ConfigFile.insert1( { **key, - "config_file": kpms_dj_config_path, + "config_file": kpms_dj_config_abs_path, } ) - # Insert plots self.Plots.insert1( { **key, @@ -1711,7 +1780,6 @@ class ModelScore(dj.Computed): -> FullFit --- score=NULL : float # Model score (MLL for single model) - std_error=NULL : float # Standard error of the model score """ def make(self, key): @@ -1722,28 +1790,24 @@ def make(self, key): # Get checkpoint file for this specific model checkpoint_file = ( FullFit.File & key & 'file_name LIKE "%checkpoint.h5"' - ).fetch1("file") + ).fetch1("file_path") # Load the checkpoint to get model data model, data, _, _ = load_checkpoint(path=checkpoint_file) - # Compute marginal log likelihood for this model + # Compute marginal log likelihood for single model mask = jnp.array(data["mask"]) x = jnp.array(model["states"]["x"]) Ab = jnp.array(model["params"]["Ab"]) Q = jnp.array(model["params"]["Q"]) pi = jnp.array(model["params"]["pi"]) - - # Compute marginal log likelihood - this is the correct metric for single models mll = marginal_log_likelihood(mask, x, Ab, Q, pi) score = float(mll) # Store as "score" - this is MLL - std_error = 0.0 # No standard error for single model MLL self.insert1( { **key, "score": score, - "std_error": std_error, } ) From bbf691f4e8cda092ecb7640d66930cc18c65cad3 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 05:06:04 +0100 Subject: [PATCH 33/41] black --- element_moseq/moseq_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 7ec744e..0e83862 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1335,6 +1335,7 @@ def make(self, key): model_name = str(model_name) execution_time = datetime.now(timezone.utc) + # Fit the model model, _ = fit_model( model=model, @@ -1648,7 +1649,6 @@ def make(self, key): # Fit the model try: - model, model_name = fit_model( model=model_to_fit, model_name=model_name, @@ -1660,7 +1660,7 @@ def make(self, key): generate_progress_plots=True, save_every_n_iters=1, # TODO: to change to a higher value verbose=False, - ) # checkpoint will be saved at project_dir/model_name + ) except Exception as e: raise ValueError(f"FullFit training failed: {e}") From 1510047a43935d7831fc96174a3525536efd3ae1 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 06:03:54 +0100 Subject: [PATCH 34/41] fix(inference) --- element_moseq/moseq_infer.py | 137 ++++++++++++++--------------------- 1 file changed, 56 insertions(+), 81 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index e7e9f10..bf83968 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -223,25 +223,16 @@ def make_fetch(self, key): ) # model dir relative to processed data directory model_key = (Model * moseq_train.SelectedFullFit & key).fetch1("KEY") - checkpoint_file_path = ( + fullfit_checkpoint_path = ( moseq_train.FullFit.File & model_key & 'file_name="checkpoint.h5"' ).fetch1("file_path") - kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( - "config_file" - ) - pca_file_path = ( + fullfit_kpms_dj_config_file = ( + moseq_train.FullFit.ConfigFile & model_key + ).fetch1("config_file") + fullfit_pca_file_path = ( moseq_train.PCAFit.File & model_key & 'file_name="pca.p"' ).fetch1("file_path") - data_file_path = ( - moseq_train.PCAFit.File & model_key & 'file_name="data.pkl"' - ).fetch1("file_path") - metadata_file_path = ( - moseq_train.PCAFit.File & model_key & 'file_name="metadata.pkl"' - ).fetch1("file_path") - coordinates, confidences = (moseq_train.PreProcessing & model_key).fetch( - "coordinates", "confidences" - ) return ( keypointset_dir, inference_output_dir, @@ -249,13 +240,9 @@ def make_fetch(self, key): task_mode, model_dir_rel, model_file, - checkpoint_file_path, - kpms_dj_config_file, - pca_file_path, - data_file_path, - metadata_file_path, - coordinates, - confidences, + fullfit_checkpoint_path, + fullfit_kpms_dj_config_file, + fullfit_pca_file_path, ) def make_compute( @@ -267,34 +254,12 @@ def make_compute( task_mode, model_dir_rel, model_file, - checkpoint_file_path, - kpms_dj_config_file, - pca_file_path, - data_file_path, - metadata_file_path, - coordinates, - confidences, + fullfit_checkpoint_path, + fullfit_kpms_dj_config_file, + fullfit_pca_file_path, ): """ Compute model inference results. - - Args: - key (dict): `InferenceTask` primary key. - keypointset_dir (str): Directory containing keypoint data. - inference_output_dir (str): Output directory for inference results. - num_iterations (int): Number of iterations for model fitting. - model_id (int): Model ID. - pose_estimation_method (str): Pose estimation method. - task_mode (str): Task mode ('trigger' or 'load'). - - Raises: - FileNotFoundError: If no pca model (`pca.p`) found in the parent model directory. - FileNotFoundError: If no model (`checkpoint.h5`) found in the model directory. - NotImplementedError: If the format method is not `deeplabcut`. - FileNotFoundError: If no valid `kpms_dj_config` found in the parent model directory. - - Returns: - tuple: Inference results including duration, results data, and sampled instances. """ from keypoint_moseq import ( apply_model, @@ -303,6 +268,7 @@ def make_compute( load_keypoints, load_pca, load_results, + outlier_removal, save_results_as_csv, ) @@ -311,51 +277,60 @@ def make_compute( start_time = datetime.now(timezone.utc) - # Get directories first + # Get directories for new recordings kpms_root = moseq_train.get_kpms_root_data_dir() kpms_processed = moseq_train.get_kpms_processed_data_dir() # Construct the full path to the inference output directory inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir - - if task_mode == "trigger": - if not inference_output_dir.exists(): - inference_output_dir.mkdir(parents=True, exist_ok=True) - keypointset_dir = find_full_path(kpms_root, keypointset_dir) if task_mode == "trigger": - kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_file + # load saved model data + fullfit_kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( + config_path=fullfit_kpms_dj_config_file ) + fullfit_model, _, _, _ = load_checkpoint(path=fullfit_checkpoint_path) + # fullfit_model_pca = pickle.load(open(fullfit_pca_file_path, "rb")) - metadata = pickle.load(open(metadata_file_path, "rb")) - data = pickle.load(open(data_file_path, "rb")) - model_data = pickle.load(open(model_file, "rb")) - if task_mode == "trigger": - results = apply_model( - model_name=inference_output_dir.name, - model=model_data, - data=data, - metadata=metadata, - pca=pca_file_path, - project_dir=inference_output_dir.parent, - results_path=(inference_output_dir / "results.h5"), - return_model=False, - num_iters=num_iterations or DEFAULT_NUM_ITERS, - overwrite=True, - save_results=True, - **kpms_dj_config_dict, - ) + # Load new data + coordinates, confidences, bodyparts = load_keypoints( + filepath_pattern=keypointset_dir, format="deeplabcut" + ) + coordinates, confidences = outlier_removal( + coordinates, + confidences, + inference_output_dir, + overwrite=False, + **fullfit_kpms_dj_config_dict, + ) + data, metadata = format_data( + coordinates, confidences, **fullfit_kpms_dj_config_dict + ) - # Create results directory and save CSV files - save_results_as_csv( - results=results, - save_dir=(inference_output_dir / "results_as_csv").as_posix(), - ) + # # apply saved model to new data + results = apply_model( + model=fullfit_model, + data=data, + metadata=metadata, + project_dir=inference_output_dir, + model_name=inference_output_dir.name, + results_path=(inference_output_dir / "results.h5"), + return_model=False, + num_iters=num_iterations or DEFAULT_NUM_ITERS, + overwrite=True, + save_results=True, + **fullfit_kpms_dj_config_dict, + ) + + # Create results directory and save CSV files + save_results_as_csv( + results=results, + save_dir=(inference_output_dir / "results_as_csv").as_posix(), + ) - end_time = datetime.now(timezone.utc) - duration_seconds = (end_time - start_time).total_seconds() + end_time = datetime.now(timezone.utc) + duration_seconds = (end_time - start_time).total_seconds() else: duration_seconds = None @@ -406,7 +381,7 @@ class VideoSequence(dj.Part): latent_states : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model centroids : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model headings : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model - file_csv : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format) + file : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format) """ class SampledInstance(dj.Part): @@ -509,7 +484,7 @@ def make(self, key): "latent_states": latent_states[vid], "centroids": filtered_centroids[vid], "headings": filtered_headings[vid], - "file_csv": ( + "file": ( inference_output_dir / "results_as_csv" / f"{vid}.csv" ).as_posix(), } From 2208df5568e29c6994f59365d43f9e5280c10e1e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 06:31:16 +0100 Subject: [PATCH 35/41] refactor(moseq_report) --- element_moseq/moseq_report.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py index 3d9c8b1..7ba4515 100644 --- a/element_moseq/moseq_report.py +++ b/element_moseq/moseq_report.py @@ -102,13 +102,15 @@ def make(self, key): "KEY" ) coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates") - config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1("config_file") - kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(config_path=config_file) + use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts") + fps = (moseq_train.PreProcessing & model_key).fetch1("average_frame_rate") + plot_similarity_dendrogram( coordinates=coordinates, results=results, save_path=(inference_output_dir / "similarity_dendrogram").as_posix(), - **kpms_dj_config_dict, + use_bodyparts=use_bodyparts, + fps=fps, ) # Insert the record @@ -163,7 +165,7 @@ def make(self, key): model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1( "KEY" ) - coordinates_dict = (moseq_train.PreProcessing & model_key).fetch1("coordinates") + coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates") kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( "config_file" ) @@ -171,6 +173,11 @@ def make(self, key): config_path=kpms_dj_config_file ) + # Get use_bodyparts from the BodyParts table + use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts") + + fps = (moseq_train.PreProcessing & model_key).fetch1("average_frame_rate") + # Construct output directory kpms_processed = moseq_train.get_kpms_processed_data_dir() output_dir = Path(model_dir) / inference_output_dir @@ -185,20 +192,27 @@ def make(self, key): # Load results results = h5py.File(results_file, "r") + logger.info(f"Generating trajectory plots for {key}") # Generate trajectory plots generate_trajectory_plots( - coordinates=coordinates_dict, + coordinates=coordinates, results=results, output_dir=trajectory_dir.as_posix(), - **kpms_dj_config_dict, + use_bodyparts=use_bodyparts, + fps=fps, + skeleton=kpms_dj_config_dict.get("skeleton", []), ) + logger.info(f"Generating grid movies for {key}") # Generate grid movies generate_grid_movies( - coordinates=coordinates_dict, + coordinates=coordinates, results=results, output_dir=grid_movies_dir.as_posix(), - **kpms_dj_config_dict, + use_bodyparts=use_bodyparts, + fps=fps, + overlay_keypoints=True, + skeleton=kpms_dj_config_dict.get("skeleton", []), ) # Calculate duration From 7c627e2be7922e76f45b05911e59c969e589a80e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 06:55:41 +0100 Subject: [PATCH 36/41] update Inference table definition --- element_moseq/moseq_infer.py | 50 +++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index bf83968..ba97f84 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -119,9 +119,9 @@ class File(dj.Part): definition = """ -> master - file_id: int # Unique ID for each file + file_id: int # Unique ID for each file --- - file_path: varchar(1000) # Filepath of each video, relative to root data directory. + file_path: filepath@moseq-infer-processed # Filepath of each video, relative to root data directory. """ @@ -196,8 +196,11 @@ class Inference(dj.Computed): definition = """ -> InferenceTask # `InferenceTask` key --- - syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings - inference_duration=NULL : float # Time duration (seconds) of the inference computation + coordinates : longblob # Cleaned coordinates dictionary after outlier removal. + confidences : longblob # Cleaned confidences dictionary after outlier removal. + average_frame_rate : int # Average frame rate of the videos for model training (used for kappa calculation). + syllable_segmentation_file : filepath@moseq-infer-processed # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings + inference_duration=NULL : float # Time duration (seconds) of the inference computation """ def make_fetch(self, key): @@ -261,6 +264,10 @@ def make_compute( """ Compute model inference results. """ + import glob + import os + + import cv2 from keypoint_moseq import ( apply_model, format_data, @@ -290,8 +297,30 @@ def make_compute( fullfit_kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( config_path=fullfit_kpms_dj_config_file ) + + # calculate average frame rate of all video in keypointset_dir + # Search for multiple common video extensions: mp4, avi, mov, mkv, etc. + video_extensions = [ + "*.mp4", + "*.avi", + "*.mov", + "*.mkv", + "*.wmv", + "*.mpeg", + "*.mpg", + ] + video_files = [] + for ext in video_extensions: + video_files.extend(glob.glob(os.path.join(keypointset_dir, ext))) + frame_rates = [] + for video_file in video_files: + cap = cv2.VideoCapture(video_file) + frame_rate = cap.get(cv2.CAP_PROP_FPS) + frame_rates.append(frame_rate) + average_frame_rate = np.mean(frame_rates) + + # Load fullfit model fullfit_model, _, _, _ = load_checkpoint(path=fullfit_checkpoint_path) - # fullfit_model_pca = pickle.load(open(fullfit_pca_file_path, "rb")) # Load new data coordinates, confidences, bodyparts = load_keypoints( @@ -340,6 +369,9 @@ def make_compute( return ( duration_seconds, results_filepath, + average_frame_rate, + coordinates, + confidences, ) def make_insert( @@ -347,6 +379,9 @@ def make_insert( key, duration_seconds, results_filepath, + average_frame_rate, + coordinates, + confidences, ): """ Insert inference results into the database. @@ -354,8 +389,11 @@ def make_insert( self.insert1( { **key, - "inference_duration": duration_seconds, "syllable_segmentation_file": results_filepath, + "coordinates": coordinates, + "confidences": confidences, + "average_frame_rate": average_frame_rate, + "inference_duration": duration_seconds, } ) From 6720dcae5f2071430add324cee23d69977ef2241 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 06:56:22 +0100 Subject: [PATCH 37/41] fix(trajectoryplots) --- element_moseq/moseq_report.py | 38 +++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py index 7ba4515..ae0a820 100644 --- a/element_moseq/moseq_report.py +++ b/element_moseq/moseq_report.py @@ -155,28 +155,35 @@ def make(self, key): start_time = datetime.now(timezone.utc) - results_file = (moseq_infer.Inference & key).fetch1( - "syllable_segmentation_file" - ) + kpms_processed = moseq_train.get_kpms_processed_data_dir() + kpms_root = moseq_train.get_kpms_root_data_dir() + + # From trained model model_dir = (moseq_infer.Model & key).fetch1("model_dir") - inference_output_dir = (moseq_infer.InferenceTask & key).fetch1( - "inference_output_dir" - ) model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1( "KEY" ) - coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates") - kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1( + use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts") + kpms_dj_config_path = (moseq_train.FullFit.ConfigFile & model_key).fetch1( "config_file" ) kpms_dj_config_dict = kpms_reader.load_kpms_dj_config( - config_path=kpms_dj_config_file + config_path=kpms_dj_config_path ) - # Get use_bodyparts from the BodyParts table - use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts") - - fps = (moseq_train.PreProcessing & model_key).fetch1("average_frame_rate") + # From new recordings + inference_output_dir = (moseq_infer.InferenceTask & key).fetch1( + "inference_output_dir" + ) + coordinates = (moseq_infer.Inference & key).fetch1("coordinates") + results_file = (moseq_infer.Inference & key).fetch1( + "syllable_segmentation_file" + ) + results = h5py.File(results_file, "r") + fps = (moseq_infer.Inference & key).fetch1("average_frame_rate") + kpset_dir = (moseq_infer.InferenceTask & key).fetch1("keypointset_dir") + kpset_dir = find_full_path(kpms_root, kpset_dir) + kpset_dir = Path(kpset_dir).as_posix() # Construct output directory kpms_processed = moseq_train.get_kpms_processed_data_dir() @@ -189,9 +196,6 @@ def make(self, key): trajectory_dir.mkdir(parents=True, exist_ok=True) grid_movies_dir.mkdir(parents=True, exist_ok=True) - # Load results - results = h5py.File(results_file, "r") - logger.info(f"Generating trajectory plots for {key}") # Generate trajectory plots generate_trajectory_plots( @@ -206,8 +210,8 @@ def make(self, key): logger.info(f"Generating grid movies for {key}") # Generate grid movies generate_grid_movies( - coordinates=coordinates, results=results, + video_path=kpset_dir, output_dir=grid_movies_dir.as_posix(), use_bodyparts=use_bodyparts, fps=fps, From 3df6afeba02497c497a2bcadb1bc4349d0ad3fcf Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 07:08:10 +0100 Subject: [PATCH 38/41] minor fix in FullFit --- element_moseq/moseq_train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 0e83862..efc77dc 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1621,10 +1621,9 @@ def make(self, key): pre_model_file = ( PreFit.File & pre_model_key & 'file_name="checkpoint.h5"' ).fetch1("file_path") - with open(pre_model_file, "rb") as f: - pre_model = pickle.load(f) + pre_model, data, metadata, _ = load_checkpoint(path=pre_model_file) logger.info( - f"Using PreFit model {pre_model_key_query} as warm start for FullFit" + f"Using PreFit model {pre_model_key} as warm start for FullFit" ) execution_time = datetime.now(timezone.utc) From 35b5c604670c3483eb2d86c8e92ca66b997ea24f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 07:20:27 +0100 Subject: [PATCH 39/41] minor revert --- element_moseq/moseq_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index ba97f84..dec6435 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -121,7 +121,7 @@ class File(dj.Part): -> master file_id: int # Unique ID for each file --- - file_path: filepath@moseq-infer-processed # Filepath of each video, relative to root data directory. + file_path: varchar(1000) # Filepath of each video, relative to root data directory. """ From 65b65c39efa0f8e8c95b04b175df276551d534b0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 14:33:58 +0100 Subject: [PATCH 40/41] minor update --- element_moseq/moseq_train.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index efc77dc..c4ef1b6 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -398,9 +398,15 @@ def make_compute( PCATask.update1( {**key, "kpms_project_output_dir": kpms_project_output_dir} ) - kpms_project_output_dir = ( - Path(get_kpms_processed_data_dir()) / kpms_project_output_dir - ) + + try: + kpms_project_output_dir = find_full_path( + get_kpms_processed_data_dir(), kpms_project_output_dir + ) + except FileNotFoundError: + kpms_project_output_dir = ( + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir + ) # Resolve kpset_dir to absolute and check if it exists kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) From e619ba69de6f899f9f0e1d6bc05c2c55c4245083 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 28 Oct 2025 15:00:36 +0100 Subject: [PATCH 41/41] remove fixed frames for preprocessing --- element_moseq/moseq_train.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index c4ef1b6..484ff33 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -822,16 +822,6 @@ def make_compute( # Generate overlay video for this specific recording (skip if already exists) if not overlay_video_path.exists(): - # Calculate frames for 1 minute of video - frame_rate = fps_lookup.get( - video_id, 30.0 - ) # TODO: Default to 30fps if not found - frames_for_dur = int(frame_rate * 6) - - logger.info( - f"Processing video {video_id}: {frame_rate}fps -> {frames_for_dur} frames for 1min" - ) - overlay_keypoints_on_video( video_path=video_file_path.as_posix(), coordinates=coordinates[ @@ -840,7 +830,6 @@ def make_compute( skeleton=kpms_dj_config_dict["skeleton"], bodyparts=list(use_bodyparts), output_path=overlay_video_path.as_posix(), - frames=range(frames_for_dur), ) logger.info(f"Generated overlay video: {overlay_video_path}") else: