Skip to content

Commit 24fd7d2

Browse files
committed
Compact checkpiont_filter_fn a bit, make learnable dct out dim changeable and define another model class to have another arch config tested.
1 parent b6eb61a commit 24fd7d2

File tree

1 file changed

+120
-94
lines changed

1 file changed

+120
-94
lines changed

timm/models/csatv2.py

Lines changed: 120 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import math
88
import warnings
9+
from functools import reduce
910
from typing import List, Optional, Tuple, Union
1011

1112
import numpy as np
@@ -172,6 +173,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
172173
return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2)
173174

174175

176+
def _split_out_chs(out_chs: int, ratio=(24, 4, 4)):
177+
# reduce ratio to smallest integers (24,4,4) -> (6,1,1)
178+
g = reduce(math.gcd, ratio)
179+
r = tuple(x // g for x in ratio)
180+
denom = sum(r)
181+
182+
assert out_chs % denom == 0 and out_chs >= denom, (
183+
f"out_chs={out_chs} can't be split into Y/Cb/Cr with ratio {ratio} "
184+
f"(reduced {r}); out_chs must be a multiple of {denom}."
185+
)
186+
187+
unit = out_chs // denom
188+
y, cb, cr = (ri * unit for ri in r)
189+
assert y + cb + cr == out_chs and min(y, cb, cr) > 0
190+
return y, cb, cr
191+
192+
175193
class LearnableDct2d(nn.Module):
176194
"""Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection."""
177195

@@ -180,6 +198,7 @@ def __init__(
180198
kernel_size: int,
181199
kernel_type: int = 2,
182200
orthonormal: bool = True,
201+
out_chs: int = 32,
183202
device=None,
184203
dtype=None,
185204
) -> None:
@@ -189,9 +208,11 @@ def __init__(
189208
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
190209
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
191210
self.permutation = _zigzag_permutation(kernel_size, kernel_size)
192-
self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0, **dd)
193-
self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd)
194-
self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd)
211+
212+
y_ch, cb_ch, cr_ch = _split_out_chs(out_chs, ratio=(24, 4, 4))
213+
self.conv_y = nn.Conv2d(kernel_size ** 2, y_ch, kernel_size=1, padding=0, **dd)
214+
self.conv_cb = nn.Conv2d(kernel_size ** 2, cb_ch, kernel_size=1, padding=0, **dd)
215+
self.conv_cr = nn.Conv2d(kernel_size ** 2, cr_ch, kernel_size=1, padding=0, **dd)
195216

