Skip to content

Commit 7234f5c

Browse files
committed
Add 448 so150m2 weight/model, add updated internvit 300m weight
1 parent 9ce824c commit 7234f5c

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

timm/models/_hub.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def push_to_hf_hub(
395395

396396

397397
def generate_readme(model_card: dict, model_name: str):
398-
tags = model_card.get('tags', None) or ['image-classification', 'timm']
398+
tags = model_card.get('tags', None) or ['image-classification', 'timm', 'transformers']
399399
readme_text = "---\n"
400400
if tags:
401401
readme_text += "tags:\n"

timm/models/vision_transformer.py

+20
Original file line numberDiff line numberDiff line change
@@ -2174,12 +2174,20 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
21742174
'vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k': _cfg(
21752175
hf_hub_id='timm/',
21762176
input_size=(3, 384, 384), crop_pct=1.0),
2177+
'vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k': _cfg(
2178+
hf_hub_id='timm/',
2179+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
21772180

21782181
'vit_intern300m_patch14_448.ogvl_dist': _cfg(
21792182
hf_hub_id='timm/',
21802183
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
21812184
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
21822185
),
2186+
'vit_intern300m_patch14_448.ogvl_2pt5': _cfg(
2187+
hf_hub_id='timm/',
2188+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
2189+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
2190+
),
21832191

21842192
'aimv2_large_patch14_224.apple_pt': _cfg(
21852193
hf_hub_id='timm/',
@@ -3538,6 +3546,18 @@ def vit_so150m2_patch16_reg1_gap_384(pretrained: bool = False, **kwargs) -> Visi
35383546
return model
35393547

35403548

3549+
@register_model
3550+
def vit_so150m2_patch16_reg1_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3551+
""" SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
3552+
model_args = dict(
3553+
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
3554+
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
3555+
)
3556+
model = _create_vision_transformer(
3557+
'vit_so150m2_patch16_reg1_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
3558+
return model
3559+
3560+
35413561
@register_model
35423562
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
35433563
model_args = dict(

0 commit comments

Comments
 (0)