Skip to content

Commit dc5c798

Browse files
committed
Prep for siglip2 release
1 parent 105a667 commit dc5c798

File tree

2 files changed

+283
-16
lines changed

2 files changed

+283
-16
lines changed

timm/layers/attention_pool.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
pos_embed: str = '',
3030
pool_type: str = 'token',
3131
norm_layer: Optional[nn.Module] = None,
32+
act_layer: Optional[nn.Module] = nn.GELU,
3233
drop: float = 0.0,
3334
):
3435
super().__init__()
@@ -54,13 +55,18 @@ def __init__(
5455

5556
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
5657
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
57-
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
58-
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
58+
if qk_norm:
59+
qk_norm_layer = norm_layer or nn.LayerNorm
60+
self.q_norm = qk_norm_layer(self.head_dim)
61+
self.k_norm = qk_norm_layer(self.head_dim)
62+
else:
63+
self.q_norm = nn.Identity()
64+
self.k_norm = nn.Identity()
5965
self.proj = nn.Linear(embed_dim, embed_dim)
6066
self.proj_drop = nn.Dropout(drop)
6167

6268
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
63-
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
69+
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer)
6470

6571
self.init_weights()
6672

0 commit comments

Comments
 (0)