Skip to content

Commit 398466c

Browse files
authored
4747 compatible dtypes for padding (#4749)
* fixes #4747 padding Signed-off-by: Wenqi Li <[email protected]>
1 parent 5e36e23 commit 398466c

File tree

4 files changed

+66
-15
lines changed

4 files changed

+66
-15
lines changed

Diff for: monai/transforms/croppad/array.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ class Pad(InvertibleTransform):
8585
in which case `np.pad` will be used.
8686
8787
Args:
88-
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
88+
to_pad: the amount to pad in each dimension (including the channel) [(low_H, high_H), (low_W, high_W), ...].
8989
if None, must provide in the `__call__` at runtime.
90-
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
90+
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
9191
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
92-
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
92+
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
9393
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
9494
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
9595
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
96+
requires pytorch >= 1.10 for best compatibility.
9697
kwargs: other arguments for the `np.pad` or `torch.pad` function.
9798
note that `np.pad` treats channel dimension as the first dimension.
9899
@@ -122,6 +123,9 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int
122123
def _np_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor:
123124
img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img
124125
mode = convert_pad_mode(dst=img_np, mode=mode).value
126+
if mode == "constant" and "value" in kwargs:
127+
val = kwargs.pop("value")
128+
kwargs["constant_values"] = val
125129
out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs))
126130
if isinstance(img, MetaTensor):
127131
out = convert_to_dst_type(out, dst=img)[0]
@@ -141,9 +145,9 @@ def __call__( # type: ignore
141145
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
142146
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
143147
default to `self.to_pad`.
144-
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
148+
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
145149
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
146-
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
150+
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
147151
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
148152
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
149153
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
@@ -163,16 +167,26 @@ def __call__( # type: ignore
163167

164168
# all zeros, skip padding
165169
if np.asarray(to_pad_).any():
166-
if mode in ["linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"]:
170+
to_pad_ = list(to_pad_)
171+
if len(to_pad_) < len(img_t.shape):
172+
to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_))
173+
if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
167174
out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
168175
else:
176+
mode_ = convert_pad_mode(dst=img_t, mode=mode_).value
169177
try:
170-
mode_ = convert_pad_mode(dst=img_t, mode=mode_).value
171-
out = self._pt_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
172-
# but if mode or args don't exist in pytorch, use numpy instead
173-
except (ValueError, TypeError) as err:
174-
if "Unsupported option" in str(err) or "unexpected keyword" in str(err):
178+
_pad = (
179+
self._pt_pad
180+
if mode_ in {"reflect", "replicate"}
181+
and img_t.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
182+
else self._np_pad
183+
)
184+
out = _pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
185+
except (ValueError, TypeError, RuntimeError) as err:
186+
if "supported" in str(err) or "unexpected keyword" in str(err) or "implemented" in str(err):
175187
out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
188+
else:
189+
raise ValueError(f"{mode_}, {kwargs_}, {img_t.dtype}, {img_t.device}") from err
176190
else:
177191
out = img_t
178192
if get_track_meta():

Diff for: monai/transforms/spatial/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ class Zoom(InvertibleTransform):
10751075
padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
10761076
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
10771077
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
1078-
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
1078+
One of the listed string values or a user supplied function. Defaults to ``"edge"``.
10791079
The mode to pad data after zooming.
10801080
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
10811081
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
@@ -1123,7 +1123,7 @@ def __call__(
11231123
padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
11241124
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
11251125
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
1126-
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
1126+
One of the listed string values or a user supplied function. Defaults to ``"edge"``.
11271127
The mode to pad data after zooming.
11281128
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
11291129
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

Diff for: monai/transforms/spatial/dictionary.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1448,7 +1448,7 @@ class Zoomd(MapTransform, InvertibleTransform):
14481448
padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
14491449
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
14501450
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
1451-
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
1451+
One of the listed string values or a user supplied function. Defaults to ``"edge"``.
14521452
The mode to pad data after zooming.
14531453
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
14541454
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
@@ -1521,7 +1521,7 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform):
15211521
padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
15221522
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
15231523
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
1524-
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
1524+
One of the listed string values or a user supplied function. Defaults to ``"edge"``.
15251525
The mode to pad data after zooming.
15261526
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
15271527
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

Diff for: tests/test_pad_mode.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
17+
from monai.transforms import CastToType, Pad
18+
from monai.utils import NumpyPadMode, PytorchPadMode
19+
from tests.utils import SkipIfBeforePyTorchVersion
20+
21+
22+
@SkipIfBeforePyTorchVersion((1, 10, 1))
23+
class TestPadMode(unittest.TestCase):
24+
def test_pad(self):
25+
expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)}
26+
for t in (float, int, np.uint8, np.int16, np.float32, bool):
27+
for d in ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",):
28+
for s in ((1, 10, 10), (1, 5, 6, 7)):
29+
for m in list(PytorchPadMode) + list(NumpyPadMode):
30+
a = torch.rand(s)
31+
to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)]
32+
out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d))
33+
self.assertEqual(out.shape, expected_shapes[len(s)])
34+
35+
36+
if __name__ == "__main__":
37+
unittest.main()

0 commit comments

Comments
 (0)