@@ -85,14 +85,15 @@ class Pad(InvertibleTransform):
85
85
in which case `np.pad` will be used.
86
86
87
87
Args:
88
- to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
88
+ to_pad: the amount to pad in each dimension (including the channel) [(low_H, high_H), (low_W, high_W), ...].
89
89
if None, must provide in the `__call__` at runtime.
90
- mode: available modes for numpy array: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
90
+ mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
91
91
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
92
- available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
92
+ ( PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
93
93
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
94
94
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
95
95
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
96
+ requires pytorch >= 1.10 for best compatibility.
96
97
kwargs: other arguments for the `np.pad` or `torch.pad` function.
97
98
note that `np.pad` treats channel dimension as the first dimension.
98
99
@@ -122,6 +123,9 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int
122
123
def _np_pad (img : torch .Tensor , pad_width , mode , ** kwargs ) -> torch .Tensor :
123
124
img_np = img .detach ().cpu ().numpy () if isinstance (img , torch .Tensor ) else img
124
125
mode = convert_pad_mode (dst = img_np , mode = mode ).value
126
+ if mode == "constant" and "value" in kwargs :
127
+ val = kwargs .pop ("value" )
128
+ kwargs ["constant_values" ] = val
125
129
out = torch .as_tensor (np .pad (img , pad_width , mode = mode , ** kwargs ))
126
130
if isinstance (img , MetaTensor ):
127
131
out = convert_to_dst_type (out , dst = img )[0 ]
@@ -141,9 +145,9 @@ def __call__( # type: ignore
141
145
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
142
146
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
143
147
default to `self.to_pad`.
144
- mode: available modes for numpy array: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
148
+ mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
145
149
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
146
- available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
150
+ ( PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
147
151
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
148
152
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
149
153
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
@@ -163,16 +167,26 @@ def __call__( # type: ignore
163
167
164
168
# all zeros, skip padding
165
169
if np .asarray (to_pad_ ).any ():
166
- if mode in ["linear_ramp" , "maximum" , "mean" , "median" , "minimum" , "symmetric" , "empty" ]:
170
+ to_pad_ = list (to_pad_ )
171
+ if len (to_pad_ ) < len (img_t .shape ):
172
+ to_pad_ = list (to_pad_ ) + [(0 , 0 )] * (len (img_t .shape ) - len (to_pad_ ))
173
+ if mode_ in {"linear_ramp" , "maximum" , "mean" , "median" , "minimum" , "symmetric" , "empty" }:
167
174
out = self ._np_pad (img_t , pad_width = to_pad_ , mode = mode_ , ** kwargs_ )
168
175
else :
176
+ mode_ = convert_pad_mode (dst = img_t , mode = mode_ ).value
169
177
try :
170
- mode_ = convert_pad_mode (dst = img_t , mode = mode_ ).value
171
- out = self ._pt_pad (img_t , pad_width = to_pad_ , mode = mode_ , ** kwargs_ )
172
- # but if mode or args don't exist in pytorch, use numpy instead
173
- except (ValueError , TypeError ) as err :
174
- if "Unsupported option" in str (err ) or "unexpected keyword" in str (err ):
178
+ _pad = (
179
+ self ._pt_pad
180
+ if mode_ in {"reflect" , "replicate" }
181
+ and img_t .dtype not in {torch .int16 , torch .int64 , torch .bool , torch .uint8 }
182
+ else self ._np_pad
183
+ )
184
+ out = _pad (img_t , pad_width = to_pad_ , mode = mode_ , ** kwargs_ )
185
+ except (ValueError , TypeError , RuntimeError ) as err :
186
+ if "supported" in str (err ) or "unexpected keyword" in str (err ) or "implemented" in str (err ):
175
187
out = self ._np_pad (img_t , pad_width = to_pad_ , mode = mode_ , ** kwargs_ )
188
+ else :
189
+ raise ValueError (f"{ mode_ } , { kwargs_ } , { img_t .dtype } , { img_t .device } " ) from err
176
190
else :
177
191
out = img_t
178
192
if get_track_meta ():
0 commit comments