11import os
22from pathlib import Path
3- from typing import Any , Dict , Union
3+ from typing import Any , Dict , List , Union
44
55import datajoint as dj
66import yaml
@@ -82,9 +82,10 @@ def _check_config_validity(config: Dict[str, Any]) -> bool:
8282 f"ACTION REQUIRED: `posterior_bodyparts` contains { bp } "
8383 "which is not one of the options in `use_bodyparts`."
8484 )
85+
8586 if errors :
86- for e in errors :
87- print ( e )
87+ for error in errors :
88+ logger . warning ( error )
8889 return False
8990 return True
9091
@@ -122,19 +123,44 @@ def dj_generate_config(kpms_project_dir: str, **kwargs) -> tuple:
122123 else :
123124 if not Path (kpms_base_config_path ).exists ():
124125 raise FileNotFoundError (
125- f"Missing KPMS base config at { kpms_base_config_path } . "
126- f"Run keypoint_moseq's setup_project first. "
127- f"Expected either config.yml or config.yaml in { kpms_project_dir } ."
126+ f"Missing KPMS base config at { kpms_base_config_path } "
128127 )
129128 kpms_dj_config_dict = kpms_base_config_dict .copy ()
130129
130+ # Update bodyparts if provided
131+ if "bodyparts" in kwargs :
132+ kpms_dj_config_dict ["bodyparts" ] = list (kwargs ["bodyparts" ])
133+
134+ if "use_bodyparts" in kwargs :
135+ use_bodyparts = list (kwargs ["use_bodyparts" ])
136+ kpms_dj_config_dict ["use_bodyparts" ] = use_bodyparts
137+
138+ # Filter anterior/posterior to be subsets of use_bodyparts
139+ if "anterior_bodyparts" in kwargs :
140+ anterior = [
141+ bp for bp in kwargs ["anterior_bodyparts" ] if bp in use_bodyparts
142+ ]
143+ kwargs ["anterior_bodyparts" ] = anterior
144+
145+ if "posterior_bodyparts" in kwargs :
146+ posterior = [
147+ bp for bp in kwargs ["posterior_bodyparts" ] if bp in use_bodyparts
148+ ]
149+ kwargs ["posterior_bodyparts" ] = posterior
150+
131151 kpms_dj_config_dict .update (kwargs )
132152
133- if "skeleton" not in kpms_dj_config_dict or kpms_dj_config_dict [ "skeleton" ] is None :
153+ if "skeleton" not in kpms_dj_config_dict :
134154 kpms_dj_config_dict ["skeleton" ] = []
135155
136156 with open (kpms_dj_config_path , "w" ) as f :
137- yaml .safe_dump (kpms_dj_config_dict , f , sort_keys = False )
157+ yaml .safe_dump (
158+ kpms_dj_config_dict ,
159+ f ,
160+ sort_keys = False ,
161+ default_flow_style = False ,
162+ allow_unicode = True ,
163+ )
138164
139165 return (
140166 kpms_dj_config_path ,
@@ -173,13 +199,10 @@ def load_kpms_dj_config(
173199 """
174200 import jax .numpy as jnp
175201
176- # Validate input parameters
177202 if kpms_project_dir is None and config_path is None :
178- raise ValueError ("Either 'kpms_project_dir' or 'config_path' must be provided. " )
203+ raise ValueError ("Either 'kpms_project_dir' or 'config_path' must be provided" )
179204 if kpms_project_dir is not None and config_path is not None :
180- raise ValueError (
181- "Cannot provide both 'kpms_project_dir' and 'config_path'. Choose one."
182- )
205+ raise ValueError ("Cannot provide both 'kpms_project_dir' and 'config_path'" )
183206
184207 # Determine the config file path
185208 if config_path is not None :
@@ -188,9 +211,7 @@ def load_kpms_dj_config(
188211 kpms_dj_cfg_path = _kpms_dj_config_path (kpms_project_dir )
189212
190213 if not Path (kpms_dj_cfg_path ).exists ():
191- raise FileNotFoundError (
192- f"Missing DJ config at { kpms_dj_cfg_path } . Create it with dj_generate_config()."
193- )
214+ raise FileNotFoundError (f"Missing DJ config at { kpms_dj_cfg_path } " )
194215
195216 with open (kpms_dj_cfg_path , "r" ) as f :
196217 cfg_dict = yaml .safe_load (f ) or {}
@@ -202,17 +223,28 @@ def load_kpms_dj_config(
202223 anterior = cfg_dict .get ("anterior_bodyparts" , [])
203224 posterior = cfg_dict .get ("posterior_bodyparts" , [])
204225 use_bps = cfg_dict .get ("use_bodyparts" , [])
205- cfg_dict ["anterior_idxs" ] = jnp .array ([use_bps .index (bp ) for bp in anterior ])
206- cfg_dict ["posterior_idxs" ] = jnp .array ([use_bps .index (bp ) for bp in posterior ])
207226
208- if "skeleton" not in cfg_dict or cfg_dict ["skeleton" ] is None :
227+ valid_anterior = [bp for bp in anterior if bp in use_bps ]
228+ valid_posterior = [bp for bp in posterior if bp in use_bps ]
229+
230+ cfg_dict ["anterior_idxs" ] = jnp .array (
231+ [use_bps .index (bp ) for bp in valid_anterior ]
232+ )
233+ cfg_dict ["posterior_idxs" ] = jnp .array (
234+ [use_bps .index (bp ) for bp in valid_posterior ]
235+ )
236+
237+ if "skeleton" not in cfg_dict :
209238 cfg_dict ["skeleton" ] = []
210239
211240 return cfg_dict
212241
213242
214243def update_kpms_dj_config (
215- kpms_project_dir : str = None , config_dict : Dict [str , Any ] = None , ** kwargs
244+ kpms_project_dir : str = None ,
245+ config_dict : Dict [str , Any ] = None ,
246+ config_path : str = None ,
247+ ** kwargs ,
216248) -> Dict [str , Any ]:
217249 """
218250 Update kpms_dj_config with provided kwargs.
@@ -233,32 +265,57 @@ def update_kpms_dj_config(
233265 If kpms_project_dir is provided, loads the config from file, updates it, saves it back, and returns it.
234266 If config_dict is provided, updates it directly and returns it (no file I/O).
235267 """
236- # Validate input parameters
268+
237269 if kpms_project_dir is None and config_dict is None :
238- raise ValueError ("Either 'kpms_project_dir' or 'config_dict' must be provided." )
239- if kpms_project_dir is not None and config_dict is not None :
240- raise ValueError (
241- "Cannot provide both 'kpms_project_dir' and 'config_dict'. Choose one."
242- )
270+ raise ValueError ("Either 'kpms_project_dir' or 'config_dict' must be provided" )
243271
244- # Load from file if kpms_project_dir is provided
245272 if kpms_project_dir is not None :
246273 kpms_dj_cfg_path = _kpms_dj_config_path (kpms_project_dir )
247274 if not Path (kpms_dj_cfg_path ).exists ():
248- raise FileNotFoundError (
249- f"Missing DJ config at { kpms_dj_cfg_path } . Create it with dj_generate_config()."
250- )
275+ raise FileNotFoundError (f"Missing DJ config at { kpms_dj_cfg_path } " )
251276
252277 with open (kpms_dj_cfg_path , "r" ) as f :
253278 cfg_dict = yaml .safe_load (f ) or {}
254279
280+ if "bodyparts" in kwargs :
281+ cfg_dict ["bodyparts" ] = list (kwargs .get ("bodyparts" ))
282+
283+ if "use_bodyparts" in kwargs :
284+ use_bodyparts = list (kwargs .get ("use_bodyparts" ))
285+ cfg_dict ["use_bodyparts" ] = use_bodyparts
286+ # NOTE: skeleton is NOT modified - it remains from the base config
287+
255288 cfg_dict .update (kwargs )
256289
257290 with open (kpms_dj_cfg_path , "w" ) as f :
258- yaml .safe_dump (cfg_dict , f , sort_keys = False )
291+ yaml .safe_dump (
292+ cfg_dict ,
293+ f ,
294+ sort_keys = False ,
295+ default_flow_style = False ,
296+ allow_unicode = True ,
297+ )
259298 else :
260- # Update the provided dict directly (no file I/O)
261- cfg_dict = config_dict .copy () # Make a copy to avoid mutating the input
299+ cfg_dict = config_dict .copy ()
300+
301+ if "bodyparts" in kwargs :
302+ cfg_dict ["bodyparts" ] = list (kwargs .get ("bodyparts" ))
303+
304+ if "use_bodyparts" in kwargs :
305+ use_bodyparts = list (kwargs .get ("use_bodyparts" ))
306+ cfg_dict ["use_bodyparts" ] = use_bodyparts
307+ # NOTE: skeleton is NOT modified - it remains from the base config
308+
262309 cfg_dict .update (kwargs )
263310
311+ if config_path is not None :
312+ with open (config_path , "w" ) as f :
313+ yaml .safe_dump (
314+ cfg_dict ,
315+ f ,
316+ sort_keys = False ,
317+ default_flow_style = False ,
318+ allow_unicode = True ,
319+ )
320+
264321 return cfg_dict
0 commit comments