diff --git a/timm/models/csatv2.py b/timm/models/csatv2.py index 48e00fc18f..a45d6e9151 100644 --- a/timm/models/csatv2.py +++ b/timm/models/csatv2.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead +from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead, LayerScale, LayerScale2d from timm.layers.grn import GlobalResponseNorm from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices @@ -322,6 +322,7 @@ def __init__( self, dim: int, drop_path: float = 0., + ls_init_value: Optional[float] = None, device=None, dtype=None, ) -> None: @@ -333,6 +334,7 @@ def __init__( self.act = nn.GELU() self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd) self.pwconv2 = nn.Linear(4 * dim, dim, **dd) + self.ls = LayerScale2d(dim, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.attn = SpatialAttention(**dd) @@ -350,6 +352,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn = self.attn(x) attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True) x = x * attn + x = self.ls(x) return shortcut + self.drop_path(x) @@ -442,6 +445,7 @@ def __init__( attn_drop: float = 0., proj_drop: float = 0., drop_path: float = 0., + ls_init_value: Optional[float] = None, device=None, dtype=None, ) -> None: @@ -470,10 +474,12 @@ def __init__( proj_drop=proj_drop, **dd, ) + self.ls1 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.LayerNorm(oup, **dd) self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd) + self.ls2 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -484,7 +490,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_t = x_t.flatten(2).transpose(1, 2) x_t = self.norm1(x_t) x_t = self.pos_embed(x_t, (H, W)) - x_t = self.attn(x_t) + x_t = self.ls1(self.attn(x_t)) x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) x = shortcut + self.drop_path1(x_t) else: @@ -493,7 +499,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_t = x.flatten(2).transpose(1, 2) x_t = self.norm1(x_t) x_t = self.pos_embed(x_t, (H, W)) - x_t = self.attn(x_t) + x_t = self.ls1(self.attn(x_t)) x_t = x_t.transpose(1, 2).reshape(B, -1, H, W) x = shortcut + self.drop_path1(x_t) @@ -501,7 +507,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape shortcut = x x_t = x.flatten(2).transpose(1, 2) - x_t = self.mlp(self.norm2(x_t)) + x_t = self.ls2(self.mlp(self.norm2(x_t))) x_t = x_t.transpose(1, 2).reshape(B, C, H, W) x = shortcut + self.drop_path2(x_t) @@ -545,6 +551,7 @@ def __init__( transformer_depths: Tuple[int, ...] = (0, 0, 2, 2), drop_path_rate: float = 0.0, transformer_drop_path: bool = False, + ls_init_value: Optional[float] = None, global_pool: str = 'avg', device=None, dtype=None, @@ -590,9 +597,9 @@ def __init__( # Downsample at start of stage (except first stage) ([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) + # Conv blocks - [Block(dim=dim, drop_path=next(dp_iter), **dd) for _ in range(depth - t_depth)] + + [Block(dim=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(depth - t_depth)] + # Transformer blocks at end of stage - [TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), **dd) for _ in range(t_depth)] + + [TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(t_depth)] + # Trailing LayerNorm (except last stage) ([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else []) ) @@ -726,7 +733,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 512, 512), + 'num_classes': 1000, 'input_size': (3, 512, 512), 'pool_size': (8, 8), 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'interpolation': 'bilinear', 'crop_pct': 1.0, 'classifier': 'head.fc', 'first_conv': [], @@ -735,10 +742,19 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ - 'csatv2': _cfg( + 'csatv2.r512_in1k': _cfg( + hf_hub_id='timm/', + ), + 'csatv2_21m.sw_r640_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 640, 640), + interpolation='bicubic', + ), + 'csatv2_21m.sw_r512_in1k': _cfg( hf_hub_id='timm/', + pool_size=(10, 10), + interpolation='bicubic', ), - 'csatv2_21m': _cfg(), }) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index dd7dee62ce..51fb97c5a2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -2791,6 +2791,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_dlittle_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), @@ -4324,6 +4327,17 @@ def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> Visio return model +@register_model +def vit_dlittle_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff', + ) + model = _create_vision_transformer( + 'vit_dlittle_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict(