Skip to content

Commit 89dfe66

Browse files
committed
fix(kpms_readers): copy config.yml appropriately
1 parent bccc03d commit 89dfe66

File tree

1 file changed

+90
-33
lines changed

1 file changed

+90
-33
lines changed

element_moseq/readers/kpms_reader.py

Lines changed: 90 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Dict, Union
3+
from typing import Any, Dict, List, Union
44

55
import datajoint as dj
66
import yaml
@@ -82,9 +82,10 @@ def _check_config_validity(config: Dict[str, Any]) -> bool:
8282
f"ACTION REQUIRED: `posterior_bodyparts` contains {bp} "
8383
"which is not one of the options in `use_bodyparts`."
8484
)
85+
8586
if errors:
86-
for e in errors:
87-
print(e)
87+
for error in errors:
88+
logger.warning(error)
8889
return False
8990
return True
9091

@@ -122,19 +123,44 @@ def dj_generate_config(kpms_project_dir: str, **kwargs) -> tuple:
122123
else:
123124
if not Path(kpms_base_config_path).exists():
124125
raise FileNotFoundError(
125-
f"Missing KPMS base config at {kpms_base_config_path}. "
126-
f"Run keypoint_moseq's setup_project first. "
127-
f"Expected either config.yml or config.yaml in {kpms_project_dir}."
126+
f"Missing KPMS base config at {kpms_base_config_path}"
128127
)
129128
kpms_dj_config_dict = kpms_base_config_dict.copy()
130129

130+
# Update bodyparts if provided
131+
if "bodyparts" in kwargs:
132+
kpms_dj_config_dict["bodyparts"] = list(kwargs["bodyparts"])
133+
134+
if "use_bodyparts" in kwargs:
135+
use_bodyparts = list(kwargs["use_bodyparts"])
136+
kpms_dj_config_dict["use_bodyparts"] = use_bodyparts
137+
138+
# Filter anterior/posterior to be subsets of use_bodyparts
139+
if "anterior_bodyparts" in kwargs:
140+
anterior = [
141+
bp for bp in kwargs["anterior_bodyparts"] if bp in use_bodyparts
142+
]
143+
kwargs["anterior_bodyparts"] = anterior
144+
145+
if "posterior_bodyparts" in kwargs:
146+
posterior = [
147+
bp for bp in kwargs["posterior_bodyparts"] if bp in use_bodyparts
148+
]
149+
kwargs["posterior_bodyparts"] = posterior
150+
131151
kpms_dj_config_dict.update(kwargs)
132152

133-
if "skeleton" not in kpms_dj_config_dict or kpms_dj_config_dict["skeleton"] is None:
153+
if "skeleton" not in kpms_dj_config_dict:
134154
kpms_dj_config_dict["skeleton"] = []
135155

136156
with open(kpms_dj_config_path, "w") as f:
137-
yaml.safe_dump(kpms_dj_config_dict, f, sort_keys=False)
157+
yaml.safe_dump(
158+
kpms_dj_config_dict,
159+
f,
160+
sort_keys=False,
161+
default_flow_style=False,
162+
allow_unicode=True,
163+
)
138164

