Skip to content

Commit c50936d

Browse files
authored
chore: jax<0.7.0, other minor code cleanup
1 parent 20dd71b commit c50936d

File tree

6 files changed

+15
-8
lines changed

6 files changed

+15
-8
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [1.0.2] - 2025-10-07
7+
+ Update - `keypoint-moseq` as extra dependency
8+
+ Fix - Version pin `jax<0.7.0`
9+
610
## [1.0.1] - 2025-09-23
711
+ Feat - Add support to generate PNG version of fitting progress plots in `PreFit`, `FullFit`, and `moseq_report` schema
812
+ Fix - Update path handling to use `Path` objects and `dj.logger`

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ DataJoint Elements collectively standardize and automate data collection and ana
2525
+ Install with `pip`:
2626

2727
```bash
28-
pip install -e .
28+
pip install -e .[keypoint-moseq]
2929
```
3030

3131
+ [Interactive tutorial on GitHub Codespaces](https://github.com/datajoint/element-moseq#interactive-tutorial)

element_moseq/plotting/viz_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
import datajoint as dj
1111
import matplotlib.pyplot as plt
1212
import numpy as np
13-
from jax_moseq.models.keypoint_slds import center_embedding
14-
from keypoint_moseq.util import get_distance_to_medoid, get_edges, plot_keypoint_traces
15-
from keypoint_moseq.viz import plot_pcs_3D
1613

1714
logger = dj.logger
1815

@@ -132,6 +129,7 @@ def plot_medoid_distance_outliers(
132129
None
133130
The plot is saved to 'QA/plots/keypoint_distance_outliers/{recording_name}.png'.
134131
"""
132+
from keypoint_moseq.util import get_distance_to_medoid, plot_keypoint_traces
135133

136134
plot_path = os.path.join(
137135
project_dir,
@@ -241,6 +239,10 @@ def plot_pcs(
241239
interactive : bool, default=True
242240
For 3D data, whether to generate an interactive 3D plot.
243241
"""
242+
from jax_moseq.models.keypoint_slds import center_embedding
243+
from keypoint_moseq.util import get_edges
244+
from keypoint_moseq.viz import plot_pcs_3D
245+
244246
k = len(use_bodyparts)
245247
d = len(pca.mean_) // (k - 1)
246248

element_moseq/readers/kpms_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Dict, Union
44

55
import datajoint as dj
6-
import jax.numpy as jnp
76
import yaml
87

98
logger = dj.logger
@@ -103,6 +102,8 @@ def load_kpms_dj_config(
103102
- build_indexes -> adds jax arrays 'anterior_idxs' and 'posterior_idxs'
104103
indexing into 'use_bodyparts' by order.
105104
"""
105+
import jax.numpy as jnp
106+
106107
dj_cfg_path = _dj_config_path(project_dir)
107108
if not Path(dj_cfg_path).exists():
108109
raise FileNotFoundError(

element_moseq/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Package metadata
33
"""
44

5-
__version__ = "1.0.1"
5+
__version__ = "1.0.2"

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "element-moseq"
7-
version = "1.0.1"
7+
version = "1.0.2"
88
description = "Keypoint-MoSeq DataJoint Element"
99
readme = "README.md"
1010
license = {text = "MIT"}
@@ -24,7 +24,6 @@ dependencies = [
2424
"ipykernel>=6.0.1",
2525
"ipywidgets",
2626
"opencv-python",
27-
"keypoint-moseq @ git+https://github.com/dattalab/keypoint-moseq/",
2827
"pdf2image",
2928
]
3029

@@ -35,6 +34,7 @@ elements = [
3534
"element-interface @ git+https://github.com/datajoint/element-interface.git",
3635
"element-animal @ git+https://github.com/datajoint/element-animal.git",
3736
]
37+
keypoint-moseq = ["jax<0.7.0", "keypoint-moseq @ git+https://github.com/dattalab/keypoint-moseq.git"]
3838
tests = [
3939
"pytest",
4040
"pytest-cov",

0 commit comments

Comments
 (0)