Skip to content

Commit 85c0de3

Browse files
committed
upgrade value residual to learnt mixing per token / head
1 parent af6f972 commit 85c0de3

File tree

3 files changed

+71
-39
lines changed

3 files changed

+71
-39
lines changed

alphafold3_pytorch/alphafold3.py

+46-30
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def forward(
683683
**kwargs
684684
) -> (
685685
Float['b n d'] |
686-
tuple[Float['b n d'], Float['b _ _']]
686+
tuple[Float['b n d'], Float['b _ _ _']]
687687
):
688688
x = self.adaptive_norm(x, cond = cond)
689689

@@ -797,11 +797,11 @@ def forward(
797797
pairwise_repr: Float["b n n dp"] | Float["b nw w (w*2) dp"], # type: ignore
798798
attn_bias: Float["b n n"] | Float["b nw w (w*2)"] | None = None, # type: ignore
799799
return_values: bool = False,
800-
value_residual: Float['b _ _'] | None = None,
800+
value_residual: Float['b _ _ _'] | None = None,
801801
**kwargs,
802802
) -> (
803803
Float['b n ds'] |
804-
tuple[Float['b n ds'], Float['b _ _']]
804+
tuple[Float['b n ds'], Float['b _ _ _']]
805805
): # type: ignore
806806