139165
return (
140166
kpms_dj_config_path,
@@ -173,13 +199,10 @@ def load_kpms_dj_config(
173199
"""
174200
import jax.numpy as jnp
175201

176-
# Validate input parameters
177202
if kpms_project_dir is None and config_path is None:
178-
raise ValueError("Either 'kpms_project_dir' or 'config_path' must be provided.")
203+
raise ValueError("Either 'kpms_project_dir' or 'config_path' must be provided")
179204
if kpms_project_dir is not None and config_path is not None:
180-
raise ValueError(
181-
"Cannot provide both 'kpms_project_dir' and 'config_path'. Choose one."
182-
)
205+
raise ValueError("Cannot provide both 'kpms_project_dir' and 'config_path'")
183206

184207
# Determine the config file path
185208
if config_path is not None:
@@ -188,9 +211,7 @@ def load_kpms_dj_config(
188211
kpms_dj_cfg_path = _kpms_dj_config_path(kpms_project_dir)
189212

190213
if not Path(kpms_dj_cfg_path).exists():
191-
raise FileNotFoundError(
192-
f"Missing DJ config at {kpms_dj_cfg_path}. Create it with dj_generate_config()."
193-
)
214+
raise FileNotFoundError(f"Missing DJ config at {kpms_dj_cfg_path}")
194215

195216
with open(kpms_dj_cfg_path, "r") as f:
196217
cfg_dict = yaml.safe_load(f) or {}
@@ -202,17 +223,28 @@ def load_kpms_dj_config(
202223
anterior = cfg_dict.get("anterior_bodyparts", [])
203224
posterior = cfg_dict.get("posterior_bodyparts", [])
204225
use_bps = cfg_dict.get("use_bodyparts", [])
205-
cfg_dict["anterior_idxs"] = jnp.array([use_bps.index(bp) for bp in anterior])
206-
cfg_dict["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior])
207226

208-
if "skeleton" not in cfg_dict or cfg_dict["skeleton"] is None:
227+
valid_anterior = [bp for bp in anterior if bp in use_bps]
228+
valid_posterior = [bp for bp in posterior if bp in use_bps]
229+
230+
cfg_dict["anterior_idxs"] = jnp.array(
231+
[use_bps.index(bp) for bp in valid_anterior]
232+
)
233+
cfg_dict["posterior_idxs"] = jnp.array(
234+
[use_bps.index(bp) for bp in valid_posterior]
235+
)
236+
237+
if "skeleton" not in cfg_dict:
209238
cfg_dict["skeleton"] = []
210239

211240
return cfg_dict
212241

213242

214243
def update_kpms_dj_config(
215-
kpms_project_dir: str = None, config_dict: Dict[str, Any] = None, **kwargs
244+
kpms_project_dir: str = None,
245+
config_dict: Dict[str, Any] = None,
246+
config_path: str = None,
247+
**kwargs,
216248
) -> Dict[str, Any]:
217249
"""
218250
Update kpms_dj_config with provided kwargs.
@@ -233,32 +265,57 @@ def update_kpms_dj_config(
233265
If kpms_project_dir is provided, loads the config from file, updates it, saves it back, and returns it.
234266
If config_dict is provided, updates it directly and returns it (no file I/O).
235267
"""
236-
# Validate input parameters
268+
237269
if kpms_project_dir is None and config_dict is None:
238-
raise ValueError("Either 'kpms_project_dir' or 'config_dict' must be provided.")
239-
if kpms_project_dir is not None and config_dict is not None:
240-
raise ValueError(
241-
"Cannot provide both 'kpms_project_dir' and 'config_dict'. Choose one."
242-
)
270+
raise ValueError("Either 'kpms_project_dir' or 'config_dict' must be provided")
243271

244-
# Load from file if kpms_project_dir is provided
245272
if kpms_project_dir is not None:
246273
kpms_dj_cfg_path = _kpms_dj_config_path(kpms_project_dir)
247274
if not Path(kpms_dj_cfg_path).exists():
248-
raise FileNotFoundError(
249-
f"Missing DJ config at {kpms_dj_cfg_path}. Create it with dj_generate_config()."
250-
)
275+
raise FileNotFoundError(f"Missing DJ config at {kpms_dj_cfg_path}")
251276

252277
with open(kpms_dj_cfg_path, "r") as f:
253278
cfg_dict = yaml.safe_load(f) or {}
254279

280+
if "bodyparts" in kwargs:
281+
cfg_dict["bodyparts"] = list(kwargs.get("bodyparts"))
282+
283+
if "use_bodyparts" in kwargs:
284+
use_bodyparts = list(kwargs.get("use_bodyparts"))
285+
cfg_dict["use_bodyparts"] = use_bodyparts
286+
# NOTE: skeleton is NOT modified - it remains from the base config
287+
255288
cfg_dict.update(kwargs)
256289

257290
with open(kpms_dj_cfg_path, "w") as f:
258-
yaml.safe_dump(cfg_dict, f, sort_keys=False)
291+
yaml.safe_dump(
292+
cfg_dict,
293+
f,
294+
sort_keys=False,
295+
default_flow_style=False,
296+
allow_unicode=True,
297+
)
259298
else:
260-
# Update the provided dict directly (no file I/O)
261-
cfg_dict = config_dict.copy() # Make a copy to avoid mutating the input
299+
cfg_dict = config_dict.copy()
300+
301+
if "bodyparts" in kwargs:
302+
cfg_dict["bodyparts"] = list(kwargs.get("bodyparts"))
303+
304+
if "use_bodyparts" in kwargs:
305+
use_bodyparts = list(kwargs.get("use_bodyparts"))
306+
cfg_dict["use_bodyparts"] = use_bodyparts
307+
# NOTE: skeleton is NOT modified - it remains from the base config
308+
262309
cfg_dict.update(kwargs)
263310

311+
if config_path is not None:
312+
with open(config_path, "w") as f:
313+
yaml.safe_dump(
314+
cfg_dict,
315+
f,
316+
sort_keys=False,
317+
default_flow_style=False,
318+
allow_unicode=True,
319+
)
320+
264321
return cfg_dict

0 commit comments

Comments
 (0)