Skip to content

Commit c8410cc

Browse files
committed
🐛 Fix Configuration.__contains__
1 parent 72dc0f2 commit c8410cc

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

CPAC/utils/configuration/configuration.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""C-PAC Configuration class and related functions."""
1818

19+
from collections.abc import Iterable
1920
import os
2021
import re
2122
from typing import Any, cast, Literal, Optional, overload
@@ -29,6 +30,7 @@
2930
from .diff import dct_diff
3031

3132
CONFIG_KEY_TYPE = str | list[str]
33+
_DICT = dict
3234
SPECIAL_REPLACEMENT_STRINGS = {r"${resolution_for_anat}", r"${func_resolution}"}
3335

3436

@@ -187,27 +189,42 @@ def __init__(
187189
os.environ["CPAC_WORKDIR"] = self["pipeline_setup", "working_directory", "path"]
188190

189191
def __str__(self):
192+
"""Return string representation of a Configuration instance."""
190193
return f"C-PAC Configuration ('{self['pipeline_setup', 'pipeline_name']}')"
191194

192195
def __repr__(self):
193196
"""Show Configuration as a dict when accessed directly."""
194197
return str(self.dict())
195198

199+
def __contains__(self, item: str | list[Any]) -> bool:
200+
"""Check if an item is in the Configuration."""
201+
if isinstance(item, str):
202+
return item in self.keys()
203+
try:
204+
self.get_nested(self, item)
205+
return True
206+
except KeyError:
207+
return False
208+
196209
def __copy__(self):
197210
newone = type(self)({})
198211
newone.__dict__.update(self.__dict__)
199212
newone._update_attr()
200213
return newone
201214

202-
def __getitem__(self, key):
215+
def __getitem__(self, key: Iterable) -> Any:
216+
"""Get an item from a Configuration."""
217+
self._check_keys(key)
203218
if isinstance(key, str):
204219
return getattr(self, key)
205220
if isinstance(key, (list, tuple)):
206221
return self.get_nested(self, key)
207222
self.key_type_error(key)
208223
return None
209224

210-
def __setitem__(self, key, value):
225+
def __setitem__(self, key: Iterable, value: Any) -> None:
226+
"""Set an item in a Configuration."""
227+
self._check_keys(key)
211228
if isinstance(key, str):
212229
setattr(self, key, value)
213230
elif isinstance(key, (list, tuple)):
@@ -432,24 +449,50 @@ def update(self, key, val=ConfigurationDictUpdateConflation()):
432449
raise val
433450
setattr(self, key, val)
434451

435-
def get_nested(self, _d, keys):
452+
@staticmethod
453+
def _check_keys(keys: Iterable) -> None:
454+
"""Check that keys are iterable and at least 1 key is provided."""
455+
if not keys:
456+
if isinstance(keys, Iterable):
457+
error = KeyError
458+
msg = "No keys provided to `set_nested`."
459+
else:
460+
error = TypeError
461+
msg = f"`set_nested` keys must be iterable, got {type(keys)}."
462+
raise error(msg)
463+
464+
def get_nested(self, _d: "Configuration | _DICT", keys: Iterable) -> Any:
465+
"""Get a value from a Configuration dictionary given a nested key."""
466+
self._check_keys(keys)
436467
if _d is None:
437468
_d = {}
438469
if isinstance(keys, str):
439470
return _d[keys]
440471
if isinstance(keys, (list, tuple)):
441472
if len(keys) > 1:
442473
return self.get_nested(_d[keys[0]], keys[1:])
474+
assert len(keys) == 1
443475
return _d[keys[0]]
444476
return _d
445477

446-
def set_nested(self, d, keys, value): # pylint: disable=invalid-name
478+
@overload
479+
def set_nested(
480+
self, d: "Configuration", keys: Iterable, value: Any
481+
) -> "Configuration": ...
482+
@overload
483+
def set_nested(self, d: _DICT, keys: Iterable, value: Any) -> _DICT: ...
484+
def set_nested(
485+
self, d: "Configuration | _DICT", keys: Iterable, value: Any
486+
) -> "Configuration | _DICT":
487+
"""Set a nested key in a Configuration dictionary."""
488+
self._check_keys(keys)
447489
if isinstance(keys, str):
448490
d[keys] = value
449491
elif isinstance(keys, (list, tuple)):
450492
if len(keys) > 1:
451493
d[keys[0]] = self.set_nested(d[keys[0]], keys[1:], value)
452494
else:
495+
assert len(keys) == 1
453496
d[keys[0]] = value
454497
return d
455498

CPAC/utils/monitoring/custom_logging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,15 @@ def init_loggers(
329329

330330
set_up_logger(
331331
f"{subject_id}_expectedOutputs",
332-
filename=f'{bidsier_prefix(cpac_config["subject_id"])}_' 'expectedOutputs.yml',
332+
filename=f"{bidsier_prefix(cpac_config['subject_id'])}_expectedOutputs.yml",
333333
level="info",
334334
log_dir=log_dir,
335335
mock=mock,
336336
overwrite_existing=( # don't overwrite if we have a longitudinal template
337337
longitudinal or not cpac_config["longitudinal_template_generation", "run"]
338338
),
339339
)
340+
340341
if cpac_config["pipeline_setup", "Debugging", "verbose"]:
341342
set_up_logger("CPAC.engine", level="debug", log_dir=log_dir, mock=True)
342343

@@ -361,5 +362,7 @@ def init_loggers(
361362
},
362363
}
363364
)
365+
364366
nipype_config.enable_resource_monitor()
367+
365368
nipype_logging.update_logging(nipype_config)

0 commit comments

Comments
 (0)