Skip to content

Commit 8674fa4

Browse files
committed
Add csatv2_21m weights at 512 & 640 img size. Add layer-scale support to csatv2 but not used yet.
1 parent 940edf4 commit 8674fa4

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

timm/models/csatv2.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

23-
from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead
23+
from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead, LayerScale, LayerScale2d
2424
from timm.layers.grn import GlobalResponseNorm
2525
from timm.models._builder import build_model_with_cfg
2626
from timm.models._features import feature_take_indices
@@ -322,6 +322,7 @@ def __init__(
322322
self,
323323
dim: int,
324324
drop_path: float = 0.,
325+
ls_init_value: Optional[float] = None,
325326
device=None,
326327
dtype=None,
327328
) -> None:
@@ -333,6 +334,7 @@ def __init__(
333334
self.act = nn.GELU()
334335
self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd)
335336
self.pwconv2 = nn.Linear(4 * dim, dim, **dd)
337+
self.ls = LayerScale2d(dim, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
336338
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
337339
self.attn = SpatialAttention(**dd)
338340

@@ -350,6 +352,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
350352
attn = self.attn(x)
351353
attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True)
352354
x = x * attn
355+
x = self.ls(x)
353356

354357
return shortcut + self.drop_path(x)
355358

@@ -442,6 +445,7 @@ def __init__(
442445
attn_drop: float = 0.,
443446
proj_drop: float = 0.,
444447
drop_path: float = 0.,
448+
ls_init_value: Optional[float] = None,
445449
device=None,
446450
dtype=None,
447451
) -> None:
@@ -470,10 +474,12 @@ def __init__(
470474
proj_drop=proj_drop,
471475
**dd,
472476
)
477+
self.ls1 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
473478
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
474479

475480
self.norm2 = nn.LayerNorm(oup, **dd)
476481
self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd)
482+
self.ls2 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
477483
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
478484

479485
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -484,7 +490,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
484490
x_t = x_t.flatten(2).transpose(1, 2)
485491
x_t = self.norm1(x_t)
486492
x_t = self.pos_embed(x_t, (H, W))
487-
x_t = self.attn(x_t)
493+
x_t = self.ls1(self.attn(x_t))
488494
x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
489495
x = shortcut + self.drop_path1(x_t)
490496
else:
@@ -493,15 +499,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
493499
x_t = x.flatten(2).transpose(1, 2)
494500
x_t = self.norm1(x_t)
495501
x_t = self.pos_embed(x_t, (H, W))
496-
x_t = self.attn(x_t)
502+
x_t = self.ls1(self.attn(x_t))
497503
x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
498504
x = shortcut + self.drop_path1(x_t)
499505

500506
# MLP block
501507
B, C, H, W = x.shape
502508
shortcut = x
503509
x_t = x.flatten(2).transpose(1, 2)
504-
x_t = self.mlp(self.norm2(x_t))
510+
x_t = self.ls2(self.mlp(self.norm2(x_t)))
505511
x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
506512
x = shortcut + self.drop_path2(x_t)
507513

@@ -545,6 +551,7 @@ def __init__(
545551
transformer_depths: Tuple[int, ...] = (0, 0, 2, 2),
546552
drop_path_rate: float = 0.0,
547553
transformer_drop_path: bool = False,
554+
ls_init_value: Optional[float] = None,
548555
global_pool: str = 'avg',
549556
device=None,
550557
dtype=None,
@@ -590,9 +597,9 @@ def __init__(
590597
# Downsample at start of stage (except first stage)
591598
([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) +
592599
# Conv blocks
593-
[Block(dim=dim, drop_path=next(dp_iter), **dd) for _ in range(depth - t_depth)] +
600+
[Block(dim=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(depth - t_depth)] +
594601
# Transformer blocks at end of stage
595-
[TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), **dd) for _ in range(t_depth)] +
602+
[TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(t_depth)] +
596603
# Trailing LayerNorm (except last stage)
597604
([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else [])
598605
)
@@ -726,7 +733,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
726733
def _cfg(url='', **kwargs):
727734
return {
728735
'url': url,
729-
'num_classes': 1000, 'input_size': (3, 512, 512),
736+
'num_classes': 1000, 'input_size': (3, 512, 512), 'pool_size': (8, 8),
730737
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
731738
'interpolation': 'bilinear', 'crop_pct': 1.0,
732739
'classifier': 'head.fc', 'first_conv': [],
@@ -735,10 +742,19 @@ def _cfg(url='', **kwargs):
735742

736743

737744
default_cfgs = generate_default_cfgs({
738-
'csatv2': _cfg(
745+
'csatv2.r512_in1k': _cfg(
746+
hf_hub_id='timm/',
747+
),
748+
'csatv2_21m.sw_r640_in1k': _cfg(
749+
hf_hub_id='timm/',
750+
input_size=(3, 640, 640),
751+
interpolation='bicubic',
752+
),
753+
'csatv2_21m.sw_r512_in1k': _cfg(
739754
hf_hub_id='timm/',
755+
pool_size=(10, 10),
756+
interpolation='bicubic',
740757
),
741-
'csatv2_21m': _cfg(),
742758
})
743759

744760

0 commit comments

Comments
 (0)