|
16 | 16 | # License along with C-PAC. If not, see <https://www.gnu.org/licenses/>. |
17 | 17 | """C-PAC Configuration class and related functions.""" |
18 | 18 |
|
| 19 | +from collections.abc import Iterable |
19 | 20 | import os |
20 | 21 | import re |
21 | 22 | from typing import Any, cast, Literal, Optional, overload |
|
29 | 30 | from .diff import dct_diff |
30 | 31 |
|
31 | 32 | CONFIG_KEY_TYPE = str | list[str] |
| 33 | +_DICT = dict |
32 | 34 | SPECIAL_REPLACEMENT_STRINGS = {r"${resolution_for_anat}", r"${func_resolution}"} |
33 | 35 |
|
34 | 36 |
|
@@ -187,27 +189,42 @@ def __init__( |
187 | 189 | os.environ["CPAC_WORKDIR"] = self["pipeline_setup", "working_directory", "path"] |
188 | 190 |
|
189 | 191 | def __str__(self): |
| 192 | + """Return string representation of a Configuration instance.""" |
190 | 193 | return f"C-PAC Configuration ('{self['pipeline_setup', 'pipeline_name']}')" |
191 | 194 |
|
192 | 195 | def __repr__(self): |
193 | 196 | """Show Configuration as a dict when accessed directly.""" |
194 | 197 | return str(self.dict()) |
195 | 198 |
|
| 199 | + def __contains__(self, item: str | list[Any]) -> bool: |
| 200 | + """Check if an item is in the Configuration.""" |
| 201 | + if isinstance(item, str): |
| 202 | + return item in self.keys() |
| 203 | + try: |
| 204 | + self.get_nested(self, item) |
| 205 | + return True |
| 206 | + except KeyError: |
| 207 | + return False |
| 208 | + |
196 | 209 | def __copy__(self): |
197 | 210 | newone = type(self)({}) |
198 | 211 | newone.__dict__.update(self.__dict__) |
199 | 212 | newone._update_attr() |
200 | 213 | return newone |
201 | 214 |
|
202 | | - def __getitem__(self, key): |
| 215 | + def __getitem__(self, key: Iterable) -> Any: |
| 216 | + """Get an item from a Configuration.""" |
| 217 | + self._check_keys(key) |
203 | 218 | if isinstance(key, str): |
204 | 219 | return getattr(self, key) |
205 | 220 | if isinstance(key, (list, tuple)): |
206 | 221 | return self.get_nested(self, key) |
207 | 222 | self.key_type_error(key) |
208 | 223 | return None |
209 | 224 |
|
210 | | - def __setitem__(self, key, value): |
| 225 | + def __setitem__(self, key: Iterable, value: Any) -> None: |
| 226 | + """Set an item in a Configuration.""" |
| 227 | + self._check_keys(key) |
211 | 228 | if isinstance(key, str): |
212 | 229 | setattr(self, key, value) |
213 | 230 | elif isinstance(key, (list, tuple)): |
@@ -432,24 +449,50 @@ def update(self, key, val=ConfigurationDictUpdateConflation()): |
432 | 449 | raise val |
433 | 450 | setattr(self, key, val) |
434 | 451 |
|
435 | | - def get_nested(self, _d, keys): |
| 452 | + @staticmethod |
| 453 | + def _check_keys(keys: Iterable) -> None: |
| 454 | + """Check that keys are iterable and at least 1 key is provided.""" |
| 455 | + if not keys: |
| 456 | + if isinstance(keys, Iterable): |
| 457 | + error = KeyError |
| 458 | + msg = "No keys provided to `set_nested`." |
| 459 | + else: |
| 460 | + error = TypeError |
| 461 | + msg = f"`set_nested` keys must be iterable, got {type(keys)}." |
| 462 | + raise error(msg) |
| 463 | + |
| 464 | + def get_nested(self, _d: "Configuration | _DICT", keys: Iterable) -> Any: |
| 465 | + """Get a value from a Configuration dictionary given a nested key.""" |
| 466 | + self._check_keys(keys) |
436 | 467 | if _d is None: |
437 | 468 | _d = {} |
438 | 469 | if isinstance(keys, str): |
439 | 470 | return _d[keys] |
440 | 471 | if isinstance(keys, (list, tuple)): |
441 | 472 | if len(keys) > 1: |
442 | 473 | return self.get_nested(_d[keys[0]], keys[1:]) |
| 474 | + assert len(keys) == 1 |
443 | 475 | return _d[keys[0]] |
444 | 476 | return _d |
445 | 477 |
|
446 | | - def set_nested(self, d, keys, value): # pylint: disable=invalid-name |
| 478 | + @overload |
| 479 | + def set_nested( |
| 480 | + self, d: "Configuration", keys: Iterable, value: Any |
| 481 | + ) -> "Configuration": ... |
| 482 | + @overload |
| 483 | + def set_nested(self, d: _DICT, keys: Iterable, value: Any) -> _DICT: ... |
| 484 | + def set_nested( |
| 485 | + self, d: "Configuration | _DICT", keys: Iterable, value: Any |
| 486 | + ) -> "Configuration | _DICT": |
| 487 | + """Set a nested key in a Configuration dictionary.""" |
| 488 | + self._check_keys(keys) |
447 | 489 | if isinstance(keys, str): |
448 | 490 | d[keys] = value |
449 | 491 | elif isinstance(keys, (list, tuple)): |
450 | 492 | if len(keys) > 1: |
451 | 493 | d[keys[0]] = self.set_nested(d[keys[0]], keys[1:], value) |
452 | 494 | else: |
| 495 | + assert len(keys) == 1 |
453 | 496 | d[keys[0]] = value |
454 | 497 | return d |
455 | 498 |
|
|
0 commit comments