Skip to content

Commit c42884e

Browse files
committed
complete hyper connected alphafold3
1 parent e18a330 commit c42884e

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

alphafold3_pytorch/alphafold3.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119

120120
from colt5_attention import ConditionalRoutedAttention
121121

122-
from hyper_connections import HyperConnections
122+
from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections
123123

124124
# other external libs
125125

@@ -995,8 +995,8 @@ def __init__(
995995
@typecheck
996996
def forward(
997997
self,
998-
*,
999998
pairwise_repr: Float['b n n d'],
999+
*,
10001000
mask: Bool['b n'] | None = None,
10011001
value_residuals: tuple[Tensor, Tensor] | None = None,
10021002
return_values = False,
@@ -1470,8 +1470,8 @@ def __init__(
14701470
single_transition = Transition(dim = dim_single)
14711471

14721472
layers.append(ModuleList([
1473-
pairwise_block,
1474-
init_hyper_conn(dim = dim_single, branch = single_pre_ln(pair_bias_attn)),
1473+
init_hyper_conn(dim = dim_pairwise, branch = pairwise_block),
1474+
init_hyper_conn(dim = dim_single, additional_input_paths = [('pairwise_repr', dim_pairwise)], branch = single_pre_ln(pair_bias_attn)),
14751475
init_hyper_conn(dim = dim_single, branch = single_pre_ln(single_transition)),
14761476
]))
14771477

@@ -1508,6 +1508,7 @@ def to_layers(
15081508
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
15091509

15101510
single_repr = self.expand_streams(single_repr)
1511+
pairwise_repr = self.expand_streams(pairwise_repr)
15111512

15121513
for _ in range(self.recurrent_depth):
15131514

@@ -1520,7 +1521,7 @@ def to_layers(
15201521
single_transition
15211522
) in self.layers:
15221523

1523-
pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr = pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True)
1524+
pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True)
15241525

15251526
single_repr, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual)
15261527

@@ -1531,6 +1532,7 @@ def to_layers(
15311532
single_repr = single_transition(single_repr)
15321533

15331534
single_repr = self.reduce_streams(single_repr)
1535+
pairwise_repr = self.reduce_streams(pairwise_repr)
15341536

15351537
return single_repr, pairwise_repr
15361538

@@ -1548,7 +1550,7 @@ def pairwise_block_wrapper(layer):
15481550
@wraps(layer)
15491551
def inner(inputs, *args, **kwargs):
15501552
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
1551-
pairwise_repr, pairwise_attn_values = layer(pairwise_repr = pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True)
1553+
pairwise_repr, pairwise_attn_values = layer(pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True)
15521554

15531555
if self.add_value_residual:
15541556
maybe_pairwise_value_residuals = default(maybe_pairwise_value_residuals, pairwise_attn_values)
@@ -1589,6 +1591,7 @@ def inner(inputs, *args, **kwargs):
15891591
wrapped_layers.append(single_transition_wrapper(single_transition))
15901592

15911593
single_repr = self.expand_streams(single_repr)
1594+
pairwise_repr = self.expand_streams(pairwise_repr)
15921595

15931596
for _ in range(self.recurrent_depth):
15941597
inputs = (single_repr, pairwise_repr, mask, None, None)
@@ -1599,6 +1602,7 @@ def inner(inputs, *args, **kwargs):
15991602
single_repr, pairwise_repr, *_ = inputs
16001603

16011604
single_repr = self.reduce_streams(single_repr)
1605+
pairwise_repr = self.reduce_streams(pairwise_repr)
16021606

16031607
return single_repr, pairwise_repr
16041608

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.7.7"
3+
version = "0.7.8"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -41,7 +41,7 @@ dependencies = [
4141
"fair-esm",
4242
"fastapi",
4343
"frame-averaging-pytorch>=0.0.18",
44-
"hyper-connections>=0.0.21",
44+
"hyper-connections>=0.0.23",
4545
"gradio",
4646
"gradio_molecule3d",
4747
"huggingface_hub>=0.21.4",

0 commit comments

Comments
 (0)