Skip to content

Commit b82e345

Browse files
committed
hyperconnect single repr in pairformerstack
1 parent c5d1f7b commit b82e345

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

alphafold3_pytorch/alphafold3.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -1430,10 +1430,18 @@ def __init__(
14301430
num_register_tokens = 0,
14311431
checkpoint = False,
14321432
add_value_residual = False,
1433+
num_residual_streams = 1,
14331434
pairwise_block_kwargs: dict = dict(),
14341435
pair_bias_attn_kwargs: dict = dict()
14351436
):
14361437
super().__init__()
1438+
1439+
# residual / hyper connections
1440+
1441+
init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
1442+
1443+
# layers
1444+
14371445
layers = ModuleList([])
14381446

14391447
pair_bias_attn_kwargs = dict(
@@ -1463,8 +1471,8 @@ def __init__(
14631471

14641472
layers.append(ModuleList([
14651473
pairwise_block,
1466-
single_pre_ln(pair_bias_attn),
1467-
single_pre_ln(single_transition),
1474+
init_hyper_conn(dim = dim_single, branch = single_pre_ln(pair_bias_attn)),
1475+
init_hyper_conn(dim = dim_single, branch = single_pre_ln(single_transition)),
14681476
]))
14691477

14701478
self.layers = layers
@@ -1499,6 +1507,8 @@ def to_layers(
14991507

15001508
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
15011509

1510+
single_repr = self.expand_streams(single_repr)
1511+
15021512
for _ in range(self.recurrent_depth):
15031513

15041514
value_residual = None
@@ -1522,6 +1532,8 @@ def to_layers(
15221532

15231533
single_repr = single_transition(single_repr) + single_repr
15241534

1535+
single_repr = self.reduce_streams(single_repr)
1536+
15251537
return single_repr, pairwise_repr
15261538

15271539
@typecheck
@@ -1550,8 +1562,7 @@ def pair_bias_attn_wrapper(layer):
15501562
@wraps(layer)
15511563
def inner(inputs, *args, **kwargs):
15521564
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
1553-
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
1554-
single_repr = single_repr + attn_out
1565+
single_repr, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
15551566

15561567
if self.add_value_residual:
15571568
maybe_value_residual = default(maybe_value_residual, attn_values)
@@ -1563,7 +1574,7 @@ def single_transition_wrapper(layer):
15631574
@wraps(layer)
15641575
def inner(inputs, *args, **kwargs):
15651576
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
1566-
single_repr = layer(single_repr) + single_repr
1577+
single_repr = layer(single_repr)
15671578
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
15681579
return inner
15691580

@@ -1579,6 +1590,8 @@ def inner(inputs, *args, **kwargs):
15791590
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
15801591
wrapped_layers.append(single_transition_wrapper(single_transition))
15811592

1593+
single_repr = self.expand_streams(single_repr)
1594+
15821595
for _ in range(self.recurrent_depth):
15831596
inputs = (single_repr, pairwise_repr, mask, None, None)
15841597

@@ -1587,6 +1600,8 @@ def inner(inputs, *args, **kwargs):
15871600

15881601
single_repr, pairwise_repr, *_ = inputs
15891602

1603+
single_repr = self.reduce_streams(single_repr)
1604+
15901605
return single_repr, pairwise_repr
15911606

15921607
@typecheck

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.7.5"
3+
version = "0.7.6"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,13 @@ def test_centre_random_augmentation():
303303
@pytest.mark.parametrize('recurrent_depth', (1, 2))
304304
@pytest.mark.parametrize('enable_attn_softclamp', (True, False))
305305
@pytest.mark.parametrize('add_value_residual', (True, False))
306+
@pytest.mark.parametrize('num_residual_streams', (1, 4))
306307
def test_pairformer(
307308
checkpoint,
308309
recurrent_depth,
309310
enable_attn_softclamp,
310-
add_value_residual
311+
add_value_residual,
312+
num_residual_streams
311313
):
312314
single = torch.randn(2, 16, 384).requires_grad_()
313315
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
@@ -319,6 +321,7 @@ def test_pairformer(
319321
recurrent_depth = recurrent_depth,
320322
checkpoint = checkpoint,
321323
add_value_residual = add_value_residual,
324+
num_residual_streams = num_residual_streams,
322325
pair_bias_attn_kwargs = dict(
323326
enable_attn_softclamp = enable_attn_softclamp
324327
)

0 commit comments

Comments
 (0)