Skip to content

Commit 20dd71b

Browse files
authored
Merge pull request #18 from MilagrosMarin/update_docs
feat: Add support to generate PNG version of fitting progress plots & refactor
2 parents 5755af2 + 1c23286 commit 20dd71b

File tree

7 files changed

+97
-21
lines changed

7 files changed

+97
-21
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [1.0.1] - 2025-09-23
7+
+ Feat - Add support to generate PNG version of fitting progress plots in `PreFit`, `FullFit`, and `moseq_report` schema
8+
+ Fix - Update path handling to use `Path` objects and `dj.logger`
9+
610
## [1.0.0] - 2025-09-10
711

812
> **BREAKING CHANGES** - This version contains breaking changes due to keypoint-moseq upgrade and API refactoring. Please review the changes below and update your code accordingly.

element_moseq/moseq_report.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,43 +198,67 @@ class PreFitReport(dj.Imported):
198198
-> moseq_train.PreFit
199199
---
200200
fitting_progress_pdf: attach # fitting_progress.pdf
201+
fitting_progress_png: attach # fitting_progress.png
201202
"""
202203

203204
def make(self, key):
204205
prefit_model_name = (moseq_train.PreFit & key).fetch1("model_name")
205206
prefit_model_dir = find_full_path(
206207
moseq_train.get_kpms_processed_data_dir(), prefit_model_name
207208
)
208-
prefit_output_dir = Path(prefit_model_dir) / "fitting_progress.pdf"
209-
if prefit_output_dir.exists():
210-
self.insert1({**key, "fitting_progress_pdf": prefit_output_dir})
211-
else:
209+
210+
pdf_path = prefit_model_dir / "fitting_progress.pdf"
211+
png_path = prefit_model_dir / "fitting_progress.png"
212+
213+
if not pdf_path.exists():
214+
raise FileNotFoundError(
215+
f"PreFit PDF progress plot not found at {pdf_path}. "
216+
)
217+
218+
if not png_path.exists():
212219
raise FileNotFoundError(
213-
f"PreFit fitting_progress.pdf not found at {prefit_output_dir}"
220+
f"PreFit PNG progress plot not found at {png_path}. "
214221
)
215222

223+
# Both files exist, insert them
224+
self.insert1(
225+
{**key, "fitting_progress_pdf": pdf_path, "fitting_progress_png": png_path}
226+
)
227+
216228

217229
@schema
218230
class FullFitReport(dj.Imported):
219231
definition = """
220232
-> moseq_train.FullFit
221233
---
222234
fitting_progress_pdf: attach # fitting_progress.pdf
235+
fitting_progress_png: attach # fitting_progress.png
223236
"""
224237

225238
def make(self, key):
226239
fullfit_model_name = (moseq_train.FullFit & key).fetch1("model_name")
227240
fullfit_model_dir = find_full_path(
228241
moseq_train.get_kpms_processed_data_dir(), fullfit_model_name
229242
)
230-
fullfit_output_file = Path(fullfit_model_dir) / "fitting_progress.pdf"
231-
if fullfit_output_file.exists():
232-
self.insert1({**key, "fitting_progress_pdf": fullfit_output_file})
233-
else:
243+
244+
pdf_path = fullfit_model_dir / "fitting_progress.pdf"
245+
png_path = fullfit_model_dir / "fitting_progress.png"
246+
247+
if not pdf_path.exists():
248+
raise FileNotFoundError(
249+
f"FullFit PDF progress plot not found at {pdf_path}. "
250+
)
251+
252+
if not png_path.exists():
234253
raise FileNotFoundError(
235-
f"FullFit fitting_progress.pdf not found at {fullfit_output_file}"
254+
f"FullFit PNG progress plot not found at {png_path}. "
236255
)
237256

257+
# Both files exist, insert them
258+
self.insert1(
259+
{**key, "fitting_progress_pdf": pdf_path, "fitting_progress_png": png_path}
260+
)
261+
238262

239263
@schema
240264
class InferenceReport(dj.Imported):

element_moseq/moseq_train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import importlib
77
import inspect
8-
import os
98
from datetime import datetime, timezone
109
from pathlib import Path
1110
from typing import Optional
@@ -15,6 +14,7 @@
1514
import numpy as np
1615
from element_interface.utils import find_full_path
1716

17+
from .plotting import viz_utils
1818
from .readers import kpms_reader
1919

2020
schema = dj.schema()
@@ -790,6 +790,8 @@ def make(self, key):
790790
end_time = datetime.now(timezone.utc)
791791

792792
duration_seconds = (end_time - start_time).total_seconds()
793+
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
794+
793795
else:
794796
duration_seconds = None
795797

@@ -942,6 +944,8 @@ def make(self, key):
942944
project_dir=kpms_project_output_dir.as_posix(),
943945
ar_only=False,
944946
num_iters=full_num_iterations,
947+
generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/
948+
save_every_n_iters=25,
945949
)
946950
end_time = datetime.now(timezone.utc)
947951
duration_seconds = (end_time - start_time).total_seconds()
@@ -950,6 +954,7 @@ def make(self, key):
950954
project_dir=kpms_project_output_dir.as_posix(),
951955
model_name=Path(model_name).parts[-1],
952956
)
957+
viz_utils.copy_pdf_to_png(kpms_project_output_dir, model_name)
953958

954959
else:
955960
duration_seconds = None

element_moseq/plotting/viz_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
from textwrap import fill
88
from typing import Dict, List, Optional, Tuple
99

10+
import datajoint as dj
1011
import matplotlib.pyplot as plt
1112
import numpy as np
1213
from jax_moseq.models.keypoint_slds import center_embedding
1314
from keypoint_moseq.util import get_distance_to_medoid, get_edges, plot_keypoint_traces
1415
from keypoint_moseq.viz import plot_pcs_3D
1516

17+
logger = dj.logger
18+
1619
_DLC_SUFFIX_RE = re.compile(
1720
r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token
1821
r"(?:\w+)*)" # optional extra blobs
@@ -158,7 +161,9 @@ def plot_medoid_distance_outliers(
158161
fig.savefig(plot_path, dpi=300)
159162

160163
plt.close()
161-
print(f"Saved keypoint distance outlier plot for {recording_name} to {plot_path}.")
164+
logger.info(
165+
f"Saved keypoint distance outlier plot for {recording_name} to {plot_path}."
166+
)
162167
return fig
163168

164169

@@ -328,3 +333,41 @@ def plot_pcs(
328333
line_width * 2,
329334
)
330335
return fig
336+
337+
338+
def copy_pdf_to_png(project_dir, model_name):
339+
"""
340+
Convert PDF progress plot to PNG format using pdf2image.
341+
The fit_model function generates a single fitting_progress.pdf file.
342+
This function should always succeed if the PDF exists.
343+
344+
Args:
345+
project_dir (Path): Project directory path
346+
model_name (str): Model name directory
347+
348+
Returns:
349+
bool: True if conversion was successful, False otherwise
350+
351+
Raises:
352+
FileNotFoundError: If the PDF file doesn't exist
353+
RuntimeError: If conversion fails
354+
"""
355+
from pdf2image import convert_from_path
356+
357+
# Construct paths for PDF and PNG files
358+
model_dir = Path(project_dir) / model_name
359+
pdf_path = model_dir / "fitting_progress.pdf"
360+
png_path = model_dir / "fitting_progress.png"
361+
362+
# Check if PDF exists
363+
if not pdf_path.exists():
364+
raise FileNotFoundError(f"PDF progress plot not found at {pdf_path}")
365+
366+
# Convert PDF to PNG
367+
images = convert_from_path(pdf_path, dpi=300)
368+
if not images:
369+
raise ValueError(f"No PDF file found at {pdf_path}")
370+
371+
images[0].save(png_path, "PNG")
372+
logger.info(f"Generated PNG progress plot at {png_path}")
373+
return True

element_moseq/readers/kpms_reader.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import logging
21
import os
32
from pathlib import Path
43
from typing import Any, Dict, Union
54

5+
import datajoint as dj
66
import jax.numpy as jnp
77
import yaml
88

9-
logger = logging.getLogger("datajoint")
10-
9+
logger = dj.logger
1110

1211
DJ_CONFIG = "kpms_dj_config.yml"
1312
BASE_CONFIG = "config.yml"
@@ -72,11 +71,11 @@ def dj_generate_config(project_dir: str, **kwargs) -> str:
7271
base_cfg_path = _base_config_path(project_dir)
7372
dj_cfg_path = _dj_config_path(project_dir)
7473

75-
if os.path.exists(dj_cfg_path):
74+
if Path(dj_cfg_path).exists():
7675
with open(dj_cfg_path, "r") as f:
7776
cfg = yaml.safe_load(f) or {}
7877
else:
79-
if not os.path.exists(base_cfg_path):
78+
if not Path(base_cfg_path).exists():
8079
raise FileNotFoundError(
8180
f"Missing base config at {base_cfg_path}. Run upstream setup_project first. "
8281
f"Expected either config.yml or config.yaml in {project_dir}."
@@ -105,7 +104,7 @@ def load_kpms_dj_config(
105104
indexing into 'use_bodyparts' by order.
106105
"""
107106
dj_cfg_path = _dj_config_path(project_dir)
108-
if not os.path.exists(dj_cfg_path):
107+
if not Path(dj_cfg_path).exists():
109108
raise FileNotFoundError(
110109
f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()."
111110
)
@@ -135,7 +134,7 @@ def update_kpms_dj_config(project_dir: str, **kwargs) -> Dict[str, Any]:
135134
keypoint_moseq.io.update_config), then rewrite the file and return the dict.
136135
"""
137136
dj_cfg_path = _dj_config_path(project_dir)
138-
if not os.path.exists(dj_cfg_path):
137+
if not Path(dj_cfg_path).exists():
139138
raise FileNotFoundError(
140139
f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()."
141140
)

element_moseq/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Package metadata
33
"""
44

5-
__version__ = "1.0.0"
5+
__version__ = "1.0.1"

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "element-moseq"
7-
version = "1.0.0"
7+
version = "1.0.1"
88
description = "Keypoint-MoSeq DataJoint Element"
99
readme = "README.md"
1010
license = {text = "MIT"}
@@ -25,6 +25,7 @@ dependencies = [
2525
"ipywidgets",
2626
"opencv-python",
2727
"keypoint-moseq @ git+https://github.com/dattalab/keypoint-moseq/",
28+
"pdf2image",
2829
]
2930

3031
[project.optional-dependencies]

0 commit comments

Comments
 (0)