|
7 | 7 | The CLI supports configuration via both command line arguments and configuration files (.yaml or .json). |
8 | 8 | """ |
9 | 9 |
|
| 10 | +import ast |
| 11 | +import importlib |
10 | 12 | import logging |
| 13 | +import sys |
11 | 14 | from collections.abc import Callable, Sequence |
12 | 15 | from functools import partial |
13 | 16 | from pathlib import Path |
@@ -73,7 +76,9 @@ def __init__(self, args: Sequence[str] | None = None, run: bool = True) -> None: |
73 | 76 | self.parser = self.init_parser() |
74 | 77 | self.subcommand_parsers: dict[str, ArgumentParser] = {} |
75 | 78 | self.subcommand_method_arguments: dict[str, list[str]] = {} |
| 79 | + self._pre_processor_kwargs: dict[str, Any] = {} |
76 | 80 | self.add_subcommands() |
| 81 | + args = self._extract_pre_processor_args(args) |
77 | 82 | self.config = self.parser.parse_args(args=args) |
78 | 83 | self.subcommand = self.config["subcommand"] |
79 | 84 | if _LIGHTNING_AVAILABLE: |
@@ -299,6 +304,15 @@ def instantiate_classes(self) -> None: |
299 | 304 | But for subcommands we do not want to instantiate any trainer specific classes such as datamodule, model, etc |
300 | 305 | This is because the subcommand is responsible for instantiating and executing code based on the passed config |
301 | 306 | """ |
| 307 | + patch_info = self._patch_configure_pre_processor() if self._pre_processor_kwargs else None |
| 308 | + try: |
| 309 | + self._instantiate_classes() |
| 310 | + finally: |
| 311 | + if patch_info is not None: |
| 312 | + self._unpatch_configure_pre_processor(patch_info) |
| 313 | + |
| 314 | + def _instantiate_classes(self) -> None: |
| 315 | + """Internal instantiation logic, called within the pre-processor patch context.""" |
302 | 316 | if self.config["subcommand"] in {*self.subcommands(), "predict"}: # trainer commands |
303 | 317 | # since all classes are instantiated, the LightningCLI also creates an unused ``Trainer`` object. |
304 | 318 | # the minor change here is that engine is instantiated instead of trainer |
@@ -470,6 +484,106 @@ def _parser(self, subcommand: str | None) -> ArgumentParser: |
470 | 484 | # return the subcommand parser for the subcommand passed |
471 | 485 | return self.subcommand_parsers[subcommand] |
472 | 486 |
|
| 487 | + def _extract_pre_processor_args(self, args: Sequence[str] | None) -> list[str]: |
| 488 | + """Extract ``--model.pre_processor.*`` arguments before jsonargparse parsing. |
| 489 | +
|
| 490 | + jsonargparse rejects ``--model.pre_processor.image_size`` because |
| 491 | + ``pre_processor`` is typed as ``nn.Module | bool``. This method strips |
| 492 | + those arguments from the raw CLI list and stores them in |
| 493 | + ``self._pre_processor_kwargs`` so they can be forwarded to |
| 494 | + ``configure_pre_processor`` later. |
| 495 | +
|
| 496 | + Args: |
| 497 | + args: Raw CLI arguments, or ``None`` to read from ``sys.argv``. |
| 498 | +
|
| 499 | + Returns: |
| 500 | + list[str]: CLI arguments with ``--model.pre_processor.*`` entries removed. |
| 501 | + """ |
| 502 | + if args is None: |
| 503 | + args = sys.argv[1:] |
| 504 | + args = list(args) |
| 505 | + |
| 506 | + prefix = "--model.pre_processor." |
| 507 | + indices_to_remove: list[int] = [] |
| 508 | + for i, arg in enumerate(args): |
| 509 | + if not arg.startswith(prefix): |
| 510 | + continue |
| 511 | + if "=" in arg: |
| 512 | + # --model.pre_processor.image_size=512 |
| 513 | + key, val = arg[len(prefix) :].split("=", 1) |
| 514 | + self._pre_processor_kwargs[key] = self._parse_cli_value(val) |
| 515 | + indices_to_remove.append(i) |
| 516 | + elif i + 1 < len(args): |
| 517 | + # --model.pre_processor.image_size 512 |
| 518 | + key = arg[len(prefix) :] |
| 519 | + self._pre_processor_kwargs[key] = self._parse_cli_value(args[i + 1]) |
| 520 | + indices_to_remove.extend([i, i + 1]) |
| 521 | + |
| 522 | + # Convert int image_size to (h, w) tuple for consistency |
| 523 | + if "image_size" in self._pre_processor_kwargs: |
| 524 | + val = self._pre_processor_kwargs["image_size"] |
| 525 | + if isinstance(val, int): |
| 526 | + self._pre_processor_kwargs["image_size"] = (val, val) |
| 527 | + elif isinstance(val, list): |
| 528 | + self._pre_processor_kwargs["image_size"] = tuple(val) |
| 529 | + |
| 530 | + for idx in reversed(indices_to_remove): |
| 531 | + args.pop(idx) |
| 532 | + return args |
| 533 | + |
| 534 | + @staticmethod |
| 535 | + def _parse_cli_value(val: str) -> Any: # noqa: ANN401 |
| 536 | + """Parse a CLI string value into a Python object using literal evaluation.""" |
| 537 | + try: |
| 538 | + return ast.literal_eval(val) |
| 539 | + except (ValueError, SyntaxError): |
| 540 | + return val |
| 541 | + |
| 542 | + def _patch_configure_pre_processor(self) -> tuple[type, bool, Any] | None: |
| 543 | + """Patch the model class's ``configure_pre_processor`` with extracted CLI kwargs. |
| 544 | +
|
| 545 | + Uses ``functools.partial`` to bind the extracted pre-processor arguments |
| 546 | + (e.g. ``image_size``) to the model's ``configure_pre_processor`` method |
| 547 | + before the model is instantiated. This follows the same pattern as |
| 548 | + Lightning CLI's ``configure_optimizers`` override. |
| 549 | +
|
| 550 | + Returns: |
| 551 | + Tuple of (model_class, had_own_method, original_descriptor) for |
| 552 | + restoration, or ``None`` if patching was not applicable. |
| 553 | + """ |
| 554 | + subcommand = self.config["subcommand"] |
| 555 | + model_cfg = self.config.get(subcommand, self.config).get("model") |
| 556 | + if model_cfg is None: |
| 557 | + return None |
| 558 | + |
| 559 | + class_path = model_cfg.get("class_path") if hasattr(model_cfg, "get") else getattr(model_cfg, "class_path", None) |
| 560 | + if class_path is None: |
| 561 | + return None |
| 562 | + |
| 563 | + # Resolve the model class |
| 564 | + module_path, class_name = class_path.rsplit(".", 1) |
| 565 | + model_class = getattr(importlib.import_module(module_path), class_name) |
| 566 | + |
| 567 | + # Save the original descriptor so we can restore it after instantiation |
| 568 | + has_own = "configure_pre_processor" in model_class.__dict__ |
| 569 | + original_descriptor = model_class.__dict__.get("configure_pre_processor") |
| 570 | + |
| 571 | + # Get the resolved callable (handles both @staticmethod and @classmethod) |
| 572 | + callable_method = model_class.configure_pre_processor |
| 573 | + patched = partial(callable_method, **self._pre_processor_kwargs) |
| 574 | + model_class.configure_pre_processor = staticmethod(patched) |
| 575 | + |
| 576 | + return (model_class, has_own, original_descriptor) |
| 577 | + |
| 578 | + @staticmethod |
| 579 | + def _unpatch_configure_pre_processor(patch_info: tuple[type, bool, Any]) -> None: |
| 580 | + """Restore the original ``configure_pre_processor`` after model instantiation.""" |
| 581 | + model_class, has_own, original_descriptor = patch_info |
| 582 | + if has_own: |
| 583 | + model_class.configure_pre_processor = original_descriptor |
| 584 | + else: |
| 585 | + delattr(model_class, "configure_pre_processor") |
| 586 | + |
473 | 587 | def _configure_optimizers_method_to_model(self) -> None: |
474 | 588 | from lightning.pytorch.cli import LightningCLI, instantiate_class |
475 | 589 |
|
|
0 commit comments