2020import torch .nn as nn
2121import torch .nn .functional as F
2222
23- from timm .layers import trunc_normal_ , DropPath , Mlp , LayerNorm2d , Attention , NormMlpClassifierHead
23+ from timm .layers import trunc_normal_ , DropPath , Mlp , LayerNorm2d , Attention , NormMlpClassifierHead , LayerScale , LayerScale2d
2424from timm .layers .grn import GlobalResponseNorm
2525from timm .models ._builder import build_model_with_cfg
2626from timm .models ._features import feature_take_indices
@@ -322,6 +322,7 @@ def __init__(
322322 self ,
323323 dim : int ,
324324 drop_path : float = 0. ,
325+ ls_init_value : Optional [float ] = None ,
325326 device = None ,
326327 dtype = None ,
327328 ) -> None :
@@ -333,6 +334,7 @@ def __init__(
333334 self .act = nn .GELU ()
334335 self .grn = GlobalResponseNorm (4 * dim , channels_last = True , ** dd )
335336 self .pwconv2 = nn .Linear (4 * dim , dim , ** dd )
337+ self .ls = LayerScale2d (dim , init_values = ls_init_value , ** dd ) if ls_init_value else nn .Identity ()
336338 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
337339 self .attn = SpatialAttention (** dd )
338340
@@ -350,6 +352,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
350352 attn = self .attn (x )
351353 attn = F .interpolate (attn , size = x .shape [2 :], mode = 'bilinear' , align_corners = True )
352354 x = x * attn
355+ x = self .ls (x )
353356
354357 return shortcut + self .drop_path (x )
355358
@@ -442,6 +445,7 @@ def __init__(
442445 attn_drop : float = 0. ,
443446 proj_drop : float = 0. ,
444447 drop_path : float = 0. ,
448+ ls_init_value : Optional [float ] = None ,
445449 device = None ,
446450 dtype = None ,
447451 ) -> None :
@@ -470,10 +474,12 @@ def __init__(
470474 proj_drop = proj_drop ,
471475 ** dd ,
472476 )
477+ self .ls1 = LayerScale (oup , init_values = ls_init_value , ** dd ) if ls_init_value else nn .Identity ()
473478 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
474479
475480 self .norm2 = nn .LayerNorm (oup , ** dd )
476481 self .mlp = Mlp (oup , hidden_dim , oup , act_layer = nn .GELU , drop = proj_drop , ** dd )
482+ self .ls2 = LayerScale (oup , init_values = ls_init_value , ** dd ) if ls_init_value else nn .Identity ()
477483 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
478484
479485 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -484,7 +490,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
484490 x_t = x_t .flatten (2 ).transpose (1 , 2 )
485491 x_t = self .norm1 (x_t )
486492 x_t = self .pos_embed (x_t , (H , W ))
487- x_t = self .attn (x_t )
493+ x_t = self .ls1 ( self . attn (x_t ) )
488494 x_t = x_t .transpose (1 , 2 ).reshape (B , - 1 , H , W )
489495 x = shortcut + self .drop_path1 (x_t )
490496 else :
@@ -493,15 +499,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
493499 x_t = x .flatten (2 ).transpose (1 , 2 )
494500 x_t = self .norm1 (x_t )
495501 x_t = self .pos_embed (x_t , (H , W ))
496- x_t = self .attn (x_t )
502+ x_t = self .ls1 ( self . attn (x_t ) )
497503 x_t = x_t .transpose (1 , 2 ).reshape (B , - 1 , H , W )
498504 x = shortcut + self .drop_path1 (x_t )
499505
500506 # MLP block
501507 B , C , H , W = x .shape
502508 shortcut = x
503509 x_t = x .flatten (2 ).transpose (1 , 2 )
504- x_t = self .mlp (self .norm2 (x_t ))
510+ x_t = self .ls2 ( self . mlp (self .norm2 (x_t ) ))
505511 x_t = x_t .transpose (1 , 2 ).reshape (B , C , H , W )
506512 x = shortcut + self .drop_path2 (x_t )
507513
@@ -545,6 +551,7 @@ def __init__(
545551 transformer_depths : Tuple [int , ...] = (0 , 0 , 2 , 2 ),
546552 drop_path_rate : float = 0.0 ,
547553 transformer_drop_path : bool = False ,
554+ ls_init_value : Optional [float ] = None ,
548555 global_pool : str = 'avg' ,
549556 device = None ,
550557 dtype = None ,
@@ -590,9 +597,9 @@ def __init__(
590597 # Downsample at start of stage (except first stage)
591598 ([nn .Conv2d (dims [i - 1 ], dim , kernel_size = 2 , stride = 2 , ** dd )] if i > 0 else []) +
592599 # Conv blocks
593- [Block (dim = dim , drop_path = next (dp_iter ), ** dd ) for _ in range (depth - t_depth )] +
600+ [Block (dim = dim , drop_path = next (dp_iter ), ls_init_value = ls_init_value , ** dd ) for _ in range (depth - t_depth )] +
594601 # Transformer blocks at end of stage
595- [TransformerBlock (inp = dim , oup = dim , drop_path = next (dp_iter ), ** dd ) for _ in range (t_depth )] +
602+ [TransformerBlock (inp = dim , oup = dim , drop_path = next (dp_iter ), ls_init_value = ls_init_value , ** dd ) for _ in range (t_depth )] +
596603 # Trailing LayerNorm (except last stage)
597604 ([LayerNorm2d (dim , eps = 1e-6 , ** dd )] if i < len (depths ) - 1 else [])
598605 )
@@ -726,7 +733,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
726733def _cfg (url = '' , ** kwargs ):
727734 return {
728735 'url' : url ,
729- 'num_classes' : 1000 , 'input_size' : (3 , 512 , 512 ),
736+ 'num_classes' : 1000 , 'input_size' : (3 , 512 , 512 ), 'pool_size' : ( 8 , 8 ),
730737 'mean' : (0.485 , 0.456 , 0.406 ), 'std' : (0.229 , 0.224 , 0.225 ),
731738 'interpolation' : 'bilinear' , 'crop_pct' : 1.0 ,
732739 'classifier' : 'head.fc' , 'first_conv' : [],
@@ -735,10 +742,19 @@ def _cfg(url='', **kwargs):
735742
736743
737744default_cfgs = generate_default_cfgs ({
738- 'csatv2' : _cfg (
745+ 'csatv2.r512_in1k' : _cfg (
746+ hf_hub_id = 'timm/' ,
747+ ),
748+ 'csatv2_21m.sw_r640_in1k' : _cfg (
749+ hf_hub_id = 'timm/' ,
750+ input_size = (3 , 640 , 640 ),
751+ interpolation = 'bicubic' ,
752+ ),
753+ 'csatv2_21m.sw_r512_in1k' : _cfg (
739754 hf_hub_id = 'timm/' ,
755+ pool_size = (10 , 10 ),
756+ interpolation = 'bicubic' ,
740757 ),
741- 'csatv2_21m' : _cfg (),
742758})
743759
744760
0 commit comments