Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [1.0.3] - 2025-10-15
+ Fix - Add new `pose_estimation_path` attribute in `KeypointSet.VideoFile`
+ Fix - Update `moseq_report` and `moseq_train` to use new `QA` and `quality_assurance` directory names
+ Fix - Update `moseq_train` to add new attribute `pose_estimation_path` attribute in `KeypointSet.VideoFile`
+ Fix - Update logic in `PreProcessing` to check for outlier plots
+ Fix - add `poppler` as system dependency in `conda_env.yml`
+ Update - code cleanup related to `copy_pdf_to_png` function
+ Update - remove unused helper functions in `viz_utils` module after these changes

## [1.0.2] - 2025-10-07
+ Update - `kpms` as extra dependency (includes `keypoint-moseq`)
+ Fix - Version pin `jax<0.7.0`
Expand Down
1 change: 1 addition & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ channels:
dependencies:
- pip
- python<3.11
- poppler

name: element_moseq
83 changes: 47 additions & 36 deletions element_moseq/moseq_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,61 +69,72 @@ class PreProcessingReport(dj.Imported):
-> moseq_train.PreProcessing
video_id: int # ID of the matching video file
---
recording_name: varchar(255) # Name of the recording
outlier_plot: attach # A plot of the outlier keypoints
"""

def make(self, key):
# Resolve project dir
project_rel = (moseq_train.PCATask & key).fetch1("kpms_project_output_dir")
kpms_project_output_dir = (
Path(moseq_train.get_kpms_processed_data_dir()) / project_rel
)
video_ids, pose_estimation_paths = (
moseq_train.KeypointSet.VideoFile & key
).fetch("video_id", "pose_estimation_path")
# Map pose estimation filename (without .h5 extension) to video id
valid_entries = [
(vid, p)
for vid, p in zip(video_ids, pose_estimation_paths)
if p is not None
]
if not valid_entries:
raise ValueError(
"No valid pose_estimation_paths found - all entries are NULL"
)

# Fetch video table info
video_paths, video_ids = (moseq_train.KeypointSet.VideoFile & key).fetch(
"video_path", "video_id"
)

# Get recording names from PreProcessing dict keys
coords = (moseq_train.PreProcessing & key).fetch1("coordinates")
recording_names = list(coords.keys())
posefile2vid = {Path(p).stem: vid for vid, p in valid_entries}
recording_names = list(posefile2vid.keys())

# Build mapping recording_name -> video_id
rec2vid = viz_utils.build_recording_to_video_id(
recording_names=recording_names,
video_paths=list(video_paths),
video_ids=list(video_ids),
)
if not recording_names:
raise ValueError(
"No recording names found after processing pose estimation paths"
)

# Insert one row per recording that matched a video_id
# Insert one row per recording
for rec in recording_names:
vid = rec2vid.get(rec)
if vid is None:
dj.logger.warning(
f"[PreProcessingReport] No video_id match for recording '{rec}'. Skipping."
vid = posefile2vid[rec]

qa_dirs = ["QA", "quality_assurance"]
plot_path = None

for qa_dir in qa_dirs:
potential_path = (
kpms_project_output_dir
/ qa_dir
/ "plots"
/ "keypoint_distance_outliers"
/ f"{rec}.png"
)
continue

plot_path = (
kpms_project_output_dir
/ "quality_assurance"
/ "plots"
/ "keypoint_distance_outliers"
/ f"{rec}.png"
)

if not plot_path.exists():
dj.logger.warning(
f"[PreProcessingReport] Outlier plot not found at {plot_path}. Skipping."
if potential_path.exists():
plot_path = potential_path
break

if plot_path is None:
checked_paths = [
kpms_project_output_dir
/ qa_dir
/ "plots"
/ "keypoint_distance_outliers"
/ f"{rec}.png"
for qa_dir in qa_dirs
]
raise FileNotFoundError(
f"Outlier plot not found for {rec}. Checked paths: {[str(p) for p in checked_paths]}"
)
continue

self.insert1(
{
**key,
"video_id": int(vid),
"recording_name": rec,
"outlier_plot": plot_path.as_posix(),
}
)
Expand Down
61 changes: 41 additions & 20 deletions element_moseq/moseq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from element_interface.utils import find_full_path

from .plotting import viz_utils
from .plotting.viz_utils import copy_pdf_to_png
from .readers import kpms_reader

schema = dj.schema()
Expand Down Expand Up @@ -161,9 +161,10 @@ class VideoFile(dj.Part):

definition = """
-> master
video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
---
video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory
video_path : varchar(1000) # Filepath of each video (e.g., `.mp4`) from which the keypoints are derived, relative to root data directory
pose_estimation_path=null : varchar(1000) # Filepath of each pose estimation file (e.g., `.h5`) that contains the keypoints, relative to root data directory
"""


Expand Down Expand Up @@ -341,7 +342,6 @@ def make_compute(
kpms_project_output_dir = find_full_path(
get_kpms_processed_data_dir(), kpms_project_output_dir
)

except FileNotFoundError:
kpms_project_output_dir = (
Path(get_kpms_processed_data_dir()) / kpms_project_output_dir
Expand Down Expand Up @@ -454,22 +454,43 @@ def make_compute(

# Plot outliers
if formatted_bodyparts is not None:
try:
plot_medoid_distance_outliers(
project_dir=kpms_project_output_dir.as_posix(),
recording_name=recording_name,
original_coordinates=raw_coords,
interpolated_coordinates=cleaned_coords,
outlier_mask=outliers["mask"],
outlier_thresholds=outliers["thresholds"],
**kpms_config,
)
plot_medoid_distance_outliers(
project_dir=kpms_project_output_dir.as_posix(),
recording_name=recording_name,
original_coordinates=raw_coords,
interpolated_coordinates=cleaned_coords,
outlier_mask=outliers["mask"],
outlier_thresholds=outliers["thresholds"],
**kpms_config,
)

except Exception as e:
logger.warning(
f"Could not create outlier plot for {recording_name}: {e}"
)
qa_dirs = ["QA", "quality_assurance"]
plot_path = None

for qa_dir in qa_dirs:
potential_path = (
kpms_project_output_dir
/ qa_dir
/ "plots"
/ "keypoint_distance_outliers"
/ f"{recording_name}.png"
)
if potential_path.exists():
plot_path = potential_path
break

if plot_path is None:
checked_paths = [
kpms_project_output_dir
/ qa_dir
/ "plots"
/ "keypoint_distance_outliers"
/ f"{recording_name}.png"
for qa_dir in qa_dirs
]
raise FileNotFoundError(
f"Could not create outlier plot for {recording_name}. Checked paths: {[str(p) for p in checked_paths]}"
)
return (
cleaned_coordinates,
cleaned_confidences,
Expand Down Expand Up @@ -790,7 +811,7 @@ def make(self, key):
end_time = datetime.now(timezone.utc)

duration_seconds = (end_time - start_time).total_seconds()
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
copy_pdf_to_png(kpms_project_output_dir, model_name)

else:
duration_seconds = None
Expand Down Expand Up @@ -954,7 +975,7 @@ def make(self, key):
project_dir=kpms_project_output_dir.as_posix(),
model_name=Path(model_name).parts[-1],
)
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
copy_pdf_to_png(kpms_project_output_dir, model_name)

else:
duration_seconds = None
Expand Down
113 changes: 25 additions & 88 deletions element_moseq/plotting/viz_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,72 +13,6 @@

logger = dj.logger

_DLC_SUFFIX_RE = re.compile(
r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token
r"(?:\w+)*)" # optional extra blobs
r"(?:shuffle\d+)?" # shuffleN
r"(?:_\d+)?$" # _iter
)


def _normalize_name(name: str) -> str:
"""
Normalize a recording/video string for matching:
- lowercase, strip whitespace
- drop extension if present
- remove common DLC suffix blob (e.g., '...DLC_resnet50_...shuffle1_500000')
- collapse separators to single spaces
"""
s = name.lower().strip()
s = Path(s).stem
s = _DLC_SUFFIX_RE.sub("", s) # strip DLC tail if present
s = re.sub(r"[\s._-]+", " ", s).strip()
return s


def build_recording_to_video_id(
recording_names: List[str],
video_paths: List[str],
video_ids: List[int],
fuzzy_threshold: float = 0.80,
) -> Dict[str, Optional[int]]:
"""
Returns: {recording_name -> video_id or None if no good match}
Strategy: exact normalized stem match; if none, substring; then fuzzy.
"""
# candidate stems from videos
stems: List[Tuple[str, int]] = [
(_normalize_name(Path(p).name), vid) for p, vid in zip(video_paths, video_ids)
]

mapping: Dict[str, Optional[int]] = {}

for rec in recording_names:
nrec = _normalize_name(rec)

# 1) exact normalized match
exact = [vid for stem, vid in stems if stem == nrec]
if exact:
mapping[rec] = exact[0]
continue

# 2) substring either way (choose longest stem to disambiguate)
subs = [(stem, vid) for stem, vid in stems if nrec in stem or stem in nrec]
if subs:
subs.sort(key=lambda x: len(x[0]), reverse=True)
mapping[rec] = subs[0][1]
continue

# 3) fuzzy best match
best_vid, best_ratio = None, 0.0
for stem, vid in stems:
r = SequenceMatcher(None, nrec, stem).ratio()
if r > best_ratio:
best_ratio, best_vid = r, vid
mapping[rec] = best_vid if best_ratio >= fuzzy_threshold else None

return mapping


def plot_medoid_distance_outliers(
project_dir: str,
Expand Down Expand Up @@ -131,14 +65,26 @@ def plot_medoid_distance_outliers(
"""
from keypoint_moseq.util import get_distance_to_medoid, plot_keypoint_traces

