Skip to content

Commit 7a75178

Browse files
committed
Upload csatv2 weights to hub, fix non-contiguous dct weight
1 parent 24fd7d2 commit 7a75178

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

timm/models/csatv2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
super().__init__()
148148
kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3}
149149
dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **dd).T
150-
self.register_buffer('weights', dct_weights)
150+
self.register_buffer('weights', dct_weights.contiguous())
151151
self.register_parameter('bias', None)
152152

153153
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -705,7 +705,7 @@ def _cfg(url='', **kwargs):
705705

706706
default_cfgs = generate_default_cfgs({
707707
'csatv2': _cfg(
708-
url='https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth'
708+
hf_hub_id='timm/',
709709
),
710710
'csatv2_21m': _cfg(),
711711
})

0 commit comments

Comments
 (0)