119
119
120
120
from colt5_attention import ConditionalRoutedAttention
121
121
122
- from hyper_connections import HyperConnections
122
+ from hyper_connections . hyper_connections_with_multi_input_streams import HyperConnections
123
123
124
124
# other external libs
125
125
@@ -995,8 +995,8 @@ def __init__(
995
995
@typecheck
996
996
def forward (
997
997
self ,
998
- * ,
999
998
pairwise_repr : Float ['b n n d' ],
999
+ * ,
1000
1000
mask : Bool ['b n' ] | None = None ,
1001
1001
value_residuals : tuple [Tensor , Tensor ] | None = None ,
1002
1002
return_values = False ,
@@ -1470,8 +1470,8 @@ def __init__(
1470
1470
single_transition = Transition (dim = dim_single )
1471
1471
1472
1472
layers .append (ModuleList ([
1473
- pairwise_block ,
1474
- init_hyper_conn (dim = dim_single , branch = single_pre_ln (pair_bias_attn )),
1473
+ init_hyper_conn ( dim = dim_pairwise , branch = pairwise_block ) ,
1474
+ init_hyper_conn (dim = dim_single , additional_input_paths = [( 'pairwise_repr' , dim_pairwise )], branch = single_pre_ln (pair_bias_attn )),
1475
1475
init_hyper_conn (dim = dim_single , branch = single_pre_ln (single_transition )),
1476
1476
]))
1477
1477
@@ -1508,6 +1508,7 @@ def to_layers(
1508
1508
) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1509
1509
1510
1510
single_repr = self .expand_streams (single_repr )
1511
+ pairwise_repr = self .expand_streams (pairwise_repr )
1511
1512
1512
1513
for _ in range (self .recurrent_depth ):
1513
1514
@@ -1520,7 +1521,7 @@ def to_layers(
1520
1521
single_transition
1521
1522
) in self .layers :
1522
1523
1523
- pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr = pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
1524
+ pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
1524
1525
1525
1526
single_repr , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
1526
1527
@@ -1531,6 +1532,7 @@ def to_layers(
1531
1532
single_repr = single_transition (single_repr )
1532
1533
1533
1534
single_repr = self .reduce_streams (single_repr )
1535
+ pairwise_repr = self .reduce_streams (pairwise_repr )
1534
1536
1535
1537
return single_repr , pairwise_repr
1536
1538
@@ -1548,7 +1550,7 @@ def pairwise_block_wrapper(layer):
1548
1550
@wraps (layer )
1549
1551
def inner (inputs , * args , ** kwargs ):
1550
1552
single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1551
- pairwise_repr , pairwise_attn_values = layer (pairwise_repr = pairwise_repr , mask = mask , value_residuals = maybe_pairwise_value_residuals , return_values = True )
1553
+ pairwise_repr , pairwise_attn_values = layer (pairwise_repr , mask = mask , value_residuals = maybe_pairwise_value_residuals , return_values = True )
1552
1554
1553
1555
if self .add_value_residual :
1554
1556
maybe_pairwise_value_residuals = default (maybe_pairwise_value_residuals , pairwise_attn_values )
@@ -1589,6 +1591,7 @@ def inner(inputs, *args, **kwargs):
1589
1591
wrapped_layers .append (single_transition_wrapper (single_transition ))
1590
1592
1591
1593
single_repr = self .expand_streams (single_repr )
1594
+ pairwise_repr = self .expand_streams (pairwise_repr )
1592
1595
1593
1596
for _ in range (self .recurrent_depth ):
1594
1597
inputs = (single_repr , pairwise_repr , mask , None , None )
@@ -1599,6 +1602,7 @@ def inner(inputs, *args, **kwargs):
1599
1602
single_repr , pairwise_repr , * _ = inputs
1600
1603
1601
1604
single_repr = self .reduce_streams (single_repr )
1605
+ pairwise_repr = self .reduce_streams (pairwise_repr )
1602
1606
1603
1607
return single_repr , pairwise_repr
1604
1608
0 commit comments