plot_path = os.path.join(
project_dir,
"quality_assurance",
"plots",
"keypoint_distance_outliers",
f"{recording_name}.png",
)
os.makedirs(os.path.dirname(plot_path), exist_ok=True)
qa_dirs = ["QA", "quality_assurance"]
plot_path = None

for qa_dir in qa_dirs:
potential_path = os.path.join(
project_dir,
qa_dir,
"plots",
"keypoint_distance_outliers",
f"{recording_name}.png",
)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(potential_path), exist_ok=True)
plot_path = potential_path
break # Use first available directory

if plot_path is None:
raise FileNotFoundError(
f"Could not determine plot directory for {recording_name}"
)

original_distances = get_distance_to_medoid(
original_coordinates
Expand Down Expand Up @@ -318,9 +264,8 @@ def plot_pcs(
plt.tight_layout()

if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
if project_dir is None:
raise ValueError(fill("The `savefig` option requires a `project_dir`"))
plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf"))
plt.show()

Expand All @@ -341,18 +286,13 @@ def copy_pdf_to_png(project_dir, model_name):
"""
Convert PDF progress plot to PNG format using pdf2image.
The fit_model function generates a single fitting_progress.pdf file.
This function should always succeed if the PDF exists.

Args:
project_dir (Path): Project directory path
model_name (str): Model name directory

Returns:
bool: True if conversion was successful, False otherwise

Raises:
FileNotFoundError: If the PDF file doesn't exist
RuntimeError: If conversion fails
None: The function raises errors instead of returning boolean values
"""
from pdf2image import convert_from_path

Expand All @@ -361,15 +301,12 @@ def copy_pdf_to_png(project_dir, model_name):
pdf_path = model_dir / "fitting_progress.pdf"
png_path = model_dir / "fitting_progress.png"

# Check if PDF exists
if not pdf_path.exists():
raise FileNotFoundError(f"PDF progress plot not found at {pdf_path}")

# Convert PDF to PNG
images = convert_from_path(pdf_path, dpi=300)
images = convert_from_path(str(pdf_path), dpi=300)
if not images:
raise ValueError(f"No PDF file found at {pdf_path}")
raise ValueError(f"Could not convert PDF at {pdf_path} (no images returned)")

images[0].save(png_path, "PNG")
logger.info(f"Generated PNG progress plot at {png_path}")
return True
Loading