@@ -184,8 +184,6 @@ def __init__(
184
184
query_bias = True ,
185
185
window_size = None ,
186
186
num_memory_kv : int = 0 ,
187
- laser = False ,
188
- laser_softclamp_value = 15. ,
189
187
enable_attn_softclamp = False ,
190
188
attn_softclamp_value = 50. ,
191
189
softmax_full_precision = False ,
@@ -211,8 +209,6 @@ def __init__(
211
209
dropout = dropout ,
212
210
window_size = window_size ,
213
211
enable_attn_softclamp = enable_attn_softclamp ,
214
- laser = laser ,
215
- laser_softclamp_value = laser_softclamp_value ,
216
212
attn_softclamp_value = attn_softclamp_value ,
217
213
softmax_full_precision = softmax_full_precision
218
214
)
@@ -322,8 +318,6 @@ class Attend(Module):
322
318
def __init__ (
323
319
self ,
324
320
dropout = 0. ,
325
- laser = False ,
326
- laser_softclamp_value = 15. ,
327
321
window_size = None ,
328
322
scale : float | None = None ,
329
323
enable_attn_softclamp = False ,
@@ -352,11 +346,6 @@ def __init__(
352
346
353
347
self .attn_dropout = nn .Dropout (dropout )
354
348
355
- # laser attention
356
-
357
- self .laser = laser
358
- self .laser_softclamp_value = laser_softclamp_value
359
-
360
349
# softclamp attention logits
361
350
# being adopted by a number of recent llms (gemma, grok)
362
351
@@ -477,20 +466,10 @@ def local_attn(
477
466
478
467
attn = sim .softmax (dim = - 1 )
479
468
480
- # maybe laser
481
-
482
- if self .laser :
483
- v = softclamp (v , self .laser_softclamp_value )
484
-
485
469
# aggregate
486
470
487
471
out = einsum (attn , v , "... i j, ... j d -> ... i d" )
488
472
489
- # maybe laser
490
-
491
- if self .laser :
492
- out = log (out )
493
-
494
473
# un-window the output
495
474
496
475
out = rearrange (out , "b h n w d -> b h (n w) d" )
@@ -586,19 +565,8 @@ def forward(
586
565
587
566
attn = self .attn_dropout (attn )
588
567
589
- # maybe laser
590
-
591
- if self .laser :
592
- v_max = v .amax (dim = - 2 , keepdim = True )
593
- v = (v - v_max ).exp ()
594
-
595
568
# aggregate values
596
569
597
570
out = einsum (attn , v , "b h i j, b h j d -> b h i d" )
598
571
599
- # maybe laser
600
-
601
- if self .laser :
602
- out = log (out ) + v_max
603
-
604
572
return out
0 commit comments