Skip to content

Commit ebaa6ee

Browse files
committed
add hyper connections to diffusion transformer
1 parent 85c0de3 commit ebaa6ee

File tree

4 files changed

+64
-21
lines changed

4 files changed

+64
-21
lines changed

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,14 @@ docker run -v .:/data --gpus all -it af3
523523
url = {https://api.semanticscholar.org/CorpusID:273849947}
524524
}
525525
```
526+
527+
```bibtex
528+
@article{Zhu2024HyperConnections,
529+
title = {Hyper-Connections},
530+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
531+
journal = {ArXiv},
532+
year = {2024},
533+
volume = {abs/2409.19606},
534+
url = {https://api.semanticscholar.org/CorpusID:272987528}
535+
}
536+
```

alphafold3_pytorch/alphafold3.py

+46-18
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,19 @@
111111

112112
from alphafold3_pytorch.utils.model_utils import distance_to_dgram
113113

114+
# personal libraries
115+
114116
from frame_averaging_pytorch import FrameAverage
115117

116118
from taylor_series_linear_attention import TaylorSeriesLinearAttn
117119

118120
from colt5_attention import ConditionalRoutedAttention
119121

120-
import einx
121-
from einops import rearrange, repeat, reduce, einsum, pack, unpack
122-
from einops.layers.torch import Rearrange
122+
from hyper_connections import HyperConnections
123123

124-
from tqdm import tqdm
124+
# other external libs
125125

126+
from tqdm import tqdm
126127
from loguru import logger
127128

128129
from importlib.metadata import version
@@ -132,6 +133,12 @@
132133
from Bio.PDB.Structure import Structure
133134
from Bio.PDB.StructureBuilder import StructureBuilder
134135

136+
# einstein notation related
137+
138+
import einx
139+
from einops import rearrange, repeat, reduce, einsum, pack, unpack
140+
from einops.layers.torch import Rearrange
141+
135142
"""
136143
global ein notation:
137144
@@ -2008,6 +2015,7 @@ def __init__(
20082015
use_linear_attn = False,
20092016
checkpoint = False,
20102017
add_value_residual = False,
2018+
num_residual_streams = 1,
20112019
linear_attn_kwargs = dict(
20122020
heads = 8,
20132021
dim_head = 16
@@ -2026,6 +2034,12 @@ def __init__(
20262034

20272035
dim_single_cond = default(dim_single_cond, dim)
20282036

2037+
# hyper connections
2038+
2039+
init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
2040+
2041+
# layers
2042+
20292043
layers = ModuleList([])
20302044

20312045
for i in range(depth):
@@ -2042,6 +2056,8 @@ def __init__(
20422056
**linear_attn_kwargs
20432057
)
20442058

2059+
linear_attn = init_hyper_conn(dim = dim, branch = linear_attn)
2060+
20452061
colt5_attn = None
20462062

20472063
if use_colt5_attn:
@@ -2051,6 +2067,8 @@ def __init__(
20512067
**colt5_attn_kwargs
20522068
)
20532069

2070+
colt5_attn = init_hyper_conn(dim = dim, branch = colt5_attn)
2071+
20542072
accept_value_residual = add_value_residual and not is_first
20552073

20562074
pair_bias_attn = AttentionPairBias(
@@ -2083,8 +2101,8 @@ def __init__(
20832101
layers.append(ModuleList([
20842102
linear_attn,
20852103
colt5_attn,
2086-
conditionable_pair_bias,
2087-
conditionable_transition
2104+
init_hyper_conn(dim = dim, branch = conditionable_pair_bias),
2105+
init_hyper_conn(dim = dim, branch = conditionable_transition)
20882106
]))
20892107

20902108
self.checkpoint = checkpoint
@@ -2112,24 +2130,21 @@ def to_checkpointed_serial_layers(
21122130
windowed_mask: Bool['b nw w (w*2)'] | None = None
21132131
):
21142132

2115-
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)
2116-
21172133
wrapped_layers = []
21182134

21192135
def efficient_attn_wrapper(fn):
21202136
@wraps(fn)
21212137
def inner(inputs):
21222138
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
2123-
noised_repr = fn(noised_repr, mask = mask) + noised_repr
2139+
noised_repr = fn(noised_repr, mask = mask)
21242140
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
21252141
return inner
21262142

21272143
def attn_wrapper(fn):
21282144
@wraps(fn)
21292145
def inner(inputs):
21302146
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
2131-
attn_out, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)
2132-
noised_repr = attn_out + noised_repr
2147+
noised_repr, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)
21332148

21342149
if self.add_value_residual:
21352150
maybe_value_residual = default(maybe_value_residual, attn_values)
@@ -2141,10 +2156,12 @@ def transition_wrapper(fn):
21412156
@wraps(fn)
21422157
def inner(inputs):
21432158
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
2144-
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr
2159+
noised_repr = fn(noised_repr, cond = single_repr)
21452160
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
21462161
return inner
21472162

2163+
# wrap layers
2164+
21482165
for linear_attn, colt5_attn, attn, transition in self.layers:
21492166

21502167
if exists(linear_attn):
@@ -2156,10 +2173,19 @@ def inner(inputs):
21562173
wrapped_layers.append(attn_wrapper(attn))
21572174
wrapped_layers.append(transition_wrapper(transition))
21582175

2176+
# forward
2177+
2178+
noised_repr = self.expand_streams(noised_repr)
2179+
2180+
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)
2181+
21592182
for layer in wrapped_layers:
21602183
inputs = checkpoint(layer, inputs)
21612184

21622185
noised_repr, *_ = inputs
2186+
2187+
noised_repr = self.reduce_streams(noised_repr)
2188+
21632189
return noised_repr
21642190

21652191
@typecheck
@@ -2175,15 +2201,17 @@ def to_serial_layers(
21752201

21762202
value_residual = None
21772203

2204+
noised_repr = self.expand_streams(noised_repr)
2205+
21782206
for linear_attn, colt5_attn, attn, transition in self.layers:
21792207

21802208
if exists(linear_attn):
2181-
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr
2209+
noised_repr = linear_attn(noised_repr, mask = mask)
21822210

21832211
if exists(colt5_attn):
2184-
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr
2212+
noised_repr = colt5_attn(noised_repr, mask = mask)
21852213

2186-
attn_out, attn_values = attn(
2214+
noised_repr, attn_values = attn(
21872215
noised_repr,
21882216
cond = single_repr,
21892217
pairwise_repr = pairwise_repr,
@@ -2193,15 +2221,15 @@ def to_serial_layers(
21932221
value_residual = value_residual
21942222
)
21952223

2196-
noised_repr = noised_repr + attn_out
2197-
21982224
if self.add_value_residual:
21992225
value_residual = default(value_residual, attn_values)
22002226

22012227
noised_repr = transition(
22022228
noised_repr,
22032229
cond = single_repr
2204-
) + noised_repr
2230+
)
2231+
2232+
noised_repr = self.reduce_streams(noised_repr)
22052233

22062234
return noised_repr
22072235

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.7.2"
3+
version = "0.7.3"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -41,6 +41,7 @@ dependencies = [
4141
"fair-esm",
4242
"fastapi",
4343
"frame-averaging-pytorch>=0.0.18",
44+
"hyper-connections>=0.0.20",
4445
"gradio",
4546
"gradio_molecule3d",
4647
"huggingface_hub>=0.21.4",

tests/test_af3.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,13 @@ def test_msa_module(
372372
@pytest.mark.parametrize('use_linear_attn', (False, True))
373373
@pytest.mark.parametrize('use_colt5_attn', (False, True))
374374
@pytest.mark.parametrize('add_value_residual', (False, True))
375+
@pytest.mark.parametrize('num_residual_streams', (1, 4))
375376
def test_diffusion_transformer(
376377
checkpoint,
377378
use_linear_attn,
378379
use_colt5_attn,
379-
add_value_residual
380+
add_value_residual,
381+
num_residual_streams
380382
):
381383

382384
single = torch.randn(2, 16, 384).requires_grad_()
@@ -389,7 +391,8 @@ def test_diffusion_transformer(
389391
checkpoint = checkpoint,
390392
use_linear_attn = use_linear_attn,
391393
use_colt5_attn = use_colt5_attn,
392-
add_value_residual = add_value_residual
394+
add_value_residual = add_value_residual,
395+
num_residual_streams = num_residual_streams
393396
)
394397

395398
single_out = diffusion_transformer(

0 commit comments

Comments
 (0)