@@ -683,7 +683,7 @@ def forward(
683
683
** kwargs
684
684
) -> (
685
685
Float ['b n d' ] |
686
- tuple [Float ['b n d' ], Float ['b _ _' ]]
686
+ tuple [Float ['b n d' ], Float ['b _ _ _ ' ]]
687
687
):
688
688
x = self .adaptive_norm (x , cond = cond )
689
689
@@ -797,11 +797,11 @@ def forward(
797
797
pairwise_repr : Float ["b n n dp" ] | Float ["b nw w (w*2) dp" ], # type: ignore
798
798
attn_bias : Float ["b n n" ] | Float ["b nw w (w*2)" ] | None = None , # type: ignore
799
799
return_values : bool = False ,
800
- value_residual : Float ['b _ _' ] | None = None ,
800
+ value_residual : Float ['b _ _ _ ' ] | None = None ,
801
801
** kwargs ,
802
802
) -> (
803
803
Float ['b n ds' ] |
804
- tuple [Float ['b n ds' ], Float ['b _ _' ]]
804
+ tuple [Float ['b n ds' ], Float ['b _ _ _ ' ]]
805
805
): # type: ignore
806
806
807
807
"""Perform the forward pass.
@@ -961,6 +961,7 @@ def __init__(
961
961
tri_attn_heads = 4 ,
962
962
dropout_row_prob = 0.25 ,
963
963
dropout_col_prob = 0.25 ,
964
+ accept_value_residual = False
964
965
):
965
966
super ().__init__ ()
966
967
@@ -974,7 +975,8 @@ def __init__(
974
975
tri_attn_kwargs = dict (
975
976
dim = dim_pairwise ,
976
977
heads = tri_attn_heads ,
977
- dim_head = tri_attn_dim_head
978
+ dim_head = tri_attn_dim_head ,
979
+ accept_value_residual = accept_value_residual
978
980
)
979
981
980
982
self .tri_mult_outgoing = pre_ln (TriangleMultiplication (mix = 'outgoing' , dropout = dropout_row_prob , dropout_type = 'row' , ** tri_mult_kwargs ))
@@ -1436,16 +1438,20 @@ def __init__(
1436
1438
** pair_bias_attn_kwargs
1437
1439
)
1438
1440
1439
- for _ in range (depth ):
1441
+ for i in range (depth ):
1442
+
1443
+ is_first = i == 0
1444
+ accept_value_residual = add_value_residual and not is_first
1440
1445
1441
1446
single_pre_ln = partial (PreLayerNorm , dim = dim_single )
1442
1447
1443
1448
pairwise_block = PairwiseBlock (
1444
1449
dim_pairwise = dim_pairwise ,
1450
+ accept_value_residual = accept_value_residual ,
1445
1451
** pairwise_block_kwargs
1446
1452
)
1447
1453
1448
- pair_bias_attn = AttentionPairBias (** pair_bias_attn_kwargs )
1454
+ pair_bias_attn = AttentionPairBias (accept_value_residual = accept_value_residual , ** pair_bias_attn_kwargs )
1449
1455
single_transition = Transition (dim = dim_single )
1450
1456
1451
1457
layers .append (ModuleList ([
@@ -1486,10 +1492,11 @@ def to_layers(
1486
1492
1487
1493
) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1488
1494
1489
- value_residual = None
1490
- pairwise_value_residuals = None
1491
-
1492
1495
for _ in range (self .recurrent_depth ):
1496
+
1497
+ value_residual = None
1498
+ pairwise_value_residuals = None
1499
+
1493
1500
for (
1494
1501
pairwise_block ,
1495
1502
pair_bias_attn ,
@@ -1520,54 +1527,59 @@ def to_checkpointed_layers(
1520
1527
1521
1528
) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1522
1529
1523
- inputs = (single_repr , pairwise_repr , mask , None )
1524
-
1525
1530
def pairwise_block_wrapper (layer ):
1526
1531
@wraps (layer )
1527
1532
def inner (inputs , * args , ** kwargs ):
1528
- single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1529
- pairwise_repr = layer (pairwise_repr = pairwise_repr , mask = mask )
1530
- return single_repr , pairwise_repr , mask , maybe_value_residual
1533
+ single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1534
+ pairwise_repr , pairwise_attn_values = layer (pairwise_repr = pairwise_repr , mask = mask , value_residuals = maybe_pairwise_value_residuals , return_values = True )
1535
+
1536
+ if self .add_value_residual :
1537
+ maybe_pairwise_value_residuals = default (maybe_pairwise_value_residuals , pairwise_attn_values )
1538
+
1539
+ return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
1531
1540
return inner
1532
1541
1533
1542
def pair_bias_attn_wrapper (layer ):
1534
1543
@wraps (layer )
1535
1544
def inner (inputs , * args , ** kwargs ):
1536
- single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1545
+ single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1537
1546
attn_out , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1538
1547
single_repr = single_repr + attn_out
1539
1548
1540
1549
if self .add_value_residual :
1541
1550
maybe_value_residual = default (maybe_value_residual , attn_values )
1542
1551
1543
- return single_repr , pairwise_repr , mask , maybe_value_residual
1552
+ return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
1544
1553
return inner
1545
1554
1546
1555
def single_transition_wrapper (layer ):
1547
1556
@wraps (layer )
1548
1557
def inner (inputs , * args , ** kwargs ):
1549
- single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1558
+ single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1550
1559
single_repr = layer (single_repr ) + single_repr
1551
- return single_repr , pairwise_repr , mask , maybe_value_residual
1560
+ return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
1552
1561
return inner
1553
1562
1554
1563
wrapped_layers = []
1555
1564
1565
+ for (
1566
+ pairwise_block ,
1567
+ pair_bias_attn ,
1568
+ single_transition
1569
+ ) in self .layers :
1570
+
1571
+ wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
1572
+ wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1573
+ wrapped_layers .append (single_transition_wrapper (single_transition ))
1574
+
1556
1575
for _ in range (self .recurrent_depth ):
1557
- for (
1558
- pairwise_block ,
1559
- pair_bias_attn ,
1560
- single_transition
1561
- ) in self .layers :
1576
+ inputs = (single_repr , pairwise_repr , mask , None , None )
1562
1577
1563
- wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
1564
- wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1565
- wrapped_layers .append (single_transition_wrapper (single_transition ))
1578
+ for layer in wrapped_layers :
1579
+ inputs = checkpoint (layer , inputs )
1566
1580
1567
- for layer in wrapped_layers :
1568
- inputs = checkpoint (layer , inputs )
1581
+ single_repr , pairwise_repr , * _ = inputs
1569
1582
1570
- single_repr , pairwise_repr , * _ = inputs
1571
1583
return single_repr , pairwise_repr
1572
1584
1573
1585
@typecheck
@@ -2016,7 +2028,8 @@ def __init__(
2016
2028
2017
2029
layers = ModuleList ([])
2018
2030
2019
- for _ in range (depth ):
2031
+ for i in range (depth ):
2032
+ is_first = i == 0
2020
2033
2021
2034
linear_attn = None
2022
2035
@@ -2038,12 +2051,15 @@ def __init__(
2038
2051
** colt5_attn_kwargs
2039
2052
)
2040
2053
2054
+ accept_value_residual = add_value_residual and not is_first
2055
+
2041
2056
pair_bias_attn = AttentionPairBias (
2042
2057
dim = dim ,
2043
2058
dim_pairwise = dim_pairwise ,
2044
2059
heads = heads ,
2045
2060
window_size = attn_window_size ,
2046
2061
num_memory_kv = attn_num_memory_kv ,
2062
+ accept_value_residual = accept_value_residual ,
2047
2063
** attn_pair_bias_kwargs
2048
2064
)
2049
2065
0 commit comments