@@ -34,10 +34,6 @@ def cast_tuple(val, length = 1):
34
34
def l2norm (t ):
35
35
return F .normalize (t , dim = - 1 )
36
36
37
- def stable_softmax (t , dim = - 1 ):
38
- t = t - t .amax (dim = dim , keepdim = True ).detach ()
39
- return F .softmax (t , dim = dim )
40
-
41
37
# helper classes
42
38
43
39
class PreNormResidual (nn .Module ):
@@ -157,7 +153,7 @@ def forward(self, x, *, xl_memory = None, rel_pos_bias = None):
157
153
causal_mask = torch .ones ((i , j ), dtype = torch .bool , device = device ).triu (j - i + 1 )
158
154
sim = sim .masked_fill (causal_mask , - torch .finfo (sim .dtype ).max )
159
155
160
- attn = stable_softmax ( sim )
156
+ attn = sim . softmax ( dim = - 1 )
161
157
attn = self .dropout (attn )
162
158
163
159
out = einsum ('b h i j, b j d -> b h i d' , attn , v )
@@ -273,7 +269,7 @@ def forward(
273
269
# attention (combining local and distant)
274
270
275
271
sim = torch .cat ((sim_mem , sim ), dim = - 1 )
276
- attn = stable_softmax ( sim )
272
+ attn = sim . softmax ( dim = - 1 )
277
273
attn = self .dropout (attn )
278
274
279
275
local_attn , mem_attn = attn [..., self .num_retrieved_memories :], attn [..., :self .num_retrieved_memories ]
0 commit comments