807807
"""Perform the forward pass.
@@ -961,6 +961,7 @@ def __init__(
961961
tri_attn_heads = 4,
962962
dropout_row_prob = 0.25,
963963
dropout_col_prob = 0.25,
964+
accept_value_residual = False
964965
):
965966
super().__init__()
966967

@@ -974,7 +975,8 @@ def __init__(
974975
tri_attn_kwargs = dict(
975976
dim = dim_pairwise,
976977
heads = tri_attn_heads,
977-
dim_head = tri_attn_dim_head
978+
dim_head = tri_attn_dim_head,
979+
accept_value_residual = accept_value_residual
978980
)
979981

980982
self.tri_mult_outgoing = pre_ln(TriangleMultiplication(mix = 'outgoing', dropout = dropout_row_prob, dropout_type = 'row', **tri_mult_kwargs))
@@ -1436,16 +1438,20 @@ def __init__(
14361438
**pair_bias_attn_kwargs
14371439
)
14381440

1439-
for _ in range(depth):
1441+
for i in range(depth):
1442+
1443+
is_first = i == 0
1444+
accept_value_residual = add_value_residual and not is_first
14401445

14411446
single_pre_ln = partial(PreLayerNorm, dim = dim_single)
14421447

14431448
pairwise_block = PairwiseBlock(
14441449
dim_pairwise = dim_pairwise,
1450+
accept_value_residual = accept_value_residual,
14451451
**pairwise_block_kwargs
14461452
)
14471453

1448-
pair_bias_attn = AttentionPairBias(**pair_bias_attn_kwargs)
1454+
pair_bias_attn = AttentionPairBias(accept_value_residual = accept_value_residual, **pair_bias_attn_kwargs)
14491455
single_transition = Transition(dim = dim_single)
14501456

14511457
layers.append(ModuleList([
@@ -1486,10 +1492,11 @@ def to_layers(
14861492

14871493
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
14881494

1489-
value_residual = None
1490-
pairwise_value_residuals = None
1491-
14921495
for _ in range(self.recurrent_depth):
1496+
1497+
value_residual = None
1498+
pairwise_value_residuals = None
1499+
14931500
for (
14941501
pairwise_block,
14951502
pair_bias_attn,
@@ -1520,54 +1527,59 @@ def to_checkpointed_layers(
15201527

15211528
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
15221529

1523-
inputs = (single_repr, pairwise_repr, mask, None)
1524-
15251530
def pairwise_block_wrapper(layer):
15261531
@wraps(layer)
15271532
def inner(inputs, *args, **kwargs):
1528-
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
1529-
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
1530-
return single_repr, pairwise_repr, mask, maybe_value_residual
1533+
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
1534+
pairwise_repr, pairwise_attn_values = layer(pairwise_repr = pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True)
1535+
1536+
if self.add_value_residual:
1537+
maybe_pairwise_value_residuals = default(maybe_pairwise_value_residuals, pairwise_attn_values)
1538+
1539+
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
15311540
return inner
15321541

15331542
def pair_bias_attn_wrapper(layer):
15341543
@wraps(layer)
15351544
def inner(inputs, *args, **kwargs):
1536-
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
1545+
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
15371546
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
15381547
single_repr = single_repr + attn_out
15391548

15401549
if self.add_value_residual:
15411550
maybe_value_residual = default(maybe_value_residual, attn_values)
15421551

1543-
return single_repr, pairwise_repr, mask, maybe_value_residual
1552+
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
15441553
return inner
15451554

15461555
def single_transition_wrapper(layer):
15471556
@wraps(layer)
15481557
def inner(inputs, *args, **kwargs):
1549-
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
1558+
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
15501559
single_repr = layer(single_repr) + single_repr
1551-
return single_repr, pairwise_repr, mask, maybe_value_residual
1560+
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
15521561
return inner
15531562

15541563
wrapped_layers = []
15551564

1565+
for (
1566+
pairwise_block,
1567+
pair_bias_attn,
1568+
single_transition
1569+
) in self.layers:
1570+
1571+
wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
1572+
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
1573+
wrapped_layers.append(single_transition_wrapper(single_transition))
1574+
15561575
for _ in range(self.recurrent_depth):
1557-
for (
1558-
pairwise_block,
1559-
pair_bias_attn,
1560-
single_transition
1561-
) in self.layers:
1576+
inputs = (single_repr, pairwise_repr, mask, None, None)
15621577

1563-
wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
1564-
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
1565-
wrapped_layers.append(single_transition_wrapper(single_transition))
1578+
for layer in wrapped_layers:
1579+
inputs = checkpoint(layer, inputs)
15661580

1567-
for layer in wrapped_layers:
1568-
inputs = checkpoint(layer, inputs)
1581+
single_repr, pairwise_repr, *_ = inputs
15691582

1570-
single_repr, pairwise_repr, *_ = inputs
15711583
return single_repr, pairwise_repr
15721584

15731585
@typecheck
@@ -2016,7 +2028,8 @@ def __init__(
20162028

20172029
layers = ModuleList([])
20182030

2019-
for _ in range(depth):
2031+
for i in range(depth):
2032+
is_first = i == 0
20202033

20212034
linear_attn = None
20222035

@@ -2038,12 +2051,15 @@ def __init__(
20382051
**colt5_attn_kwargs
20392052
)
20402053

2054+
accept_value_residual = add_value_residual and not is_first
2055+
20412056
pair_bias_attn = AttentionPairBias(
20422057
dim = dim,
20432058
dim_pairwise = dim_pairwise,
20442059
heads = heads,
20452060
window_size = attn_window_size,
20462061
num_memory_kv = attn_num_memory_kv,
2062+
accept_value_residual = accept_value_residual,
20472063
**attn_pair_bias_kwargs
20482064
)
20492065

alphafold3_pytorch/attention.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def __init__(
188188
laser_softclamp_value = 15.,
189189
enable_attn_softclamp = False,
190190
attn_softclamp_value = 50.,
191-
softmax_full_precision = False
191+
softmax_full_precision = False,
192+
accept_value_residual = False
192193
):
193194
super().__init__()
194195
"""
@@ -237,6 +238,18 @@ def __init__(
237238
if gate_output:
238239
self.to_gates = nn.Sequential(LinearNoBias(dim, dim_inner), nn.Sigmoid())
239240

241+
# learned value residual mixing
242+
# even greater improvements on top of value residual learning, discovered by open source community
243+
244+
self.accept_value_residual = accept_value_residual
245+
246+
if accept_value_residual:
247+
self.to_value_residual_mix = nn.Sequential(
248+
LinearNoBias(dim, heads),
249+
Rearrange('b n h -> b h n 1'),
250+
nn.Sigmoid()
251+
)
252+
240253
@typecheck
241254
def forward(
242255
self,
@@ -246,28 +259,31 @@ def forward(
246259
windowed_mask: Bool['b nw w (w*2)'] | None = None,
247260
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
248261
return_values: bool = False,
249-
value_residual: Float['b j dh'] | None = None,
262+
value_residual: Float['b h j dh'] | None = None,
250263

251264
) -> (
252265
Float['b i d'] |
253-
tuple[Float['b i d'], Float['b j _']]
266+
tuple[Float['b i d'], Float['b h j dh']]
254267
):
255268

256269
q = self.to_q(seq)
257270

258271
context_seq = default(context, seq)
259272
k, v = self.to_kv(context_seq).chunk(2, dim = -1)
260273

274+
# split heads
275+
276+
q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
277+
261278
# handle value residual
262279

263280
orig_v = v
264281

265-
if exists(value_residual):
266-
v = 0.5 * (v + value_residual)
282+
assert not (self.accept_value_residual ^ exists(value_residual))
267283

268-
# split heads
269-
270-
q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
284+
if exists(value_residual):
285+
mix = self.to_value_residual_mix(seq)
286+
v = v.lerp(value_residual, mix)
271287

272288
# attention
273289

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.7.0"
3+
version = "0.7.2"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)