Skip to content

Commit 6827a6f

Browse files
committed
refactor: redesign image_size handling via pre_processor as per review
Signed-off-by: Abhay Kumar Das <dasabhay.jsr@gmail.com>
1 parent 37c51c4 commit 6827a6f

1 file changed

Lines changed: 114 additions & 0 deletions

File tree

src/anomalib/cli/cli.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
The CLI supports configuration via both command line arguments and configuration files (.yaml or .json).
88
"""
99

10+
import ast
11+
import importlib
1012
import logging
13+
import sys
1114
from collections.abc import Callable, Sequence
1215
from functools import partial
1316
from pathlib import Path
@@ -73,7 +76,9 @@ def __init__(self, args: Sequence[str] | None = None, run: bool = True) -> None:
7376
self.parser = self.init_parser()
7477
self.subcommand_parsers: dict[str, ArgumentParser] = {}
7578
self.subcommand_method_arguments: dict[str, list[str]] = {}
79+
self._pre_processor_kwargs: dict[str, Any] = {}
7680
self.add_subcommands()
81+
args = self._extract_pre_processor_args(args)
7782
self.config = self.parser.parse_args(args=args)
7883
self.subcommand = self.config["subcommand"]
7984
if _LIGHTNING_AVAILABLE:
@@ -299,6 +304,15 @@ def instantiate_classes(self) -> None:
299304
But for subcommands we do not want to instantiate any trainer specific classes such as datamodule, model, etc
300305
This is because the subcommand is responsible for instantiating and executing code based on the passed config
301306
"""
307+
patch_info = self._patch_configure_pre_processor() if self._pre_processor_kwargs else None
308+
try:
309+
self._instantiate_classes()
310+
finally:
311+
if patch_info is not None:
312+
self._unpatch_configure_pre_processor(patch_info)
313+
314+
def _instantiate_classes(self) -> None:
315+
"""Internal instantiation logic, called within the pre-processor patch context."""
302316
if self.config["subcommand"] in {*self.subcommands(), "predict"}: # trainer commands
303317
# since all classes are instantiated, the LightningCLI also creates an unused ``Trainer`` object.
304318
# the minor change here is that engine is instantiated instead of trainer
@@ -470,6 +484,106 @@ def _parser(self, subcommand: str | None) -> ArgumentParser:
470484
# return the subcommand parser for the subcommand passed
471485
return self.subcommand_parsers[subcommand]
472486

487+
def _extract_pre_processor_args(self, args: Sequence[str] | None) -> list[str]:
488+
"""Extract ``--model.pre_processor.*`` arguments before jsonargparse parsing.
489+
490+
jsonargparse rejects ``--model.pre_processor.image_size`` because
491+
``pre_processor`` is typed as ``nn.Module | bool``. This method strips
492+
those arguments from the raw CLI list and stores them in
493+
``self._pre_processor_kwargs`` so they can be forwarded to
494+
``configure_pre_processor`` later.
495+
496+
Args:
497+
args: Raw CLI arguments, or ``None`` to read from ``sys.argv``.
498+
499+
Returns:
500+
list[str]: CLI arguments with ``--model.pre_processor.*`` entries removed.
501+
"""
502+
if args is None:
503+
args = sys.argv[1:]
504+
args = list(args)
505+
506+
prefix = "--model.pre_processor."
507+
indices_to_remove: list[int] = []
508+
for i, arg in enumerate(args):
509+
if not arg.startswith(prefix):
510+
continue
511+
if "=" in arg:
512+
# --model.pre_processor.image_size=512
513+
key, val = arg[len(prefix) :].split("=", 1)
514+
self._pre_processor_kwargs[key] = self._parse_cli_value(val)
515+
indices_to_remove.append(i)
516+
elif i + 1 < len(args):
517+
# --model.pre_processor.image_size 512
518+
key = arg[len(prefix) :]
519+
self._pre_processor_kwargs[key] = self._parse_cli_value(args[i + 1])
520+
indices_to_remove.extend([i, i + 1])
521+
522+
# Convert int image_size to (h, w) tuple for consistency
523+
if "image_size" in self._pre_processor_kwargs:
524+
val = self._pre_processor_kwargs["image_size"]
525+
if isinstance(val, int):
526+
self._pre_processor_kwargs["image_size"] = (val, val)
527+
elif isinstance(val, list):
528+
self._pre_processor_kwargs["image_size"] = tuple(val)
529+
530+
for idx in reversed(indices_to_remove):
531+
args.pop(idx)
532+
return args
533+
534+
@staticmethod
535+
def _parse_cli_value(val: str) -> Any: # noqa: ANN401
536+
"""Parse a CLI string value into a Python object using literal evaluation."""
537+
try:
538+
return ast.literal_eval(val)
539+
except (ValueError, SyntaxError):
540+
return val
541+
542+
def _patch_configure_pre_processor(self) -> tuple[type, bool, Any] | None:
543+
"""Patch the model class's ``configure_pre_processor`` with extracted CLI kwargs.
544+
545+
Uses ``functools.partial`` to bind the extracted pre-processor arguments
546+
(e.g. ``image_size``) to the model's ``configure_pre_processor`` method
547+
before the model is instantiated. This follows the same pattern as
548+
Lightning CLI's ``configure_optimizers`` override.
549+
550+
Returns:
551+
Tuple of (model_class, had_own_method, original_descriptor) for
552+
restoration, or ``None`` if patching was not applicable.
553+
"""
554+
subcommand = self.config["subcommand"]
555+
model_cfg = self.config.get(subcommand, self.config).get("model")
556+
if model_cfg is None:
557+
return None
558+
559+
class_path = model_cfg.get("class_path") if hasattr(model_cfg, "get") else getattr(model_cfg, "class_path", None)
560+
if class_path is None:
561+
return None
562+
563+
# Resolve the model class
564+
module_path, class_name = class_path.rsplit(".", 1)
565+
model_class = getattr(importlib.import_module(module_path), class_name)
566+
567+
# Save the original descriptor so we can restore it after instantiation
568+
has_own = "configure_pre_processor" in model_class.__dict__
569+
original_descriptor = model_class.__dict__.get("configure_pre_processor")
570+
571+
# Get the resolved callable (handles both @staticmethod and @classmethod)
572+
callable_method = model_class.configure_pre_processor
573+
patched = partial(callable_method, **self._pre_processor_kwargs)
574+
model_class.configure_pre_processor = staticmethod(patched)
575+
576+
return (model_class, has_own, original_descriptor)
577+
578+
@staticmethod
579+
def _unpatch_configure_pre_processor(patch_info: tuple[type, bool, Any]) -> None:
580+
"""Restore the original ``configure_pre_processor`` after model instantiation."""
581+
model_class, has_own, original_descriptor = patch_info
582+
if has_own:
583+
model_class.configure_pre_processor = original_descriptor
584+
else:
585+
delattr(model_class, "configure_pre_processor")
586+
473587
def _configure_optimizers_method_to_model(self) -> None:
474588
from lightning.pytorch.cli import LightningCLI, instantiate_class
475589

0 commit comments

Comments
 (0)