Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [1.0.3] - 2025-10-15

> **BREAKING CHANGES** - This version contains breaking changes due to keypoint-moseq upgrade and API refactoring. Please review the changes below and update your code accordingly.
### Breaking Changes
+ **BREAKING**: Remove `recording_name` attribute from `moseq_report` since it will be added in `VideoFile` in `moseq_train`
+ **BREAKING**: Add new `pose_estimation_path` attribute in `KeypointSet.VideoFile`

### New Features and Fixes
+ Fix - Update `moseq_report` and `moseq_train` to use new `QA` dir name instead of `quality_assurance`
+ Fix - Update logic in `PreProcessing` to check for outlier plots
+ Fix - add `poppler` as system dependency in `conda_env.yml`
+ Update - code cleanup related to `copy_pdf_to_png` function
+ Update - remove unused helper functions in `viz_utils` module after these changes

## [1.0.2] - 2025-10-07
+ Update - `kpms` as extra dependency (includes `keypoint-moseq`)
+ Fix - Version pin `jax<0.7.0`
Expand Down
1 change: 1 addition & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ channels:
dependencies:
- pip
- python<3.11
- poppler

name: element_moseq
54 changes: 25 additions & 29 deletions element_moseq/moseq_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,61 +69,57 @@ class PreProcessingReport(dj.Imported):
-> moseq_train.PreProcessing
video_id: int # ID of the matching video file
---
recording_name: varchar(255) # Name of the recording
outlier_plot: attach # A plot of the outlier keypoints
"""

def make(self, key):
# Resolve project dir
project_rel = (moseq_train.PCATask & key).fetch1("kpms_project_output_dir")
kpms_project_output_dir = (
Path(moseq_train.get_kpms_processed_data_dir()) / project_rel
)
video_ids, pose_estimation_paths = (
moseq_train.KeypointSet.VideoFile & key
).fetch("video_id", "pose_estimation_path")
# Map pose estimation filename (without .h5 extension) to video id
valid_entries = [
(vid, p)
for vid, p in zip(video_ids, pose_estimation_paths)
if p is not None
]
if not valid_entries:
raise ValueError(
"No valid pose_estimation_paths found - all entries are NULL"
)

# Fetch video table info
video_paths, video_ids = (moseq_train.KeypointSet.VideoFile & key).fetch(
"video_path", "video_id"
)

# Get recording names from PreProcessing dict keys
coords = (moseq_train.PreProcessing & key).fetch1("coordinates")
recording_names = list(coords.keys())
posefile2vid = {Path(p).stem: vid for vid, p in valid_entries}
recording_names = list(posefile2vid.keys())

# Build mapping recording_name -> video_id
rec2vid = viz_utils.build_recording_to_video_id(
recording_names=recording_names,
video_paths=list(video_paths),
video_ids=list(video_ids),
)
if not recording_names:
raise ValueError(
"No recording names found after processing pose estimation paths"
)

# Insert one row per recording that matched a video_id
# Insert one row per recording
for rec in recording_names:
vid = rec2vid.get(rec)
if vid is None:
dj.logger.warning(
f"[PreProcessingReport] No video_id match for recording '{rec}'. Skipping."
)
continue
vid = posefile2vid[rec]

# Look for outlier plot in QA directory
plot_path = (
kpms_project_output_dir
/ "quality_assurance"
/ "QA"
/ "plots"
/ "keypoint_distance_outliers"
/ f"{rec}.png"
)

if not plot_path.exists():
dj.logger.warning(
f"[PreProcessingReport] Outlier plot not found at {plot_path}. Skipping."
raise FileNotFoundError(
f"Outlier plot not found for {rec} at {plot_path}"
)
continue

self.insert1(
{
**key,
"video_id": int(vid),
"recording_name": rec,
"outlier_plot": plot_path.as_posix(),
}
)
Expand Down
46 changes: 26 additions & 20 deletions element_moseq/moseq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from element_interface.utils import find_full_path

from .plotting import viz_utils
from .plotting.viz_utils import copy_pdf_to_png
from .readers import kpms_reader

schema = dj.schema()
Expand Down Expand Up @@ -161,9 +161,10 @@ class VideoFile(dj.Part):

definition = """
-> master
video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
---
video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory
video_path : varchar(1000) # Filepath of each video (e.g., `.mp4`) from which the keypoints are derived, relative to root data directory
pose_estimation_path=null : varchar(1000) # Filepath of each pose estimation file (e.g., `.h5`) that contains the keypoints, relative to root data directory
"""


Expand Down Expand Up @@ -341,7 +342,6 @@ def make_compute(
kpms_project_output_dir = find_full_path(
get_kpms_processed_data_dir(), kpms_project_output_dir
)

except FileNotFoundError:
kpms_project_output_dir = (
Path(get_kpms_processed_data_dir()) / kpms_project_output_dir
Expand Down Expand Up @@ -454,22 +454,28 @@ def make_compute(

# Plot outliers
if formatted_bodyparts is not None:
try:
plot_medoid_distance_outliers(
project_dir=kpms_project_output_dir.as_posix(),
recording_name=recording_name,
original_coordinates=raw_coords,
interpolated_coordinates=cleaned_coords,
outlier_mask=outliers["mask"],
outlier_thresholds=outliers["thresholds"],
**kpms_config,
)
plot_medoid_distance_outliers(
project_dir=kpms_project_output_dir.as_posix(),
recording_name=recording_name,
original_coordinates=raw_coords,
interpolated_coordinates=cleaned_coords,
outlier_mask=outliers["mask"],
outlier_thresholds=outliers["thresholds"],
**kpms_config,
)

except Exception as e:
logger.warning(
f"Could not create outlier plot for {recording_name}: {e}"
# Check if outlier plot was created in QA directory
plot_path = (
kpms_project_output_dir
/ "QA"
/ "plots"
/ "keypoint_distance_outliers"
/ f"{recording_name}.png"
)
if not plot_path.exists():
raise FileNotFoundError(
f"Could not create outlier plot for {recording_name} at {plot_path}"
)

return (
cleaned_coordinates,
cleaned_confidences,
Expand Down Expand Up @@ -790,7 +796,7 @@ def make(self, key):
end_time = datetime.now(timezone.utc)

duration_seconds = (end_time - start_time).total_seconds()
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
copy_pdf_to_png(kpms_project_output_dir, model_name)

else:
duration_seconds = None
Expand Down Expand Up @@ -954,7 +960,7 @@ def make(self, key):
project_dir=kpms_project_output_dir.as_posix(),
model_name=Path(model_name).parts[-1],
)
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
copy_pdf_to_png(kpms_project_output_dir, model_name)

else:
duration_seconds = None
Expand Down
89 changes: 7 additions & 82 deletions element_moseq/plotting/viz_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,72 +13,6 @@

logger = dj.logger

_DLC_SUFFIX_RE = re.compile(
r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token
r"(?:\w+)*)" # optional extra blobs
r"(?:shuffle\d+)?" # shuffleN
r"(?:_\d+)?$" # _iter
)


def _normalize_name(name: str) -> str:
"""
Normalize a recording/video string for matching:
- lowercase, strip whitespace
- drop extension if present
- remove common DLC suffix blob (e.g., '...DLC_resnet50_...shuffle1_500000')
- collapse separators to single spaces
"""
s = name.lower().strip()
s = Path(s).stem
s = _DLC_SUFFIX_RE.sub("", s) # strip DLC tail if present
s = re.sub(r"[\s._-]+", " ", s).strip()
return s


def build_recording_to_video_id(
recording_names: List[str],
video_paths: List[str],
video_ids: List[int],
fuzzy_threshold: float = 0.80,
) -> Dict[str, Optional[int]]:
"""
Returns: {recording_name -> video_id or None if no good match}
Strategy: exact normalized stem match; if none, substring; then fuzzy.
"""
# candidate stems from videos
stems: List[Tuple[str, int]] = [
(_normalize_name(Path(p).name), vid) for p, vid in zip(video_paths, video_ids)
]

mapping: Dict[str, Optional[int]] = {}

for rec in recording_names:
nrec = _normalize_name(rec)

# 1) exact normalized match
exact = [vid for stem, vid in stems if stem == nrec]
if exact:
mapping[rec] = exact[0]
continue

# 2) substring either way (choose longest stem to disambiguate)
subs = [(stem, vid) for stem, vid in stems if nrec in stem or stem in nrec]
if subs:
subs.sort(key=lambda x: len(x[0]), reverse=True)
mapping[rec] = subs[0][1]
continue

# 3) fuzzy best match
best_vid, best_ratio = None, 0.0
for stem, vid in stems:
r = SequenceMatcher(None, nrec, stem).ratio()
if r > best_ratio:
best_ratio, best_vid = r, vid
mapping[rec] = best_vid if best_ratio >= fuzzy_threshold else None

return mapping


def plot_medoid_distance_outliers(
project_dir: str,
Expand Down Expand Up @@ -131,13 +65,15 @@ def plot_medoid_distance_outliers(
"""
from keypoint_moseq.util import get_distance_to_medoid, plot_keypoint_traces

