111
111
112
112
from alphafold3_pytorch .utils .model_utils import distance_to_dgram
113
113
114
+ # personal libraries
115
+
114
116
from frame_averaging_pytorch import FrameAverage
115
117
116
118
from taylor_series_linear_attention import TaylorSeriesLinearAttn
117
119
118
120
from colt5_attention import ConditionalRoutedAttention
119
121
120
- import einx
121
- from einops import rearrange , repeat , reduce , einsum , pack , unpack
122
- from einops .layers .torch import Rearrange
122
+ from hyper_connections import HyperConnections
123
123
124
- from tqdm import tqdm
124
+ # other external libs
125
125
126
+ from tqdm import tqdm
126
127
from loguru import logger
127
128
128
129
from importlib .metadata import version
132
133
from Bio .PDB .Structure import Structure
133
134
from Bio .PDB .StructureBuilder import StructureBuilder
134
135
136
+ # einstein notation related
137
+
138
+ import einx
139
+ from einops import rearrange , repeat , reduce , einsum , pack , unpack
140
+ from einops .layers .torch import Rearrange
141
+
135
142
"""
136
143
global ein notation:
137
144
@@ -2008,6 +2015,7 @@ def __init__(
2008
2015
use_linear_attn = False ,
2009
2016
checkpoint = False ,
2010
2017
add_value_residual = False ,
2018
+ num_residual_streams = 1 ,
2011
2019
linear_attn_kwargs = dict (
2012
2020
heads = 8 ,
2013
2021
dim_head = 16
@@ -2026,6 +2034,12 @@ def __init__(
2026
2034
2027
2035
dim_single_cond = default (dim_single_cond , dim )
2028
2036
2037
+ # hyper connections
2038
+
2039
+ init_hyper_conn , self .expand_streams , self .reduce_streams = HyperConnections .get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
2040
+
2041
+ # layers
2042
+
2029
2043
layers = ModuleList ([])
2030
2044
2031
2045
for i in range (depth ):
@@ -2042,6 +2056,8 @@ def __init__(
2042
2056
** linear_attn_kwargs
2043
2057
)
2044
2058
2059
+ linear_attn = init_hyper_conn (dim = dim , branch = linear_attn )
2060
+
2045
2061
colt5_attn = None
2046
2062
2047
2063
if use_colt5_attn :
@@ -2051,6 +2067,8 @@ def __init__(
2051
2067
** colt5_attn_kwargs
2052
2068
)
2053
2069
2070
+ colt5_attn = init_hyper_conn (dim = dim , branch = colt5_attn )
2071
+
2054
2072
accept_value_residual = add_value_residual and not is_first
2055
2073
2056
2074
pair_bias_attn = AttentionPairBias (
@@ -2083,8 +2101,8 @@ def __init__(
2083
2101
layers .append (ModuleList ([
2084
2102
linear_attn ,
2085
2103
colt5_attn ,
2086
- conditionable_pair_bias ,
2087
- conditionable_transition
2104
+ init_hyper_conn ( dim = dim , branch = conditionable_pair_bias ) ,
2105
+ init_hyper_conn ( dim = dim , branch = conditionable_transition )
2088
2106
]))
2089
2107
2090
2108
self .checkpoint = checkpoint
@@ -2112,24 +2130,21 @@ def to_checkpointed_serial_layers(
2112
2130
windowed_mask : Bool ['b nw w (w*2)' ] | None = None
2113
2131
):
2114
2132
2115
- inputs = (noised_repr , single_repr , pairwise_repr , mask , windowed_mask , None )
2116
-
2117
2133
wrapped_layers = []
2118
2134
2119
2135
def efficient_attn_wrapper (fn ):
2120
2136
@wraps (fn )
2121
2137
def inner (inputs ):
2122
2138
noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2123
- noised_repr = fn (noised_repr , mask = mask ) + noised_repr
2139
+ noised_repr = fn (noised_repr , mask = mask )
2124
2140
return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
2125
2141
return inner
2126
2142
2127
2143
def attn_wrapper (fn ):
2128
2144
@wraps (fn )
2129
2145
def inner (inputs ):
2130
2146
noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2131
- attn_out , attn_values = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask , value_residual = maybe_value_residual , return_values = True )
2132
- noised_repr = attn_out + noised_repr
2147
+ noised_repr , attn_values = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask , value_residual = maybe_value_residual , return_values = True )
2133
2148
2134
2149
if self .add_value_residual :
2135
2150
maybe_value_residual = default (maybe_value_residual , attn_values )
@@ -2141,10 +2156,12 @@ def transition_wrapper(fn):
2141
2156
@wraps (fn )
2142
2157
def inner (inputs ):
2143
2158
noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2144
- noised_repr = fn (noised_repr , cond = single_repr ) + noised_repr
2159
+ noised_repr = fn (noised_repr , cond = single_repr )
2145
2160
return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
2146
2161
return inner
2147
2162
2163
+ # wrap layers
2164
+
2148
2165
for linear_attn , colt5_attn , attn , transition in self .layers :
2149
2166
2150
2167
if exists (linear_attn ):
@@ -2156,10 +2173,19 @@ def inner(inputs):
2156
2173
wrapped_layers .append (attn_wrapper (attn ))
2157
2174
wrapped_layers .append (transition_wrapper (transition ))
2158
2175
2176
+ # forward
2177
+
2178
+ noised_repr = self .expand_streams (noised_repr )
2179
+
2180
+ inputs = (noised_repr , single_repr , pairwise_repr , mask , windowed_mask , None )
2181
+
2159
2182
for layer in wrapped_layers :
2160
2183
inputs = checkpoint (layer , inputs )
2161
2184
2162
2185
noised_repr , * _ = inputs
2186
+
2187
+ noised_repr = self .reduce_streams (noised_repr )
2188
+
2163
2189
return noised_repr
2164
2190
2165
2191
@typecheck
@@ -2175,15 +2201,17 @@ def to_serial_layers(
2175
2201
2176
2202
value_residual = None
2177
2203
2204
+ noised_repr = self .expand_streams (noised_repr )
2205
+
2178
2206
for linear_attn , colt5_attn , attn , transition in self .layers :
2179
2207
2180
2208
if exists (linear_attn ):
2181
- noised_repr = linear_attn (noised_repr , mask = mask ) + noised_repr
2209
+ noised_repr = linear_attn (noised_repr , mask = mask )
2182
2210
2183
2211
if exists (colt5_attn ):
2184
- noised_repr = colt5_attn (noised_repr , mask = mask ) + noised_repr
2212
+ noised_repr = colt5_attn (noised_repr , mask = mask )
2185
2213
2186
- attn_out , attn_values = attn (
2214
+ noised_repr , attn_values = attn (
2187
2215
noised_repr ,
2188
2216
cond = single_repr ,
2189
2217
pairwise_repr = pairwise_repr ,
@@ -2193,15 +2221,15 @@ def to_serial_layers(
2193
2221
value_residual = value_residual
2194
2222
)
2195
2223
2196
- noised_repr = noised_repr + attn_out
2197
-
2198
2224
if self .add_value_residual :
2199
2225
value_residual = default (value_residual , attn_values )
2200
2226
2201
2227
noised_repr = transition (
2202
2228
noised_repr ,
2203
2229
cond = single_repr
2204
- ) + noised_repr
2230
+ )
2231
+
2232
+ noised_repr = self .reduce_streams (noised_repr )
2205
2233
2206
2234
return noised_repr
2207
2235
0 commit comments