Skip to content

Commit 4dc99bd

Browse files
committed
just tempt some student into trying it
1 parent 0791dfe commit 4dc99bd

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,12 @@ docker run -v .:/data --gpus all -it af3
514514
url = {https://api.semanticscholar.org/CorpusID:273532030}
515515
}
516516
```
517+
518+
```bibtex
519+
@inproceedings{Duvvuri2024LASERAW,
520+
title = {LASER: Attention with Exponential Transformation},
521+
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
522+
year = {2024},
523+
url = {https://api.semanticscholar.org/CorpusID:273849947}
524+
}
525+
```

alphafold3_pytorch/attention.py

+19
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def pack_one(t, pattern):
4040
def unpack_one(t, ps, pattern):
4141
return unpack(t, ps, pattern)[0]
4242

43+
def log(t, eps = 1e-20):
44+
return t.clamp(min = eps).log()
45+
4346
def softclamp(t, value):
4447
return (t / value).tanh() * value
4548

@@ -181,6 +184,7 @@ def __init__(
181184
query_bias = True,
182185
window_size = None,
183186
num_memory_kv: int = 0,
187+
laser = False,
184188
enable_attn_softclamp = False,
185189
attn_softclamp_value = 50.,
186190
softmax_full_precision = False
@@ -222,6 +226,10 @@ def __init__(
222226
self.memory_kv = nn.Parameter(torch.zeros(2, heads, num_memory_kv, dim_head))
223227
nn.init.normal_(self.memory_kv, std = 0.02)
224228

229+
# laser attention
230+
231+
self.laser = laser
232+
225233
# gating of value
226234
# allows attention to attend to nothing
227235

@@ -262,6 +270,12 @@ def forward(
262270

263271
q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
264272

273+
# maybe laser
274+
275+
if self.laser:
276+
v_max = v.amax(dim = -2, keepdim = True)
277+
v = (v - v_max).exp()
278+
265279
# attention
266280

267281
out = self.attend(
@@ -272,6 +286,11 @@ def forward(
272286
memory_kv = self.memory_kv
273287
)
274288

289+
# maybe laser
290+
291+
if self.laser:
292+
out = log(out) + v_max
293+
275294
# merge heads
276295

277296
out = self.merge_heads(out)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.6.8"
3+
version = "0.6.9"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)