# Use QA directory for outlier plots
plot_path = os.path.join(
project_dir,
"quality_assurance",
"QA",
"plots",
"keypoint_distance_outliers",
f"{recording_name}.png",
)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(plot_path), exist_ok=True)

original_distances = get_distance_to_medoid(
Expand Down Expand Up @@ -318,9 +254,8 @@ def plot_pcs(
plt.tight_layout()

if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
if project_dir is None:
raise ValueError(fill("The `savefig` option requires a `project_dir`"))
plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf"))
plt.show()

Expand All @@ -341,18 +276,11 @@ def copy_pdf_to_png(project_dir, model_name):
"""
Convert PDF progress plot to PNG format using pdf2image.
The fit_model function generates a single fitting_progress.pdf file.
This function should always succeed if the PDF exists.

Args:
project_dir (Path): Project directory path
model_name (str): Model name directory

Returns:
bool: True if conversion was successful, False otherwise

Raises:
FileNotFoundError: If the PDF file doesn't exist
RuntimeError: If conversion fails
"""
from pdf2image import convert_from_path

Expand All @@ -361,15 +289,12 @@ def copy_pdf_to_png(project_dir, model_name):
pdf_path = model_dir / "fitting_progress.pdf"
png_path = model_dir / "fitting_progress.png"

# Check if PDF exists
if not pdf_path.exists():
raise FileNotFoundError(f"PDF progress plot not found at {pdf_path}")

# Convert PDF to PNG
images = convert_from_path(pdf_path, dpi=300)
images = convert_from_path(str(pdf_path), dpi=300)
if not images:
raise ValueError(f"No PDF file found at {pdf_path}")
raise ValueError(f"Could not convert PDF at {pdf_path} (no images returned)")

images[0].save(png_path, "PNG")
logger.info(f"Generated PNG progress plot at {png_path}")
return True
2 changes: 1 addition & 1 deletion element_moseq/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Package metadata
"""

__version__ = "1.0.2"
__version__ = "1.0.3"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "element-moseq"
version = "1.0.2"
version = "1.0.3"
description = "Keypoint-MoSeq DataJoint Element"
readme = "README.md"
license = {text = "MIT"}
Expand Down