@@ -40,6 +40,9 @@ def pack_one(t, pattern):
40
40
def unpack_one (t , ps , pattern ):
41
41
return unpack (t , ps , pattern )[0 ]
42
42
43
+ def log (t , eps = 1e-20 ):
44
+ return t .clamp (min = eps ).log ()
45
+
43
46
def softclamp (t , value ):
44
47
return (t / value ).tanh () * value
45
48
@@ -181,6 +184,7 @@ def __init__(
181
184
query_bias = True ,
182
185
window_size = None ,
183
186
num_memory_kv : int = 0 ,
187
+ laser = False ,
184
188
enable_attn_softclamp = False ,
185
189
attn_softclamp_value = 50. ,
186
190
softmax_full_precision = False
@@ -222,6 +226,10 @@ def __init__(
222
226
self .memory_kv = nn .Parameter (torch .zeros (2 , heads , num_memory_kv , dim_head ))
223
227
nn .init .normal_ (self .memory_kv , std = 0.02 )
224
228
229
+ # laser attention
230
+
231
+ self .laser = laser
232
+
225
233
# gating of value
226
234
# allows attention to attend to nothing
227
235
@@ -262,6 +270,12 @@ def forward(
262
270
263
271
q , k , v = tuple (self .split_heads (t ) for t in (q , k , v ))
264
272
273
+ # maybe laser
274
+
275
+ if self .laser :
276
+ v_max = v .amax (dim = - 2 , keepdim = True )
277
+ v = (v - v_max ).exp ()
278
+
265
279
# attention
266
280
267
281
out = self .attend (
@@ -272,6 +286,11 @@ def forward(
272
286
memory_kv = self .memory_kv
273
287
)
274
288
289
+ # maybe laser
290
+
291
+ if self .laser :
292
+ out = log (out ) + v_max
293
+
275
294
# merge heads
276
295
277
296
out = self .merge_heads (out )
0 commit comments