Skip to content

Commit 5b134e8

Browse files
authored
Merge pull request #22 from MilagrosMarin/update_docs
fix: add new attribute to track pose estimation files, update logic in two schemas, refactor & code cleanup
2 parents 978021c + ef2e65a commit 5b134e8

File tree

7 files changed

+76
-133
lines changed

7 files changed

+76
-133
lines changed

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,21 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [1.1.0] - 2025-10-15
7+
8+
> **BREAKING CHANGES** - This version contains breaking changes due to keypoint-moseq upgrade and API refactoring. Please review the changes below and update your code accordingly.
9+
10+
### Breaking Changes
11+
+ **BREAKING**: Remove `recording_name` attribute from `moseq_report` since it will be added in `VideoFile` in `moseq_train`
12+
+ **BREAKING**: Add new `pose_estimation_path` attribute in `KeypointSet.VideoFile`
13+
14+
### New Features and Fixes
15+
+ Fix - Update `moseq_report` and `moseq_train` to use new `QA` dir name instead of `quality_assurance`
16+
+ Fix - Update logic in `PreProcessing` to check for outlier plots
17+
+ Fix - add `poppler` as system dependency in `conda_env.yml`
18+
+ Update - code cleanup related to `copy_pdf_to_png` function
19+
+ Update - remove unused helper functions in `viz_utils` module after these changes
20+
621
## [1.0.2] - 2025-10-07
722
+ Update - `kpms` as extra dependency (includes `keypoint-moseq`)
823
+ Fix - Version pin `jax<0.7.0`

conda_env.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ channels:
44
dependencies:
55
- pip
66
- python<3.11
7+
- poppler
78

89
name: element_moseq

