66"""
77import math
88import warnings
9+ from functools import reduce
910from typing import List , Optional , Tuple , Union
1011
1112import numpy as np
@@ -172,6 +173,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
172173 return self .transform (self .transform (x ).transpose (- 1 , - 2 )).transpose (- 1 , - 2 )
173174
174175
176+ def _split_out_chs (out_chs : int , ratio = (24 , 4 , 4 )):
177+ # reduce ratio to smallest integers (24,4,4) -> (6,1,1)
178+ g = reduce (math .gcd , ratio )
179+ r = tuple (x // g for x in ratio )
180+ denom = sum (r )
181+
182+ assert out_chs % denom == 0 and out_chs >= denom , (
183+ f"out_chs={ out_chs } can't be split into Y/Cb/Cr with ratio { ratio } "
184+ f"(reduced { r } ); out_chs must be a multiple of { denom } ."
185+ )
186+
187+ unit = out_chs // denom
188+ y , cb , cr = (ri * unit for ri in r )
189+ assert y + cb + cr == out_chs and min (y , cb , cr ) > 0
190+ return y , cb , cr
191+
192+
175193class LearnableDct2d (nn .Module ):
176194 """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection."""
177195
@@ -180,6 +198,7 @@ def __init__(
180198 kernel_size : int ,
181199 kernel_type : int = 2 ,
182200 orthonormal : bool = True ,
201+ out_chs : int = 32 ,
183202 device = None ,
184203 dtype = None ,
185204 ) -> None :
@@ -189,9 +208,11 @@ def __init__(
189208 self .unfold = nn .Unfold (kernel_size = (kernel_size , kernel_size ), stride = (kernel_size , kernel_size ))
190209 self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** dd )
191210 self .permutation = _zigzag_permutation (kernel_size , kernel_size )
192- self .conv_y = nn .Conv2d (kernel_size ** 2 , 24 , kernel_size = 1 , padding = 0 , ** dd )
193- self .conv_cb = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 , ** dd )
194- self .conv_cr = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 , ** dd )
211+
212+ y_ch , cb_ch , cr_ch = _split_out_chs (out_chs , ratio = (24 , 4 , 4 ))
213+ self .conv_y = nn .Conv2d (kernel_size ** 2 , y_ch , kernel_size = 1 , padding = 0 , ** dd )
214+ self .conv_cb = nn .Conv2d (kernel_size ** 2 , cb_ch , kernel_size = 1 , padding = 0 , ** dd )
215+ self .conv_cr = nn .Conv2d (kernel_size ** 2 , cr_ch , kernel_size = 1 , padding = 0 , ** dd )
195216
196217 self .register_buffer ('mean' , torch .tensor (_DCT_MEAN , device = device ), persistent = False )
197218 self .register_buffer ('var' , torch .tensor (_DCT_VAR , device = device ), persistent = False )
@@ -534,7 +555,7 @@ def __init__(
534555 dp_rates += [next (dp_iter ) for _ in range (depth - t_depth )]
535556 dp_rates += [next (dp_iter ) if transformer_drop_path else 0. for _ in range (t_depth )]
536557
537- self .stem_dct = LearnableDct2d (8 , ** dd )
558+ self .stem_dct = LearnableDct2d (8 , out_chs = dims [ 0 ], ** dd )
538559
539560 # Build stages dynamically
540561 dp_iter = iter (dp_rates )
@@ -671,112 +692,105 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
671692 return self .forward_head (x )
672693
673694
695+ def _cfg (url = '' , ** kwargs ):
696+ return {
697+ 'url' : url ,
698+ 'num_classes' : 1000 , 'input_size' : (3 , 512 , 512 ),
699+ 'mean' : (0.485 , 0.456 , 0.406 ), 'std' : (0.229 , 0.224 , 0.225 ),
700+ 'interpolation' : 'bilinear' , 'crop_pct' : 1.0 ,
701+ 'classifier' : 'head.fc' , 'first_conv' : [],
702+ ** kwargs ,
703+ }
704+
705+
674706default_cfgs = generate_default_cfgs ({
675- 'csatv2' : {
676- 'url' : 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth' ,
677- 'num_classes' : 1000 ,
678- 'input_size' : (3 , 512 , 512 ),
679- 'mean' : (0.485 , 0.456 , 0.406 ),
680- 'std' : (0.229 , 0.224 , 0.225 ),
681- 'interpolation' : 'bilinear' ,
682- 'crop_pct' : 1.0 ,
683- 'classifier' : 'head.fc' ,
684- 'first_conv' : [],
685- },
707+ 'csatv2' : _cfg (
708+ url = 'https://huggingface.co/Hyunil/CSATv2/resolve/main/CSATv2_ImageNet_timm.pth'
709+ ),
710+ 'csatv2_21m' : _cfg (),
686711})
687712
688713
689714def checkpoint_filter_fn (state_dict : dict , model : nn .Module ) -> dict :
690715 """Remap original CSATv2 checkpoint to timm format.
691716
692717 Handles two key structural changes:
693- 1. Stage naming: stages1/2/3/4 -> stages.0/1/2/3
694- 2. Downsample position: moved from end of stage N to start of stage N+1
718+ 1) Stage naming: stages1/2/3/4 -> stages.0/1/2/3
719+ 2) Downsample position: moved from end of stage N to start of stage N+1
695720 """
696- if ' stages.0.0.grn.weight' in state_dict :
697- return state_dict # Already in timm format
721+ if " stages.0.0.grn.weight" in state_dict :
722+ return state_dict # already in timm format
698723
699724 import re
700725
701- # Downsample indices in original checkpoint (Conv2d at end of each stage)
702- # These move to index 0 of the next stage
703- downsample_idx = {1 : 3 , 2 : 3 , 3 : 9 } # stage -> downsample index
726+ # FIXME this downsample idx is wired to the original 'csatv2' model size
727+ downsample_idx = {1 : 3 , 2 : 3 , 3 : 9 } # original stage -> downsample index
728+
729+ dct_re = re .compile (r"^dct\." )
730+ stage_re = re .compile (r"^stages([1-4])\.(\d+)\.(.*)$" )
731+ head_re = re .compile (r"^head\." )
732+ norm_re = re .compile (r"^norm\." )
704733
705- def remap_stage (m ):
706- stage = int (m .group (1 ))
707- idx = int (m .group (2 ))
708- rest = m .group (3 )
734+ def remap_stage (m : re .Match ) -> str :
735+ stage , idx , rest = int (m .group (1 )), int (m .group (2 )), m .group (3 )
709736 if stage in downsample_idx and idx == downsample_idx [stage ]:
710- # Downsample moves to start of next stage
711- return f'stages.{ stage } .0.{ rest } '
712- elif stage == 1 :
713- # Stage 1 -> stages.0, indices unchanged
714- return f'stages.0.{ idx } .{ rest } '
715- else :
716- # Stages 2-4 -> stages.1-3, indices shift +1 (after downsample)
717- return f'stages.{ stage - 1 } .{ idx + 1 } .{ rest } '
737+ return f"stages.{ stage } .0.{ rest } " # move downsample to next stage @0
738+ if stage == 1 :
739+ return f"stages.0.{ idx } .{ rest } " # stage1 -> stages.0
740+ return f"stages.{ stage - 1 } .{ idx + 1 } .{ rest } " # stage2-4 -> stages.1-3, shift +1
718741
719- out_dict = {}
742+ out = {}
720743 for k , v in state_dict .items ():
721- # Remap dct -> stem_dct, and Y_Conv/Cb_Conv/Cr_Conv -> conv_y/conv_cb/conv_cr
722- k = re .sub (r'^dct\.' , 'stem_dct.' , k )
723- k = k .replace ('.Y_Conv.' , '.conv_y.' )
724- k = k .replace ('.Cb_Conv.' , '.conv_cb.' )
725- k = k .replace ('.Cr_Conv.' , '.conv_cr.' )
726-
727- # Remap stage names with index adjustments for downsample relocation
728- k = re .sub (r'^stages([1-4])\.(\d+)\.(.*)$' , remap_stage , k )
729-
730- # Remap GRN: gamma/beta -> weight/bias with reshape
731- if 'grn.gamma' in k :
732- k = k .replace ('grn.gamma' , 'grn.weight' )
733- v = v .reshape (- 1 )
734- elif 'grn.beta' in k :
735- k = k .replace ('grn.beta' , 'grn.bias' )
736- v = v .reshape (- 1 )
737-
738- # Remap FeedForward (nn.Sequential) to Mlp: net.0 -> fc1, net.3 -> fc2
739- # Also rename ff -> mlp, ff_norm -> norm2, attn_norm -> norm1
740- if '.ff.net.0.' in k :
741- k = k .replace ('.ff.net.0.' , '.mlp.fc1.' )
742- elif '.ff.net.3.' in k :
743- k = k .replace ('.ff.net.3.' , '.mlp.fc2.' )
744- elif '.ff_norm.' in k :
745- k = k .replace ('.ff_norm.' , '.norm2.' )
746- elif '.attn_norm.' in k :
747- k = k .replace ('.attn_norm.' , '.norm1.' )
748-
749- # Block.attention -> Block.attn (SpatialAttention)
750- # SpatialAttention.attention -> SpatialAttention.attn (SpatialTransformerBlock)
751- # Handle nested .attention.attention. first, then remaining .attention.
752- if '.attention.attention.' in k :
753- # SpatialTransformerBlock inner attn: remap to_qkv -> qkv
754- k = k .replace ('.attention.attention.attn.to_qkv.' , '.attn.attn.qkv.' )
755- k = k .replace ('.attention.attention.attn.' , '.attn.attn.' )
756- k = k .replace ('.attention.attention.' , '.attn.attn.' )
757- elif '.attention.' in k :
758- # Block.attention -> Block.attn (catches SpatialAttention.conv etc)
759- k = k .replace ('.attention.' , '.attn.' )
760-
761- # TransformerBlock: remap attention layer names
762- # to_qkv -> qkv, to_out.0 -> proj, attn.pos_embed -> pos_embed
763- # Note: only for TransformerBlock, not SpatialTransformerBlock (which has .attn.attn.)
764- if '.attn.to_qkv.' in k :
765- k = k .replace ('.attn.to_qkv.' , '.attn.qkv.' )
766- elif '.attn.to_out.0.' in k :
767- k = k .replace ('.attn.to_out.0.' , '.attn.proj.' )
768-
769- # TransformerBlock: .attn.pos_embed -> .pos_embed (but not .attn.attn.pos_embed)
770- if '.attn.pos_embed.' in k and '.attn.attn.' not in k :
771- k = k .replace ('.attn.pos_embed.' , '.pos_embed.' )
772-
773- # Remap head -> head.fc, norm -> head.norm (order matters)
774- k = re .sub (r'^head\.' , 'head.fc.' , k )
775- k = re .sub (r'^norm\.' , 'head.norm.' , k )
776-
777- out_dict [k ] = v
778-
779- return out_dict
744+ # dct -> stem_dct, and Y/Cb/Cr conv names
745+ k = dct_re .sub ("stem_dct." , k )
746+ k = (k .replace (".Y_Conv." , ".conv_y." )
747+ .replace (".Cb_Conv." , ".conv_cb." )
748+ .replace (".Cr_Conv." , ".conv_cr." ))
749+
750+ # stage remap + downsample relocation
751+ k = stage_re .sub (remap_stage , k )
752+
753+ # GRN: gamma/beta -> weight/bias (reshape)
754+ if "grn.gamma" in k :
755+ k , v = k .replace ("grn.gamma" , "grn.weight" ), v .reshape (- 1 )
756+ elif "grn.beta" in k :
757+ k , v = k .replace ("grn.beta" , "grn.bias" ), v .reshape (- 1 )
758+
759+ # FeedForward(nn.Sequential) -> Mlp + norm renames
760+ if ".ff.net.0." in k :
761+ k = k .replace (".ff.net.0." , ".mlp.fc1." )
762+ elif ".ff.net.3." in k :
763+ k = k .replace (".ff.net.3." , ".mlp.fc2." )
764+ elif ".ff_norm." in k :
765+ k = k .replace (".ff_norm." , ".norm2." )
766+ elif ".attn_norm." in k :
767+ k = k .replace (".attn_norm." , ".norm1." )
768+
769+ # attention -> attn (handle nested first)
770+ if ".attention.attention." in k :
771+ k = (k .replace (".attention.attention.attn.to_qkv." , ".attn.attn.qkv." )
772+ .replace (".attention.attention.attn." , ".attn.attn." )
773+ .replace (".attention.attention." , ".attn.attn." ))
774+ elif ".attention." in k :
775+ k = k .replace (".attention." , ".attn." )
776+
777+ # TransformerBlock attention name remaps
778+ if ".attn.to_qkv." in k :
779+ k = k .replace (".attn.to_qkv." , ".attn.qkv." )
780+ elif ".attn.to_out.0." in k :
781+ k = k .replace (".attn.to_out.0." , ".attn.proj." )
782+
783+ # .attn.pos_embed -> .pos_embed (but not SpatialTransformerBlock's .attn.attn.pos_embed)
784+ if ".attn.pos_embed." in k and ".attn.attn." not in k :
785+ k = k .replace (".attn.pos_embed." , ".pos_embed." )
786+
787+ # head -> head.fc, norm -> head.norm (order matters)
788+ k = head_re .sub ("head.fc." , k )
789+ k = norm_re .sub ("head.norm." , k )
790+
791+ out [k ] = v
792+
793+ return out
780794
781795
782796def _create_csatv2 (variant : str , pretrained : bool = False , ** kwargs ) -> CSATv2 :
@@ -795,3 +809,15 @@ def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2:
795809@register_model
796810def csatv2 (pretrained : bool = False , ** kwargs ) -> CSATv2 :
797811 return _create_csatv2 ('csatv2' , pretrained , ** kwargs )
812+
813+
814+ @register_model
815+ def csatv2_21m (pretrained : bool = False , ** kwargs ) -> CSATv2 :
816+ # experimental ~20-21M param larger model to validate flexible arch spec
817+ model_args = dict (
818+ dims = (48 , 96 , 224 , 448 ),
819+ depths = (3 , 3 , 10 , 8 ),
820+ transformer_depths = (0 , 0 , 4 , 3 )
821+
822+ )
823+ return _create_csatv2 ('csatv2_21m' , pretrained , ** dict (model_args , ** kwargs ))
0 commit comments