@@ -1430,10 +1430,18 @@ def __init__(
1430
1430
num_register_tokens = 0 ,
1431
1431
checkpoint = False ,
1432
1432
add_value_residual = False ,
1433
+ num_residual_streams = 1 ,
1433
1434
pairwise_block_kwargs : dict = dict (),
1434
1435
pair_bias_attn_kwargs : dict = dict ()
1435
1436
):
1436
1437
super ().__init__ ()
1438
+
1439
+ # residual / hyper connections
1440
+
1441
+ init_hyper_conn , self .expand_streams , self .reduce_streams = HyperConnections .get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
1442
+
1443
+ # layers
1444
+
1437
1445
layers = ModuleList ([])
1438
1446
1439
1447
pair_bias_attn_kwargs = dict (
@@ -1463,8 +1471,8 @@ def __init__(
1463
1471
1464
1472
layers .append (ModuleList ([
1465
1473
pairwise_block ,
1466
- single_pre_ln (pair_bias_attn ),
1467
- single_pre_ln (single_transition ),
1474
+ init_hyper_conn ( dim = dim_single , branch = single_pre_ln (pair_bias_attn ) ),
1475
+ init_hyper_conn ( dim = dim_single , branch = single_pre_ln (single_transition ) ),
1468
1476
]))
1469
1477
1470
1478
self .layers = layers
@@ -1499,6 +1507,8 @@ def to_layers(
1499
1507
1500
1508
) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1501
1509
1510
+ single_repr = self .expand_streams (single_repr )
1511
+
1502
1512
for _ in range (self .recurrent_depth ):
1503
1513
1504
1514
value_residual = None
@@ -1522,6 +1532,8 @@ def to_layers(
1522
1532
1523
1533
single_repr = single_transition (single_repr ) + single_repr
1524
1534
1535
+ single_repr = self .reduce_streams (single_repr )
1536
+
1525
1537
return single_repr , pairwise_repr
1526
1538
1527
1539
@typecheck
@@ -1550,8 +1562,7 @@ def pair_bias_attn_wrapper(layer):
1550
1562
@wraps (layer )
1551
1563
def inner (inputs , * args , ** kwargs ):
1552
1564
single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1553
- attn_out , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1554
- single_repr = single_repr + attn_out
1565
+ single_repr , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1555
1566
1556
1567
if self .add_value_residual :
1557
1568
maybe_value_residual = default (maybe_value_residual , attn_values )
@@ -1563,7 +1574,7 @@ def single_transition_wrapper(layer):
1563
1574
@wraps (layer )
1564
1575
def inner (inputs , * args , ** kwargs ):
1565
1576
single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1566
- single_repr = layer (single_repr ) + single_repr
1577
+ single_repr = layer (single_repr )
1567
1578
return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
1568
1579
return inner
1569
1580
@@ -1579,6 +1590,8 @@ def inner(inputs, *args, **kwargs):
1579
1590
wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1580
1591
wrapped_layers .append (single_transition_wrapper (single_transition ))
1581
1592
1593
+ single_repr = self .expand_streams (single_repr )
1594
+
1582
1595
for _ in range (self .recurrent_depth ):
1583
1596
inputs = (single_repr , pairwise_repr , mask , None , None )
1584
1597
@@ -1587,6 +1600,8 @@ def inner(inputs, *args, **kwargs):
1587
1600
1588
1601
single_repr , pairwise_repr , * _ = inputs
1589
1602
1603
+ single_repr = self .reduce_streams (single_repr )
1604
+
1590
1605
return single_repr , pairwise_repr
1591
1606
1592
1607
@typecheck
0 commit comments