Skip to content

Commit 98aa3d9

Browse files
committed
Fix timm models interpolation unresolved
1 parent 0088545 commit 98aa3d9

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

torchattack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torchattack.vmifgsm import VMIFGSM
3434
from torchattack.vnifgsm import VNIFGSM
3535

36-
__version__ = '1.5.0'
36+
__version__ = '1.5.1'
3737

3838
__all__ = [
3939
# Helper functions

torchattack/attack_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class AttackModelMeta: # type: ignore[no-any-unimported]
102102

103103
@classmethod
104104
def from_timm_pretrained_cfg(cls, cfg: dict) -> Self:
105+
from timm.data.transforms import str_to_interp_mode
106+
105107
# Reference:
106108
# create_transform::https://github.com/huggingface/pytorch-image-models/blob/a49b02/timm/data/transforms_factory.py#L334
107109
# transforms_imagenet_eval::https://github.com/huggingface/pytorch-image-models/blob/a49b02/timm/data/transforms_factory.py#L247
@@ -114,7 +116,7 @@ def from_timm_pretrained_cfg(cls, cfg: dict) -> Self:
114116
return cls(
115117
resize_size=resize_size,
116118
crop_size=crop_size,
117-
interpolation=cfg['interpolation'],
119+
interpolation=str_to_interp_mode(cfg['interpolation']),
118120
mean=cfg['mean'],
119121
std=cfg['std'],
120122
)

0 commit comments

Comments
 (0)