element_moseq/moseq_report.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -69,61 +69,57 @@ class PreProcessingReport(dj.Imported):
6969
-> moseq_train.PreProcessing
7070
video_id: int # ID of the matching video file
7171
---
72-
recording_name: varchar(255) # Name of the recording
7372
outlier_plot: attach # A plot of the outlier keypoints
7473
"""
7574

7675
def make(self, key):
77-
# Resolve project dir
7876
project_rel = (moseq_train.PCATask & key).fetch1("kpms_project_output_dir")
7977
kpms_project_output_dir = (
8078
Path(moseq_train.get_kpms_processed_data_dir()) / project_rel
8179
)
80+
video_ids, pose_estimation_paths = (
81+
moseq_train.KeypointSet.VideoFile & key
82+
).fetch("video_id", "pose_estimation_path")
83+
# Map pose estimation filename (without .h5 extension) to video id
84+
valid_entries = [
85+
(vid, p)
86+
for vid, p in zip(video_ids, pose_estimation_paths)
87+
if p and p.strip() # Check for non-empty strings
88+
]
89+
if not valid_entries:
90+
raise ValueError(
91+
"No valid pose_estimation_paths found - all entries are empty"
92+
)
8293

83-
# Fetch video table info
84-
video_paths, video_ids = (moseq_train.KeypointSet.VideoFile & key).fetch(
85-
"video_path", "video_id"
86-
)
87-
88-
# Get recording names from PreProcessing dict keys
89-
coords = (moseq_train.PreProcessing & key).fetch1("coordinates")
90-
recording_names = list(coords.keys())
94+
posefile2vid = {Path(p).stem: vid for vid, p in valid_entries}
95+
recording_names = list(posefile2vid.keys())
9196

92-
# Build mapping recording_name -> video_id
93-
rec2vid = viz_utils.build_recording_to_video_id(
94-
recording_names=recording_names,
95-
video_paths=list(video_paths),
96-
video_ids=list(video_ids),
97-
)
97+
if not recording_names:
98+
raise ValueError(
99+
"No recording names found after processing pose estimation paths"
100+
)
98101

99-
# Insert one row per recording that matched a video_id
102+
# Insert one row per recording
100103
for rec in recording_names:
101-
vid = rec2vid.get(rec)
102-
if vid is None:
103-
dj.logger.warning(
104-
f"[PreProcessingReport] No video_id match for recording '{rec}'. Skipping."
105-
)
106-
continue
104+
vid = posefile2vid[rec]
107105

106+
# Look for outlier plot in QA directory
108107
plot_path = (
109108
kpms_project_output_dir
110-
/ "quality_assurance"
109+
/ "QA"
111110
/ "plots"
112111
/ "keypoint_distance_outliers"
113112
/ f"{rec}.png"
114113
)
115-
116114
if not plot_path.exists():
117-
dj.logger.warning(
118-
f"[PreProcessingReport] Outlier plot not found at {plot_path}. Skipping."
115+
raise FileNotFoundError(
116+
f"Outlier plot not found for {rec} at {plot_path}"
119117
)
120-
continue
121118

122119
self.insert1(
123120
{
124121
**key,
125122
"video_id": int(vid),
126-
"recording_name": rec,
127123
"outlier_plot": plot_path.as_posix(),
128124
}
129125
)

element_moseq/moseq_train.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
from element_interface.utils import find_full_path
1616

17-
from .plotting import viz_utils
17+
from .plotting.viz_utils import copy_pdf_to_png
1818
from .readers import kpms_reader
1919

2020
schema = dj.schema()
@@ -161,9 +161,10 @@ class VideoFile(dj.Part):
161161

162162
definition = """
163163
-> master
164-
video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
164+
video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
165165
---
166-
video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory
166+
video_path : varchar(1000) # Filepath of each video (e.g., `.mp4`) from which the keypoints are derived, relative to root data directory
167+
pose_estimation_path='' : varchar(1000) # Filepath of each pose estimation file (e.g., `.h5`) that contains the keypoints, relative to root data directory
167168
"""
168169

169170

@@ -341,7 +342,6 @@ def make_compute(
341342
kpms_project_output_dir = find_full_path(
342343
get_kpms_processed_data_dir(), kpms_project_output_dir
343344
)
344-
345345
except FileNotFoundError:
346346
kpms_project_output_dir = (
347347
Path(get_kpms_processed_data_dir()) / kpms_project_output_dir
@@ -454,22 +454,28 @@ def make_compute(
454454

455455
# Plot outliers
456456
if formatted_bodyparts is not None:
457-
try:
458-
plot_medoid_distance_outliers(
459-
project_dir=kpms_project_output_dir.as_posix(),
460-
recording_name=recording_name,
461-
original_coordinates=raw_coords,
462-
interpolated_coordinates=cleaned_coords,
463-
outlier_mask=outliers["mask"],
464-
outlier_thresholds=outliers["thresholds"],
465-
**kpms_config,
466-
)
457+
plot_medoid_distance_outliers(
458+
project_dir=kpms_project_output_dir.as_posix(),
459+
recording_name=recording_name,
460+
original_coordinates=raw_coords,
461+
interpolated_coordinates=cleaned_coords,
462+
outlier_mask=outliers["mask"],
463+
outlier_thresholds=outliers["thresholds"],
464+
**kpms_config,
465+
)
467466

468-
except Exception as e:
469-
logger.warning(
470-
f"Could not create outlier plot for {recording_name}: {e}"
467+
# Check if outlier plot was created in QA directory
468+
plot_path = (
469+
kpms_project_output_dir
470+
/ "QA"
471+
/ "plots"
472+
/ "keypoint_distance_outliers"
473+
/ f"{recording_name}.png"
474+
)
475+
if not plot_path.exists():
476+
raise FileNotFoundError(
477+
f"Could not create outlier plot for {recording_name} at {plot_path}"
471478
)
472-
473479
return (
474480
cleaned_coordinates,
475481
cleaned_confidences,
@@ -790,7 +796,7 @@ def make(self, key):
790796
end_time = datetime.now(timezone.utc)
791797

792798
duration_seconds = (end_time - start_time).total_seconds()
793-
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
799+
copy_pdf_to_png(kpms_project_output_dir, model_name)
794800

795801
else:
796802
duration_seconds = None
@@ -954,7 +960,7 @@ def make(self, key):
954960
project_dir=kpms_project_output_dir.as_posix(),
955961
model_name=Path(model_name).parts[-1],
956962
)
957-
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
963+
copy_pdf_to_png(kpms_project_output_dir, model_name)
958964

959965
else:
960966
duration_seconds = None

element_moseq/plotting/viz_utils.py

Lines changed: 7 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,72 +13,6 @@
1313

1414
logger = dj.logger
1515

16-
_DLC_SUFFIX_RE = re.compile(
17-
r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token
18-
r"(?:\w+)*)" # optional extra blobs
19-
r"(?:shuffle\d+)?" # shuffleN
20-
r"(?:_\d+)?$" # _iter
21-
)
22-
23-
24-
def _normalize_name(name: str) -> str:
25-
"""
26-
Normalize a recording/video string for matching:
27-
- lowercase, strip whitespace
28-
- drop extension if present
29-
- remove common DLC suffix blob (e.g., '...DLC_resnet50_...shuffle1_500000')
30-
- collapse separators to single spaces
31-
"""
32-
s = name.lower().strip()
33-
s = Path(s).stem
34-
s = _DLC_SUFFIX_RE.sub("", s) # strip DLC tail if present
35-
s = re.sub(r"[\s._-]+", " ", s).strip()
36-
return s
37-
38-
39-
def build_recording_to_video_id(
40-
recording_names: List[str],
41-
video_paths: List[str],
42-
video_ids: List[int],
43-
fuzzy_threshold: float = 0.80,
44-
) -> Dict[str, Optional[int]]:
45-
"""
46-
Returns: {recording_name -> video_id or None if no good match}
47-
Strategy: exact normalized stem match; if none, substring; then fuzzy.
48-
"""
49-
# candidate stems from videos
50-
stems: List[Tuple[str, int]] = [
51-
(_normalize_name(Path(p).name), vid) for p, vid in zip(video_paths, video_ids)
52-
]
53-
54-
mapping: Dict[str, Optional[int]] = {}
55-
56-
for rec in recording_names:
57-
nrec = _normalize_name(rec)
58-
59-
# 1) exact normalized match
60-
exact = [vid for stem, vid in stems if stem == nrec]
61-
if exact:
62-
mapping[rec] = exact[0]
63-
continue
64-
65-
# 2) substring either way (choose longest stem to disambiguate)
66-
subs = [(stem, vid) for stem, vid in stems if nrec in stem or stem in nrec]
67-
if subs:
68-
subs.sort(key=lambda x: len(x[0]), reverse=True)
69-
mapping[rec] = subs[0][1]
70-
continue
71-
72-
# 3) fuzzy best match
73-
best_vid, best_ratio = None, 0.0
74-
for stem, vid in stems:
75-
r = SequenceMatcher(None, nrec, stem).ratio()
76-
if r > best_ratio:
77-
best_ratio, best_vid = r, vid
78-
mapping[rec] = best_vid if best_ratio >= fuzzy_threshold else None
79-
80-
return mapping
81-
8216

8317
def plot_medoid_distance_outliers(
8418
project_dir: str,
@@ -131,13 +65,15 @@ def plot_medoid_distance_outliers(
13165
"""
13266
from keypoint_moseq.util import get_distance_to_medoid, plot_keypoint_traces
13367

