@@ -29,6 +29,7 @@ def __init__(
29
29
pos_embed : str = '' ,
30
30
pool_type : str = 'token' ,
31
31
norm_layer : Optional [nn .Module ] = None ,
32
+ act_layer : Optional [nn .Module ] = nn .GELU ,
32
33
drop : float = 0.0 ,
33
34
):
34
35
super ().__init__ ()
@@ -54,13 +55,18 @@ def __init__(
54
55
55
56
self .q = nn .Linear (embed_dim , embed_dim , bias = qkv_bias )
56
57
self .kv = nn .Linear (embed_dim , embed_dim * 2 , bias = qkv_bias )
57
- self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
58
- self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
58
+ if qk_norm :
59
+ qk_norm_layer = norm_layer or nn .LayerNorm
60
+ self .q_norm = qk_norm_layer (self .head_dim )
61
+ self .k_norm = qk_norm_layer (self .head_dim )
62
+ else :
63
+ self .q_norm = nn .Identity ()
64
+ self .k_norm = nn .Identity ()
59
65
self .proj = nn .Linear (embed_dim , embed_dim )
60
66
self .proj_drop = nn .Dropout (drop )
61
67
62
68
self .norm = norm_layer (out_features ) if norm_layer is not None else nn .Identity ()
63
- self .mlp = Mlp (embed_dim , int (embed_dim * mlp_ratio ))
69
+ self .mlp = Mlp (embed_dim , int (embed_dim * mlp_ratio ), act_layer = act_layer )
64
70
65
71
self .init_weights ()
66
72
0 commit comments