Skip to content

Commit c5d1f7b

Browse files
committed
remove laser attention
1 parent 119fcb9 commit c5d1f7b

File tree

3 files changed

+1
-42
lines changed

3 files changed

+1
-42
lines changed

README.md

-9
Original file line numberDiff line numberDiff line change
@@ -515,15 +515,6 @@ docker run -v .:/data --gpus all -it af3
515515
}
516516
```
517517

518-
```bibtex
519-
@inproceedings{Duvvuri2024LASERAW,
520-
title = {LASER: Attention with Exponential Transformation},
521-
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
522-
year = {2024},
523-
url = {https://api.semanticscholar.org/CorpusID:273849947}
524-
}
525-
```
526-
527518
```bibtex
528519
@article{Zhu2024HyperConnections,
529520
title = {Hyper-Connections},

alphafold3_pytorch/attention.py

-32
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,6 @@ def __init__(
184184
query_bias = True,
185185
window_size = None,
186186
num_memory_kv: int = 0,
187-
laser = False,
188-
laser_softclamp_value = 15.,
189187
enable_attn_softclamp = False,
190188
attn_softclamp_value = 50.,
191189
softmax_full_precision = False,
@@ -211,8 +209,6 @@ def __init__(
211209
dropout = dropout,
212210
window_size = window_size,
213211
enable_attn_softclamp = enable_attn_softclamp,
214-
laser = laser,
215-
laser_softclamp_value = laser_softclamp_value,
216212
attn_softclamp_value = attn_softclamp_value,
217213
softmax_full_precision = softmax_full_precision
218214
)
@@ -322,8 +318,6 @@ class Attend(Module):
322318
def __init__(
323319
self,
324320
dropout = 0.,
325-
laser = False,
326-
laser_softclamp_value = 15.,
327321
window_size = None,
328322
scale: float | None = None,
329323
enable_attn_softclamp = False,
@@ -352,11 +346,6 @@ def __init__(
352346

353347
self.attn_dropout = nn.Dropout(dropout)
354348

355-
# laser attention
356-
357-
self.laser = laser
358-
self.laser_softclamp_value = laser_softclamp_value
359-
360349
# softclamp attention logits
361350
# being adopted by a number of recent llms (gemma, grok)
362351

@@ -477,20 +466,10 @@ def local_attn(
477466

478467
attn = sim.softmax(dim = -1)
479468

480-
# maybe laser
481-
482-
if self.laser:
483-
v = softclamp(v, self.laser_softclamp_value)
484-
485469
# aggregate
486470

487471
out = einsum(attn, v, "... i j, ... j d -> ... i d")
488472

489-
# maybe laser
490-
491-
if self.laser:
492-
out = log(out)
493-
494473
# un-window the output
495474

496475
out = rearrange(out, "b h n w d -> b h (n w) d")
@@ -586,19 +565,8 @@ def forward(
586565

587566
attn = self.attn_dropout(attn)
588567

589-
# maybe laser
590-
591-
if self.laser:
592-
v_max = v.amax(dim = -2, keepdim = True)
593-
v = (v - v_max).exp()
594-
595568
# aggregate values
596569

597570
out = einsum(attn, v, "b h i j, b h j d -> b h i d")
598571

599-
# maybe laser
600-
601-
if self.laser:
602-
out = log(out) + v_max
603-
604572
return out

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.7.4"
3+
version = "0.7.5"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)