68+
# Use QA directory for outlier plots
13469
plot_path = os.path.join(
13570
project_dir,
136-
"quality_assurance",
71+
"QA",
13772
"plots",
13873
"keypoint_distance_outliers",
13974
f"{recording_name}.png",
14075
)
76+
# Create directory if it doesn't exist
14177
os.makedirs(os.path.dirname(plot_path), exist_ok=True)
14278

14379
original_distances = get_distance_to_medoid(
@@ -318,9 +254,8 @@ def plot_pcs(
318254
plt.tight_layout()
319255

320256
if savefig:
321-
assert project_dir is not None, fill(
322-
"The `savefig` option requires a `project_dir`"
323-
)
257+
if project_dir is None:
258+
raise ValueError(fill("The `savefig` option requires a `project_dir`"))
324259
plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf"))
325260
plt.show()
326261

@@ -341,18 +276,11 @@ def copy_pdf_to_png(project_dir, model_name):
341276
"""
342277
Convert PDF progress plot to PNG format using pdf2image.
343278
The fit_model function generates a single fitting_progress.pdf file.
344-
This function should always succeed if the PDF exists.
345279
346280
Args:
347281
project_dir (Path): Project directory path
348282
model_name (str): Model name directory
349283
350-
Returns:
351-
bool: True if conversion was successful, False otherwise
352-
353-
Raises:
354-
FileNotFoundError: If the PDF file doesn't exist
355-
RuntimeError: If conversion fails
356284
"""
357285
from pdf2image import convert_from_path
358286

@@ -361,15 +289,12 @@ def copy_pdf_to_png(project_dir, model_name):
361289
pdf_path = model_dir / "fitting_progress.pdf"
362290
png_path = model_dir / "fitting_progress.png"
363291

364-
# Check if PDF exists
365292
if not pdf_path.exists():
366293
raise FileNotFoundError(f"PDF progress plot not found at {pdf_path}")
367294

368-
# Convert PDF to PNG
369-
images = convert_from_path(pdf_path, dpi=300)
295+
images = convert_from_path(str(pdf_path), dpi=300)
370296
if not images:
371-
raise ValueError(f"No PDF file found at {pdf_path}")
297+
raise ValueError(f"Could not convert PDF at {pdf_path} (no images returned)")
372298

373299
images[0].save(png_path, "PNG")
374300
logger.info(f"Generated PNG progress plot at {png_path}")
375-
return True

element_moseq/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Package metadata
33
"""
44

5-
__version__ = "1.0.2"
5+
__version__ = "1.1.0"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "element-moseq"
7-
version = "1.0.2"
7+
version = "1.1.0"
88
description = "Keypoint-MoSeq DataJoint Element"
99
readme = "README.md"
1010
license = {text = "MIT"}

0 commit comments

Comments
 (0)