@@ -1732,6 +1732,11 @@ def checkpoint_filter_fn(
17321732 state_dict = _convert_openai_clip (state_dict , model , prefix = 'module.visual.' )
17331733 elif "mask_token" in state_dict :
17341734 state_dict = _convert_dinov2 (state_dict , model )
1735+ elif "vision_encoder.mask_token" in state_dict :
1736+ # TIPSv2 multimodal checkpoint, vision encoder is DINOv2-style under a 'vision_encoder.' prefix
1737+ ve_prefix = 'vision_encoder.'
1738+ state_dict = {k [len (ve_prefix ):]: v for k , v in state_dict .items () if k .startswith (ve_prefix )}
1739+ state_dict = _convert_dinov2 (state_dict , model )
17351740 elif any ('beit3.' in k for k in state_dict .keys ()):
17361741 # BEiT3 model - multimodal checkpoint with beit3.* prefix
17371742 state_dict = _convert_beit3 (state_dict , model )
@@ -2043,6 +2048,29 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20432048 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ,
20442049 input_size = (3 , 518 , 518 ), crop_pct = 1.0 ),
20452050
2051+ # TIPSv2 (DINOv2-style ViT w/ 1 register, no input normalization).
2052+ # Paper: https://arxiv.org/abs/2604.12012 Weights: https://huggingface.co/google/tipsv2-b14
2053+ 'vit_base_patch14_reg1_tipsv2.webli' : _cfg (
2054+ hf_hub_id = 'timm/' ,
2055+ license = 'apache-2.0' ,
2056+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), num_classes = 0 ,
2057+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ),
2058+ 'vit_large_patch14_reg1_tipsv2.webli' : _cfg (
2059+ hf_hub_id = 'timm/' ,
2060+ license = 'apache-2.0' ,
2061+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), num_classes = 0 ,
2062+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ),
2063+ 'vit_so400m_patch14_reg1_tipsv2.webli' : _cfg (
2064+ hf_hub_id = 'timm/' ,
2065+ license = 'apache-2.0' ,
2066+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), num_classes = 0 ,
2067+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ),
2068+ 'vit_giant_patch14_reg1_tipsv2.webli' : _cfg (
2069+ hf_hub_id = 'timm/' ,
2070+ license = 'apache-2.0' ,
2071+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), num_classes = 0 ,
2072+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ),
2073+
20462074 # ViT ImageNet-21K-P pretraining by MILL
20472075 'vit_base_patch16_224_miil.in21k' : _cfg (
20482076 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth' ,
@@ -3926,6 +3954,61 @@ def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionT
39263954 return model
39273955
39283956
3957+ @register_model
3958+ def vit_base_patch14_reg1_tipsv2 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3959+ """ ViT-B/14 for TIPSv2 (DINOv2-style w/ 1 register token, LayerScale init=1.0).
3960+ """
3961+ model_args = dict (
3962+ patch_size = 14 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1.0 ,
3963+ reg_tokens = 1 , no_embed_class = True ,
3964+ )
3965+ model = _create_vision_transformer (
3966+ 'vit_base_patch14_reg1_tipsv2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3967+ return model
3968+
3969+
3970+ @register_model
3971+ def vit_large_patch14_reg1_tipsv2 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3972+ """ ViT-L/14 for TIPSv2 (DINOv2-style w/ 1 register token, LayerScale init=1.0).
3973+ """
3974+ model_args = dict (
3975+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , init_values = 1.0 ,
3976+ reg_tokens = 1 , no_embed_class = True ,
3977+ )
3978+ model = _create_vision_transformer (
3979+ 'vit_large_patch14_reg1_tipsv2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3980+ return model
3981+
3982+
3983+ @register_model
3984+ def vit_so400m_patch14_reg1_tipsv2 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3985+ """ SoViT-400M/14 for TIPSv2 (DINOv2-style w/ 1 register token, LayerScale init=1.0).
3986+ """
3987+ model_args = dict (
3988+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , init_values = 1.0 ,
3989+ mlp_ratio = 4304 / 1152 , reg_tokens = 1 , no_embed_class = True ,
3990+ )
3991+ model = _create_vision_transformer (
3992+ 'vit_so400m_patch14_reg1_tipsv2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3993+ return model
3994+
3995+
3996+ @register_model
3997+ def vit_giant_patch14_reg1_tipsv2 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3998+ """ ViT-G/14 for TIPSv2 (DINOv2-style w/ SwiGLU FFN, 1 register token, LayerScale init=1.0).
3999+ """
4000+ # SwiGLU hidden after DINOv2's (2/3, align-8) reduction is 4096; SwiGLUPacked fc1 outputs 2*4096=8192,
4001+ # so mlp_ratio = 8192 / 1536 = 16/3.
4002+ model_args = dict (
4003+ patch_size = 14 , embed_dim = 1536 , depth = 40 , num_heads = 24 , init_values = 1.0 ,
4004+ mlp_ratio = 2.66667 * 2 , mlp_layer = SwiGLUPacked , act_layer = nn .SiLU ,
4005+ reg_tokens = 1 , no_embed_class = True ,
4006+ )
4007+ model = _create_vision_transformer (
4008+ 'vit_giant_patch14_reg1_tipsv2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
4009+ return model
4010+
4011+
39294012@register_model
39304013def vit_base_patch32_siglip_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
39314014 model_args = dict (
0 commit comments