Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions timm/models/csatv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn
import torch.nn.functional as F

from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead
from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead, LayerScale, LayerScale2d
from timm.layers.grn import GlobalResponseNorm
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
Expand Down Expand Up @@ -322,6 +322,7 @@ def __init__(
self,
dim: int,
drop_path: float = 0.,
ls_init_value: Optional[float] = None,
device=None,
dtype=None,
) -> None:
Expand All @@ -333,6 +334,7 @@ def __init__(
self.act = nn.GELU()
self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd)
self.pwconv2 = nn.Linear(4 * dim, dim, **dd)
self.ls = LayerScale2d(dim, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn = SpatialAttention(**dd)

Expand All @@ -350,6 +352,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
attn = self.attn(x)
attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True)
x = x * attn
x = self.ls(x)

return shortcut + self.drop_path(x)

Expand Down Expand Up @@ -442,6 +445,7 @@ def __init__(
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
ls_init_value: Optional[float] = None,
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -470,10 +474,12 @@ def __init__(
proj_drop=proj_drop,
**dd,
)
self.ls1 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

self.norm2 = nn.LayerNorm(oup, **dd)
self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd)
self.ls2 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -484,7 +490,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x_t = x_t.flatten(2).transpose(1, 2)
x_t = self.norm1(x_t)
x_t = self.pos_embed(x_t, (H, W))
x_t = self.attn(x_t)
x_t = self.ls1(self.attn(x_t))
x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
x = shortcut + self.drop_path1(x_t)
else:
Expand All @@ -493,15 +499,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x_t = x.flatten(2).transpose(1, 2)
x_t = self.norm1(x_t)
x_t = self.pos_embed(x_t, (H, W))
x_t = self.attn(x_t)
x_t = self.ls1(self.attn(x_t))
x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
x = shortcut + self.drop_path1(x_t)

# MLP block
B, C, H, W = x.shape
shortcut = x
x_t = x.flatten(2).transpose(1, 2)
x_t = self.mlp(self.norm2(x_t))
x_t = self.ls2(self.mlp(self.norm2(x_t)))
x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
x = shortcut + self.drop_path2(x_t)

Expand Down Expand Up @@ -545,6 +551,7 @@ def __init__(
transformer_depths: Tuple[int, ...] = (0, 0, 2, 2),
drop_path_rate: float = 0.0,
transformer_drop_path: bool = False,
ls_init_value: Optional[float] = None,
global_pool: str = 'avg',
device=None,
dtype=None,
Expand Down Expand Up @@ -590,9 +597,9 @@ def __init__(
# Downsample at start of stage (except first stage)
([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) +
# Conv blocks
[Block(dim=dim, drop_path=next(dp_iter), **dd) for _ in range(depth - t_depth)] +
[Block(dim=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(depth - t_depth)] +
# Transformer blocks at end of stage
[TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), **dd) for _ in range(t_depth)] +
[TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(t_depth)] +
# Trailing LayerNorm (except last stage)
([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else [])
)
Expand Down Expand Up @@ -726,7 +733,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 512, 512),
'num_classes': 1000, 'input_size': (3, 512, 512), 'pool_size': (8, 8),
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
'interpolation': 'bilinear', 'crop_pct': 1.0,
'classifier': 'head.fc', 'first_conv': [],
Expand All @@ -735,10 +742,19 @@ def _cfg(url='', **kwargs):


default_cfgs = generate_default_cfgs({
'csatv2': _cfg(
'csatv2.r512_in1k': _cfg(
hf_hub_id='timm/',
),
'csatv2_21m.sw_r640_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 640, 640),
interpolation='bicubic',
),
'csatv2_21m.sw_r512_in1k': _cfg(
hf_hub_id='timm/',
pool_size=(10, 10),
interpolation='bicubic',
),
'csatv2_21m': _cfg(),
})


Expand Down
14 changes: 14 additions & 0 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2791,6 +2791,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_dlittle_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
Expand Down Expand Up @@ -4324,6 +4327,17 @@ def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> Visio
return model


@register_model
def vit_dlittle_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
)
model = _create_vision_transformer(
'vit_dlittle_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
Expand Down