Skip to content

Commit 1f4ced7

Browse files
committed
Fix CLI limitation for setting custom image size via model.image_size (#3460)
Fixes issue #3460 by enabling CLI-based configuration of input image size across all models without requiring manual pre-processor reconstruction. After the pre-processor refactor, the CLI does not provide a simple way to modify image size. Users are forced to either: 1. Recreate the entire preprocessing pipeline via CLI arguments, or 2. Define a full YAML configuration just to change resize dimensions This significantly reduces usability for a common operation. This PR introduces `image_size` as a first-class parameter in `AnomalibModule` and propagates it through all model constructors. The resizing is applied by patching the auto-created pre-processing pipeline using `_apply_image_size`, which: - Updates the first `Resize` transform if present - Inserts a `Resize` transform if missing - Preserves existing transform attributes (e.g., interpolation, antialias) - Updates `export_transform` for compatibility Users can now directly set image size via CLI: anomalib train --model Padim --model.image_size 512 --data MVTecAD --fast_dev_run True anomalib train --model Padim --model.image_size "[512, 512]" --data MVTecAD --data.category transistor anomalib train --model Stfpm --model.image_size 384 --data MVTecAD --data.category bottle --fast_dev_run True jsonargparse exposes CLI arguments based on subclass `__init__` signatures. Therefore, `image_size` is added to each model constructor and forwarded to `AnomalibModule`, following the existing design pattern used for: - pre_processor - post_processor - evaluator - visualizer This ensures: - No CLI-level argument injection - Consistent API across all models - Full compatibility with config/CLI system - Added `image_size` parameter to `AnomalibModule` - Implemented `_apply_image_size` and `_rebuild_resize` - Updated all model constructors to accept and forward `image_size` - Removed the need for CLI or YAML preprocessing overrides - Supports both int (square) and tuple (H, W) - Applied only when `pre_processor=True` (default behavior) - Does not override custom `PreProcessor` instances - Fully backward compatible (default = None) - Simplifies CLI usage for a common configuration task - Eliminates need for verbose preprocessing definitions - Aligns CLI behavior with Python API capabilities Signed-off-by: Abhay Kumar Das <dasabhay.jsr@gmail.com>
1 parent 37c51c4 commit 1f4ced7

25 files changed

Lines changed: 104 additions & 1 deletion

src/anomalib/models/components/base/anomalib_module.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from anomalib.metrics.evaluator import Evaluator
6060
from anomalib.post_processing import PostProcessor
6161
from anomalib.pre_processing import PreProcessor
62+
from anomalib.pre_processing.utils.transform import get_exportable_transform
6263
from anomalib.utils import deprecate
6364
from anomalib.visualization import ImageVisualizer, Visualizer
6465

@@ -83,6 +84,11 @@ class AnomalibModule(ExportMixin, pl.LightningModule, ABC):
8384
default. Defaults to ``True``.
8485
visualizer (Visualizer | bool, optional): Visualizer instance or flag to
8586
use default. Defaults to ``True``.
87+
image_size (tuple[int, int] | int | None, optional): Target image size
88+
for resizing in the pre-processor. Only applied when
89+
``pre_processor=True`` (auto-created). Ignored when an explicit
90+
``PreProcessor`` instance is provided. An int value is interpreted
91+
as square ``(n, n)``. Defaults to ``None`` (use model default).
8692
8793
Attributes:
8894
model (nn.Module): PyTorch model to be trained
@@ -126,16 +132,22 @@ def __init__(
126132
post_processor: nn.Module | bool = True,
127133
evaluator: Evaluator | bool = True,
128134
visualizer: Visualizer | bool = True,
135+
image_size: tuple[int, int] | int | None = None,
129136
) -> None:
130137
super().__init__()
131138
logger.info("Initializing %s model.", self.__class__.__name__)
132139

140+
if isinstance(image_size, int):
141+
image_size = (image_size, image_size)
142+
133143
self.save_hyperparameters()
134144
self.model: nn.Module
135145
self.loss: nn.Module
136146
self.callbacks: list[Callback]
137147

138148
self.pre_processor = self._resolve_component(pre_processor, nn.Module, self.configure_pre_processor)
149+
if image_size is not None and isinstance(pre_processor, bool) and pre_processor:
150+
self._apply_image_size(image_size)
139151
self.post_processor = self._resolve_component(post_processor, nn.Module, self.configure_post_processor)
140152
self.evaluator = self._resolve_component(evaluator, Evaluator, self.configure_evaluator)
141153
self.visualizer = self._resolve_component(visualizer, Visualizer, self.configure_visualizer)
@@ -306,6 +318,51 @@ def _resolve_component(
306318
msg = f"Passed object should be {component_type} or bool, got: {type(component)}"
307319
raise TypeError(msg)
308320

321+
def _apply_image_size(self, image_size: tuple[int, int]) -> None:
322+
"""Patch the auto-created pre-processor's transform to use the given image size.
323+
324+
Modifies the first ``Resize`` transform found in the pre-processor's
325+
transform pipeline. If no ``Resize`` exists, one is prepended.
326+
Only called when ``pre_processor=True`` and ``image_size`` is not None.
327+
328+
Args:
329+
image_size (tuple[int, int]): Target ``(height, width)`` for resizing.
330+
"""
331+
if self.pre_processor is None or not isinstance(self.pre_processor, PreProcessor):
332+
return
333+
334+
transform = self.pre_processor.transform
335+
if transform is None:
336+
logger.warning(
337+
"image_size=%s was specified but the model's pre-processor has no transform. "
338+
"The image_size setting will be ignored.",
339+
image_size,
340+
)
341+
return
342+
343+
if isinstance(transform, Compose):
344+
for i, t in enumerate(transform.transforms):
345+
if isinstance(t, Resize):
346+
transform.transforms[i] = self._rebuild_resize(t, image_size)
347+
break
348+
else:
349+
transform.transforms.insert(0, Resize(image_size, antialias=True))
350+
elif isinstance(transform, Resize):
351+
self.pre_processor.transform = self._rebuild_resize(transform, image_size)
352+
else:
353+
self.pre_processor.transform = Compose([Resize(image_size, antialias=True), transform])
354+
355+
self.pre_processor.export_transform = get_exportable_transform(self.pre_processor.transform)
356+
357+
@staticmethod
358+
def _rebuild_resize(original: Resize, image_size: tuple[int, int]) -> Resize:
359+
"""Build a new ``Resize`` with updated size, preserving the original's parameters."""
360+
kwargs: dict[str, Any] = {}
361+
for attr in ("interpolation", "max_size", "antialias"):
362+
if hasattr(original, attr):
363+
kwargs[attr] = getattr(original, attr)
364+
return Resize(size=image_size, **kwargs)
365+
309366
@staticmethod
310367
def configure_pre_processor(image_size: tuple[int, int] | None = None) -> PreProcessor:
311368
"""Configure the default pre-processor.

src/anomalib/models/image/anomaly_dino/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ def __init__(
159159
post_processor: nn.Module | bool = True,
160160
evaluator: Evaluator | bool = True,
161161
visualizer: Visualizer | bool = True,
162+
image_size: tuple[int, int] | int | None = None,
162163
) -> None:
163164
super().__init__(
164165
pre_processor=pre_processor,
165166
post_processor=post_processor,
166167
evaluator=evaluator,
167168
visualizer=visualizer,
169+
image_size=image_size,
168170
)
169171
self.model: AnomalyDINOModel = AnomalyDINOModel(
170172
num_neighbours=num_neighbours,

src/anomalib/models/image/cfa/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,14 @@ def __init__(
101101
post_processor: PostProcessor | bool = True,
102102
evaluator: Evaluator | bool = True,
103103
visualizer: Visualizer | bool = True,
104+
image_size: tuple[int, int] | int | None = None,
104105
) -> None:
105106
super().__init__(
106107
pre_processor=pre_processor,
107108
post_processor=post_processor,
108109
evaluator=evaluator,
109110
visualizer=visualizer,
111+
image_size=image_size,
110112
)
111113
self.model: CfaModel = CfaModel(
112114
backbone=backbone,

src/anomalib/models/image/cflow/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ def __init__(
9393
post_processor: PostProcessor | bool = True,
9494
evaluator: Evaluator | bool = True,
9595
visualizer: Visualizer | bool = True,
96+
image_size: tuple[int, int] | int | None = None,
9697
) -> None:
9798
super().__init__(
9899
pre_processor=pre_processor,
99100
post_processor=post_processor,
100101
evaluator=evaluator,
101102
visualizer=visualizer,
103+
image_size=image_size,
102104
)
103105

104106
self.model: CflowModel = CflowModel(

src/anomalib/models/image/csflow/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ def __init__(
8080
post_processor: PostProcessor | bool = True,
8181
evaluator: Evaluator | bool = True,
8282
visualizer: Visualizer | bool = True,
83+
image_size: tuple[int, int] | int | None = None,
8384
) -> None:
8485
super().__init__(
8586
pre_processor=pre_processor,
8687
post_processor=post_processor,
8788
evaluator=evaluator,
8889
visualizer=visualizer,
90+
image_size=image_size,
8991
)
9092
if self.input_size is None:
9193
msg = "CsFlow needs input size to build torch model."

src/anomalib/models/image/dfkde/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ def __init__(
9696
post_processor: PostProcessor | bool = True,
9797
evaluator: Evaluator | bool = True,
9898
visualizer: Visualizer | bool = True,
99+
image_size: tuple[int, int] | int | None = None,
99100
) -> None:
100101
super().__init__(
101102
pre_processor=pre_processor,
102103
post_processor=post_processor,
103104
evaluator=evaluator,
104105
visualizer=visualizer,
106+
image_size=image_size,
105107
)
106108

107109
self.model: DfkdeModel = DfkdeModel(

src/anomalib/models/image/dfm/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ def __init__(
9696
post_processor: PostProcessor | bool = True,
9797
evaluator: Evaluator | bool = True,
9898
visualizer: Visualizer | bool = True,
99+
image_size: tuple[int, int] | int | None = None,
99100
) -> None:
100101
super().__init__(
101102
pre_processor=pre_processor,
102103
post_processor=post_processor,
103104
evaluator=evaluator,
104105
visualizer=visualizer,
106+
image_size=image_size,
105107
)
106108

107109
self.model: DFMModel = DFMModel(

src/anomalib/models/image/dinomaly/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,14 @@ def __init__(
161161
post_processor: PostProcessor | bool = True,
162162
evaluator: Evaluator | bool = True,
163163
visualizer: Visualizer | bool = True,
164+
image_size: tuple[int, int] | int | None = None,
164165
) -> None:
165166
super().__init__(
166167
pre_processor=pre_processor,
167168
post_processor=post_processor,
168169
evaluator=evaluator,
169170
visualizer=visualizer,
171+
image_size=image_size,
170172
)
171173

172174
self.model: DinomalyModel = DinomalyModel(

src/anomalib/models/image/draem/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,14 @@ def __init__(
9292
post_processor: PostProcessor | bool = True,
9393
evaluator: Evaluator | bool = True,
9494
visualizer: Visualizer | bool = True,
95+
image_size: tuple[int, int] | int | None = None,
9596
) -> None:
9697
super().__init__(
9798
pre_processor=pre_processor,
9899
post_processor=post_processor,
99100
evaluator=evaluator,
100101
visualizer=visualizer,
102+
image_size=image_size,
101103
)
102104
dtd_dir = Path(dtd_dir)
103105
if not dtd_dir.is_dir():

src/anomalib/models/image/dsr/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,14 @@ def __init__(
107107
post_processor: PostProcessor | bool = True,
108108
evaluator: Evaluator | bool = True,
109109
visualizer: Visualizer | bool = True,
110+
image_size: tuple[int, int] | int | None = None,
110111
) -> None:
111112
super().__init__(
112113
pre_processor=pre_processor,
113114
post_processor=post_processor,
114115
evaluator=evaluator,
115116
visualizer=visualizer,
117+
image_size=image_size,
116118
)
117119

118120
self.automatic_optimization = False

0 commit comments

Comments
 (0)