196217
self.register_buffer('mean', torch.tensor(_DCT_MEAN, device=device), persistent=False)
197218
self.register_buffer('var', torch.tensor(_DCT_VAR, device=device), persistent=False)
@@ -534,7 +555,7 @@ def __init__(
534555
dp_rates += [next(dp_iter) for _ in range(depth - t_depth)]
535556
dp_rates += [next(dp_iter) if transformer_drop_path else 0. for _ in range(t_depth)]
536557

537-
self.stem_dct = LearnableDct2d(8, **dd)
558+
self.stem_dct = LearnableDct2d(8, out_chs=dims[0], **dd)
538559

539560
# Build stages dynamically
540561
dp_iter = iter(dp_rates)
@@ -671,112 +692,105 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
671692
return self.forward_head(x)
672693

673694

695+
def _cfg(url='', **kwargs):
696+
return {
697+
'url': url,
698+
'num_classes': 1000, 'input_size': (3, 512, 512),
699+
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
700+
'interpolation': 'bilinear', 'crop_pct': 1.0,
701+
'classifier': 'head.fc', 'first_conv': [],
702+
**kwargs,
703+
}
704+
705+
674706
default_cfgs = generate_default_cfgs({
675-
'csatv2': {
676-
'url': 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth',
677-
'num_classes': 1000,
678-
'input_size': (3, 512, 512),
679-
'mean': (0.485, 0.456, 0.406),
680-
'std': (0.229, 0.224, 0.225),
681-
'interpolation': 'bilinear',
682-
'crop_pct': 1.0,
683-
'classifier': 'head.fc',
684-
'first_conv': [],
685-
},
707+
'csatv2': _cfg(
708+
url='https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth'
709+
),
710+
'csatv2_21m': _cfg(),
686711
})
687712

688713

689714
def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict:
690715
"""Remap original CSATv2 checkpoint to timm format.
691716
692717
Handles two key structural changes:
693-
1. Stage naming: stages1/2/3/4 -> stages.0/1/2/3
694-
2. Downsample position: moved from end of stage N to start of stage N+1
718+
1) Stage naming: stages1/2/3/4 -> stages.0/1/2/3
719+
2) Downsample position: moved from end of stage N to start of stage N+1
695720
"""
696-
if 'stages.0.0.grn.weight' in state_dict:
697-
return state_dict # Already in timm format
721+
if "stages.0.0.grn.weight" in state_dict:
722+
return state_dict # already in timm format
698723

699724
import re
700725

701-
# Downsample indices in original checkpoint (Conv2d at end of each stage)
702-
# These move to index 0 of the next stage
703-
downsample_idx = {1: 3, 2: 3, 3: 9} # stage -> downsample index
726+
# FIXME this downsample idx is wired to the original 'csatv2' model size
727+
downsample_idx = {1: 3, 2: 3, 3: 9} # original stage -> downsample index
728+
729+
dct_re = re.compile(r"^dct\.")
730+
stage_re = re.compile(r"^stages([1-4])\.(\d+)\.(.*)$")
731+
head_re = re.compile(r"^head\.")
732+
norm_re = re.compile(r"^norm\.")
704733

705-
def remap_stage(m):
706-
stage = int(m.group(1))
707-
idx = int(m.group(2))
708-
rest = m.group(3)
734+
def remap_stage(m: re.Match) -> str:
735+
stage, idx, rest = int(m.group(1)), int(m.group(2)), m.group(3)
709736
if stage in downsample_idx and idx == downsample_idx[stage]:
710-
# Downsample moves to start of next stage
711-
return f'stages.{stage}.0.{rest}'
712-
elif stage == 1:
713-
# Stage 1 -> stages.0, indices unchanged
714-
return f'stages.0.{idx}.{rest}'
715-
else:
716-
# Stages 2-4 -> stages.1-3, indices shift +1 (after downsample)
717-
return f'stages.{stage - 1}.{idx + 1}.{rest}'
737+
return f"stages.{stage}.0.{rest}" # move downsample to next stage @0
738+
if stage == 1:
739+
return f"stages.0.{idx}.{rest}" # stage1 -> stages.0
740+
return f"stages.{stage - 1}.{idx + 1}.{rest}" # stage2-4 -> stages.1-3, shift +1
718741

719-
out_dict = {}
742+
out = {}
720743
for k, v in state_dict.items():
721-
# Remap dct -> stem_dct, and Y_Conv/Cb_Conv/Cr_Conv -> conv_y/conv_cb/conv_cr
722-
k = re.sub(r'^dct\.', 'stem_dct.', k)
723-
k = k.replace('.Y_Conv.', '.conv_y.')
724-
k = k.replace('.Cb_Conv.', '.conv_cb.')
725-
k = k.replace('.Cr_Conv.', '.conv_cr.')
726-
727-
# Remap stage names with index adjustments for downsample relocation
728-
k = re.sub(r'^stages([1-4])\.(\d+)\.(.*)$', remap_stage, k)
729-
730-
# Remap GRN: gamma/beta -> weight/bias with reshape
731-
if 'grn.gamma' in k:
732-
k = k.replace('grn.gamma', 'grn.weight')
733-
v = v.reshape(-1)
734-
elif 'grn.beta' in k:
735-
k = k.replace('grn.beta', 'grn.bias')
736-
v = v.reshape(-1)
737-
738-
# Remap FeedForward (nn.Sequential) to Mlp: net.0 -> fc1, net.3 -> fc2
739-
# Also rename ff -> mlp, ff_norm -> norm2, attn_norm -> norm1
740-
if '.ff.net.0.' in k:
741-
k = k.replace('.ff.net.0.', '.mlp.fc1.')
742-
elif '.ff.net.3.' in k:
743-
k = k.replace('.ff.net.3.', '.mlp.fc2.')
744-
elif '.ff_norm.' in k:
745-
k = k.replace('.ff_norm.', '.norm2.')
746-
elif '.attn_norm.' in k:
747-
k = k.replace('.attn_norm.', '.norm1.')
748-
749-
# Block.attention -> Block.attn (SpatialAttention)
750-
# SpatialAttention.attention -> SpatialAttention.attn (SpatialTransformerBlock)
751-
# Handle nested .attention.attention. first, then remaining .attention.
752-
if '.attention.attention.' in k:
753-
# SpatialTransformerBlock inner attn: remap to_qkv -> qkv
754-
k = k.replace('.attention.attention.attn.to_qkv.', '.attn.attn.qkv.')
755-
k = k.replace('.attention.attention.attn.', '.attn.attn.')
756-
k = k.replace('.attention.attention.', '.attn.attn.')
757-
elif '.attention.' in k:
758-
# Block.attention -> Block.attn (catches SpatialAttention.conv etc)
759-
k = k.replace('.attention.', '.attn.')
760-
761-
# TransformerBlock: remap attention layer names
762-
# to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed
763-
# Note: only for TransformerBlock, not SpatialTransformerBlock (which has .attn.attn.)
764-
if '.attn.to_qkv.' in k:
765-
k = k.replace('.attn.to_qkv.', '.attn.qkv.')
766-
elif '.attn.to_out.0.' in k:
767-
k = k.replace('.attn.to_out.0.', '.attn.proj.')
768-
769-
# TransformerBlock: .attn.pos_embed -> .pos_embed (but not .attn.attn.pos_embed)
770-
if '.attn.pos_embed.' in k and '.attn.attn.' not in k:
771-
k = k.replace('.attn.pos_embed.', '.pos_embed.')
772-
773-
# Remap head -> head.fc, norm -> head.norm (order matters)
774-
k = re.sub(r'^head\.', 'head.fc.', k)
775-
k = re.sub(r'^norm\.', 'head.norm.', k)
776-
777-
out_dict[k] = v
778-
779-
return out_dict
744+
# dct -> stem_dct, and Y/Cb/Cr conv names
745+
k = dct_re.sub("stem_dct.", k)
746+
k = (k.replace(".Y_Conv.", ".conv_y.")
747+
.replace(".Cb_Conv.", ".conv_cb.")
748+
.replace(".Cr_Conv.", ".conv_cr."))
749+
750+
# stage remap + downsample relocation
751+
k = stage_re.sub(remap_stage, k)
752+
753+
# GRN: gamma/beta -> weight/bias (reshape)
754+
if "grn.gamma" in k:
755+
k, v = k.replace("grn.gamma", "grn.weight"), v.reshape(-1)
756+
elif "grn.beta" in k:
757+
k, v = k.replace("grn.beta", "grn.bias"), v.reshape(-1)
758+
759+
# FeedForward(nn.Sequential) -> Mlp + norm renames
760+
if ".ff.net.0." in k:
761+
k = k.replace(".ff.net.0.", ".mlp.fc1.")
762+
elif ".ff.net.3." in k:
763+
k = k.replace(".ff.net.3.", ".mlp.fc2.")
764+
elif ".ff_norm." in k:
765+
k = k.replace(".ff_norm.", ".norm2.")
766+
elif ".attn_norm." in k:
767+
k = k.replace(".attn_norm.", ".norm1.")
768+
769+
# attention -> attn (handle nested first)
770+
if ".attention.attention." in k:
771+
k = (k.replace(".attention.attention.attn.to_qkv.", ".attn.attn.qkv.")
772+
.replace(".attention.attention.attn.", ".attn.attn.")
773+
.replace(".attention.attention.", ".attn.attn."))
774+
elif ".attention." in k:
775+
k = k.replace(".attention.", ".attn.")
776+
777+
# TransformerBlock attention name remaps
778+
if ".attn.to_qkv." in k:
779+
k = k.replace(".attn.to_qkv.", ".attn.qkv.")
780+
elif ".attn.to_out.0." in k:
781+
k = k.replace(".attn.to_out.0.", ".attn.proj.")
782+
783+
# .attn.pos_embed -> .pos_embed (but not SpatialTransformerBlock's .attn.attn.pos_embed)
784+
if ".attn.pos_embed." in k and ".attn.attn." not in k:
785+
k = k.replace(".attn.pos_embed.", ".pos_embed.")
786+
787+
# head -> head.fc, norm -> head.norm (order matters)
788+
k = head_re.sub("head.fc.", k)
789+
k = norm_re.sub("head.norm.", k)
790+
791+
out[k] = v
792+
793+
return out
780794

781795

782796
def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2:
@@ -795,3 +809,15 @@ def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2:
795809
@register_model
796810
def csatv2(pretrained: bool = False, **kwargs) -> CSATv2:
797811
return _create_csatv2('csatv2', pretrained, **kwargs)
812+
813+
814+
@register_model
815+
def csatv2_21m(pretrained: bool = False, **kwargs) -> CSATv2:
816+
# experimental ~20-21M param larger model to validate flexible arch spec
817+
model_args = dict(
818+
dims = (48, 96, 224, 448),
819+
depths = (3, 3, 10, 8),
820+
transformer_depths = (0, 0, 4, 3)
821+
822+
)
823+
return _create_csatv2('csatv2_21m', pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)