Skip to content

Commit 3677f67

Browse files
committed
Add the 256x256 in1k ft of the so150m, add an alternate so150m def
1 parent 2a84d68 commit 3677f67

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

timm/models/vision_transformer.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -2152,15 +2152,20 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
21522152
'vit_base_patch16_reg4_gap_256.untrained': _cfg(
21532153
input_size=(3, 256, 256)),
21542154

2155-
'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
2155+
'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k': _cfg(
21562156
hf_hub_id='timm/',
2157-
input_size=(3, 384, 384), crop_pct=1.0),
2157+
input_size=(3, 256, 256), crop_pct=0.95),
21582158
'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k': _cfg(
21592159
hf_hub_id='timm/',
21602160
num_classes=11821,
21612161
input_size=(3, 256, 256), crop_pct=0.95),
2162+
'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
2163+
hf_hub_id='timm/',
2164+
input_size=(3, 384, 384), crop_pct=1.0),
21622165
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
21632166
input_size=(3, 256, 256)),
2167+
'vit_so150m2_patch16_reg1_gap_256.untrained': _cfg(
2168+
input_size=(3, 256, 256), crop_pct=0.95),
21642169

21652170
'vit_intern300m_patch14_448.ogvl_dist': _cfg(
21662171
hf_hub_id='timm/',
@@ -3467,6 +3472,7 @@ def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionT
34673472

34683473
@register_model
34693474
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
3475+
""" SO150M (shape optimized, but diff than paper def, optimized for GPU) """
34703476
model_args = dict(
34713477
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
34723478
class_token=False, reg_tokens=4, global_pool='map',
@@ -3478,6 +3484,7 @@ def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> Visio
34783484

34793485
@register_model
34803486
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
3487+
""" SO150M (shape optimized, but diff than paper def, optimized for GPU) """
34813488
model_args = dict(
34823489
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
34833490
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
@@ -3489,6 +3496,7 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
34893496

34903497
@register_model
34913498
def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
3499+
""" SO150M (shape optimized, but diff than paper def, optimized for GPU) """
34923500
model_args = dict(
34933501
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
34943502
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
@@ -3498,6 +3506,18 @@ def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> Visio
34983506
return model
34993507

35003508

3509+
@register_model
3510+
def vit_so150m2_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
3511+
""" SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
3512+
model_args = dict(
3513+
patch_size=16, embed_dim=896, depth=20, num_heads=14, mlp_ratio=2.429, init_values=1e-5,
3514+
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
3515+
)
3516+
model = _create_vision_transformer(
3517+
'vit_so150m2_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
3518+
return model
3519+
3520+
35013521
@register_model
35023522
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
35033523
model_args = dict(

0 commit comments

Comments
 (0)