@@ -1430,10 +1430,18 @@ def __init__(
1430
1430
num_register_tokens = 0 ,
1431
1431
checkpoint = False ,
1432
1432
add_value_residual = False ,
1433
+ num_residual_streams = 1 ,
1433
1434
pairwise_block_kwargs : dict = dict (),
1434
1435
pair_bias_attn_kwargs : dict = dict ()
1435
1436
):
1436
1437
super ().__init__ ()
1438
+
1439
+ # residual / hyper connections
1440
+
1441
+ init_hyper_conn , self .expand_streams , self .reduce_streams = HyperConnections .get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
1442
+
1443
+ # layers
1444
+
1437
1445
layers = ModuleList ([])
1438
1446
1439
1447
pair_bias_attn_kwargs = dict (
@@ -1463,8 +1471,8 @@ def __init__(
1463
1471
1464
1472
layers .append (ModuleList ([
1465
1473
pairwise_block ,
1466
- single_pre_ln (pair_bias_attn ),
1467
- single_pre_ln (single_transition ),
1474
+ init_hyper_conn ( dim = dim_single , branch = single_pre_ln (pair_bias_attn ) ),
1475
+ init_hyper_conn ( dim = dim_single , branch = single_pre_ln (single_transition ) ),
1468
1476
]))
1469
1477
1470
1478
self .layers = layers
@@ -1499,6 +1507,8 @@ def to_layers(
1499
1507
1500
1508
) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1501
1509
1510
+ single_repr = self .expand_streams (single_repr )
1511
+
1502
1512
for _ in range (self .recurrent_depth ):
1503
1513
1504
1514
value_residual = None
@@ -1512,15 +1522,15 @@ def to_layers(
1512
1522
1513
1523
pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr = pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
1514
1524
1515
- attn_out , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
1516
-
1517
- single_repr = single_repr + attn_out
1525
+ single_repr , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
1518
1526
1519
1527
if self .add_value_residual :
1520
1528
value_residual = default (value_residual , attn_values )
1521
1529
pairwise_value_residuals = default (pairwise_value_residuals , pairwise_attn_values )
1522
1530
1523
- single_repr = single_transition (single_repr ) + single_repr
1531
+ single_repr = single_transition (single_repr )
1532
+
1533
+ single_repr = self .reduce_streams (single_repr )
1524
1534
1525
1535
return single_repr , pairwise_repr
1526
1536
@@ -1550,8 +1560,7 @@ def pair_bias_attn_wrapper(layer):
1550
1560
@wraps (layer )
1551
1561
def inner (inputs , * args , ** kwargs ):
1552
1562
single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1553
- attn_out , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1554
- single_repr = single_repr + attn_out
1563
+ single_repr , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1555
1564
1556
1565
if self .add_value_residual :
1557
1566
maybe_value_residual = default (maybe_value_residual , attn_values )
@@ -1563,7 +1572,7 @@ def single_transition_wrapper(layer):
1563
1572
@wraps (layer )
1564
1573
def inner (inputs , * args , ** kwargs ):
1565
1574
single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1566
- single_repr = layer (single_repr ) + single_repr
1575
+ single_repr = layer (single_repr )
1567
1576
return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
1568
1577
return inner
1569
1578
@@ -1579,6 +1588,8 @@ def inner(inputs, *args, **kwargs):
1579
1588
wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1580
1589
wrapped_layers .append (single_transition_wrapper (single_transition ))
1581
1590
1591
+ single_repr = self .expand_streams (single_repr )
1592
+
1582
1593
for _ in range (self .recurrent_depth ):
1583
1594
inputs = (single_repr , pairwise_repr , mask , None , None )
1584
1595
@@ -1587,6 +1598,8 @@ def inner(inputs, *args, **kwargs):
1587
1598
1588
1599
single_repr , pairwise_repr , * _ = inputs
1589
1600
1601
+ single_repr = self .reduce_streams (single_repr )
1602
+
1590
1603
return single_repr , pairwise_repr
1591
1604
1592
1605
@typecheck
0 commit comments