Skip to content

Commit d6e4e83

Browse files
committed
Bump version to 1.6.0 and refine tensor ops
1 parent 3c98266 commit d6e4e83

File tree

2 files changed

+15
-20
lines changed

2 files changed

+15
-20
lines changed

torchattack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from torchattack.vmifgsm import VMIFGSM
3737
from torchattack.vnifgsm import VNIFGSM
3838

39-
__version__ = '1.5.1'
39+
__version__ = '1.6.0'
4040

4141
__all__ = [
4242
# Helper functions

torchattack/mumodig.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def random_rotate(img: torch.Tensor) -> torch.Tensor:
155155
random_rotate,
156156
]
157157
choice = torch.randint(len(transforms), (1,)).item()
158-
return transforms[choice](img)
158+
return transforms[int(choice)](img)
159159

160160
def random_resize_and_pad(img: torch.Tensor, dim: int = 245) -> torch.Tensor:
161161
"""
@@ -168,21 +168,24 @@ def random_resize_and_pad(img: torch.Tensor, dim: int = 245) -> torch.Tensor:
168168
img, size=(target, target), mode='bilinear', align_corners=False
169169
)
170170

171-
pad_total = dim - target
172-
pad_top = torch.randint(0, pad_total, (1,)).item() # type: ignore[arg-type]
173-
pad_bottom = pad_total - pad_top
174-
pad_left = torch.randint(0, pad_total, (1,)).item() # type: ignore[arg-type]
175-
pad_right = pad_total - pad_left
171+
pad_total = int(dim - target)
172+
pad_top = int(torch.randint(0, pad_total, (1,)).item())
173+
pad_bottom = int(pad_total - pad_top)
174+
pad_left = int(torch.randint(0, pad_total, (1,)).item())
175+
pad_right = int(pad_total - pad_left)
176176

177-
padded = f.pad(resized, [pad_left, pad_right, pad_top, pad_bottom], value=0) # type: ignore[list-item]
178-
return f.interpolate(
177+
padded: torch.Tensor = f.pad(
178+
resized, [pad_left, pad_right, pad_top, pad_bottom], value=0
179+
)
180+
padded = f.interpolate(
179181
padded, size=(orig, orig), mode='bilinear', align_corners=False
180182
)
183+
return padded
181184

182185
# Choose one augmentation at random
183186
transforms = [random_affine, random_resize_and_pad]
184187
idx = torch.randint(len(transforms), (1,)).item()
185-
aug_x: torch.Tensor = transforms[idx](x)
188+
aug_x: torch.Tensor = transforms[int(idx)](x) # type: ignore[operator]
186189
return aug_x
187190

188191
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -235,15 +238,7 @@ def __init__(self, region_num: int, is_channels_first: bool = False) -> None:
235238
def get_params(
236239
self, x: torch.Tensor
237240
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
238-
"""
239-
Compute per-channel min, max and number of splits.
240-
241-
Args:
242-
x: Tensor of shape (C,H,W)
243-
244-
Returns:
245-
min_val: (C,), max_val: (C,), counts: (C,) = region_num - 1 splits
246-
"""
241+
"""Compute per-channel min, max and number of splits."""
247242

248243
c, _, _ = x.size()
249244
flat = x.view(c, -1)
@@ -265,7 +260,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
265260

266261
# sample random percentiles for splits
267262
total_splits = counts.sum().item()
268-
rand_perc = torch.rand(total_splits, device=x.device)
263+
rand_perc = torch.rand(int(total_splits), device=x.device)
269264
splits = rand_perc.view(-1, self.region_num - 1)
270265

271266
# compute split positions: in [min_val, max_val) per channel

0 commit comments

Comments
 (0)