@@ -1289,7 +1289,9 @@ def gmm(
1289
1289
lhs : torch .Tensor ,
1290
1290
rhs : torch .Tensor ,
1291
1291
group_sizes : torch .Tensor ,
1292
- tiling : Tuple [int , int , int ] = (512 , 512 , 512 )
1292
+ tiling : Tuple [int , int , int ] = (512 , 512 , 512 ),
1293
+ group_offset : torch .Tensor | None = None ,
1294
+ transpose_rhs : bool = False ,
1293
1295
) -> torch .Tensor :
1294
1296
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
1295
1297
@@ -1298,7 +1300,9 @@ def gmm(
1298
1300
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
1299
1301
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
1300
1302
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
1301
-
1303
+ group_offset: The group in group sizes to start computing from. This is
1304
+ particularly useful for when rhs num_groups is sharded.
1305
+ transpose_rhs: True if the rhs needs to be transposed.
1302
1306
Returns:
1303
1307
A 2d, torch.Tensor with shape [m, n].
1304
1308
"""
@@ -1310,15 +1314,18 @@ def gmm(
1310
1314
tm , tk , tn = min (tiling [0 ], m ), min (tiling [1 ], k ), min (tiling [2 ], n )
1311
1315
preferred_element_type = lhs .dtype
1312
1316
return xb .call_jax (gmm , (lhs , rhs , group_sizes , preferred_element_type ,
1313
- (tm , tk , tn )))
1317
+ (tm , tk , tn ), group_offset ),
1318
+ {"transpose_rhs" : transpose_rhs })
1314
1319
1315
1320
1316
1321
@requires_jax
1317
1322
def tgmm (
1318
1323
lhs : torch .Tensor ,
1319
1324
rhs : torch .Tensor ,
1320
1325
group_sizes : torch .Tensor ,
1321
- tiling : Tuple [int , int , int ] = (512 , 512 , 512 )
1326
+ tiling : Tuple [int , int , int ] = (512 , 512 , 512 ),
1327
+ group_offset : torch .Tensor | None = None ,
1328
+ num_actual_groups : int | None = None ,
1322
1329
) -> torch .Tensor :
1323
1330
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
1324
1331
@@ -1340,7 +1347,7 @@ def tgmm(
1340
1347
tm , tk , tn = min (tiling [0 ], m ), min (tiling [1 ], k ), min (tiling [2 ], n )
1341
1348
preferred_element_type = lhs .dtype
1342
1349
return xb .call_jax (tgmm , (lhs , rhs , group_sizes , preferred_element_type ,
1343
- (tm , tk , tn )))
1350
+ (tm , tk , tn ), group_offset , num_actual_groups ))
1344
1351
1345
1352
1346
1353
def gmm_backward (grad , lhs , rhs , group_sizes , tiling = (512 , 512 , 512 )):
@@ -1547,7 +1554,7 @@ def ragged_paged_attention_non_xla(
1547
1554
1548
1555
1549
1556
XLA_LIB .define (
1550
- "gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor" ,
1557
+ "gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None, Tensor? group_offset=None, bool transpose_rhs=False ) -> Tensor" ,
1551
1558
)
1552
1559
1553
1560
@@ -1557,28 +1564,37 @@ def gmm_xla(
1557
1564
rhs : torch .Tensor ,
1558
1565
group_sizes : torch .Tensor ,
1559
1566
# pytorch custom op does not allow tuple type, use list instead
1560
- tiling : Optional [List [int ]] = [512 , 512 , 512 ]):
1567
+ tiling : Optional [List [int ]] = [512 , 512 , 512 ],
1568
+ group_offset : torch .Tensor | None = None ,
1569
+ transpose_rhs : bool = False ):
1570
+ if tiling is None :
1571
+ tiling = [512 , 512 , 512 ]
1561
1572
assert len (tiling ) == 3 , "tiling must be a list with 3 integers"
1562
1573
assert lhs .dim () == 2 , "lhs must be a 2d, torch.Tensor with shape [k, m]"
1563
1574
assert rhs .dim (
1564
1575
) == 3 , "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
1565
1576
tiling = tuple (tiling )
1566
- return gmm (lhs , rhs , group_sizes , tiling )
1577
+ return gmm (lhs , rhs , group_sizes , tiling , group_offset , transpose_rhs )
1567
1578
1568
1579
1569
1580
@impl (XLA_LIB , "gmm" , "CompositeExplicitAutograd" )
1570
1581
def gmm_non_xla (lhs : torch .Tensor ,
1571
1582
rhs : torch .Tensor ,
1572
1583
group_sizes : torch .Tensor ,
1573
- tiling : Optional [List [int ]] = [512 , 512 , 512 ]):
1584
+ tiling : Optional [List [int ]] = [512 , 512 , 512 ],
1585
+ group_offset : torch .Tensor | None = None ,
1586
+ transpose_rhs : bool = False ):
1574
1587
# This will be called when dynamo use fake tensor to construct the fake output.
1575
1588
# We need to make sure output tensor's shape is correct.
1576
1589
if lhs .device != torch .device ("meta" ):
1577
1590
warnings .warn (f'XLA gmm should only be applied to tensors on XLA device' )
1591
+ if tiling is None :
1592
+ tiling = [512 , 512 , 512 ]
1578
1593
assert len (tiling ) == 3 , "tiling must be a list with 3 integers"
1579
1594
assert lhs .dim () == 2 , "lhs must be a 2d, torch.Tensor with shape [k, m]"
1580
1595
assert rhs .dim (
1581
- ) == 3 , "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
1596
+ ) == 3 , "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n] or [num_groups, n, k] when transpose_rhs is True"
1597
+ rhs_dim_size = rhs .size ()[1 ] if transpose_rhs is True else rhs .size ()[2 ]
1582
1598
1583
1599
# we only need to return the tensor with correct shape for meta tensor.
1584
- return torch .empty (lhs .size ()[0 ], rhs . size ()[ 2 ] , device = lhs .device )
1600
+ return torch .empty (lhs .size ()[0 ], rhs_dim_size , device = lhs .device )
0 commit comments