Skip to content

Commit b992f6f

Browse files
yaochengjipgmoka
authored andcommitted
[Kernel] add group_offset and transpose_rhs support in gmm kernel (#9251)
1 parent e5d6e5a commit b992f6f

File tree

2 files changed

+69
-35
lines changed

2 files changed

+69
-35
lines changed

test/test_gmm.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@
2323

2424
class MegabloxTest(unittest.TestCase):
2525

26-
def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor,
27-
group_sizes: torch.Tensor) -> torch.Tensor:
26+
def _reference_gmm(self,
27+
lhs: torch.Tensor,
28+
rhs: torch.Tensor,
29+
group_sizes: torch.Tensor,
30+
transpose_rhs: bool = False) -> torch.Tensor:
2831
start = 0
2932
out = []
3033
for i, size in enumerate(group_sizes):
31-
result = lhs[start:start + size, :] @ rhs[i, :, :]
34+
rhsi = rhs[i, :, :]
35+
if transpose_rhs is True:
36+
rhsi = torch.transpose(rhsi, 0, 1)
37+
result = lhs[start:start + size, :] @ rhsi
3238
out.append(result)
3339
start += group_sizes[i]
3440
return torch.cat(out)
@@ -105,27 +111,39 @@ def test_gmm(self):
105111
for test_cache in [False, True]:
106112
for gmm_func in gmm_funcs:
107113
for test_case in self.tests_cases:
108-
num_groups = test_case['num_groups']
109-
k = test_case['k']
110-
m = test_case['m']
111-
n = test_case['n']
112-
lhs_dtype = rhs_dtype = test_case['dtype']
113-
114-
lhs = torch.rand(m, k, dtype=lhs_dtype)
115-
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
116-
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
117-
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
118-
119-
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
120-
# torch.compiled version of the gmm will cache the payload in dynamo layer
121-
# hence won't trigger the trace_pallas cache
122-
if test_cache and gmm_func != compiled_gmm:
123-
old_cnt = xr.get_num_cached_compilation_graph()
124-
# execute the same gmm func, expected to hit the cache
125-
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
126-
new_cnt = xr.get_num_cached_compilation_graph()
127-
self.assertEqual(old_cnt, new_cnt)
128-
self.assertTrue(torch.allclose(ref_out, out.cpu()))
114+
for transpose_rhs in [True, False]:
115+
num_groups = test_case['num_groups']
116+
k = test_case['k']
117+
m = test_case['m']
118+
n = test_case['n']
119+
lhs_dtype = rhs_dtype = test_case['dtype']
120+
121+
lhs = torch.rand(m, k, dtype=lhs_dtype)
122+
if transpose_rhs is False:
123+
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
124+
else:
125+
rhs = torch.rand(num_groups, n, k, dtype=rhs_dtype)
126+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
127+
ref_out = self._reference_gmm(lhs, rhs, group_sizes, transpose_rhs)
128+
129+
out = gmm_func(
130+
lhs.to("xla"),
131+
rhs.to("xla"),
132+
group_sizes.to("xla"),
133+
transpose_rhs=transpose_rhs)
134+
# torch.compiled version of the gmm will cache the payload in dynamo layer
135+
# hence won't trigger the trace_pallas cache
136+
if test_cache and gmm_func != compiled_gmm:
137+
old_cnt = xr.get_num_cached_compilation_graph()
138+
# execute the same gmm func, expected to hit the cache
139+
out = gmm_func(
140+
lhs.to("xla"),
141+
rhs.to("xla"),
142+
group_sizes.to("xla"),
143+
transpose_rhs=transpose_rhs)
144+
new_cnt = xr.get_num_cached_compilation_graph()
145+
self.assertEqual(old_cnt, new_cnt)
146+
self.assertTrue(torch.allclose(ref_out, out.cpu()))
129147

130148
# Make sure gmm doesn't fallback.
131149
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

torch_xla/experimental/custom_kernel.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,9 @@ def gmm(
12891289
lhs: torch.Tensor,
12901290
rhs: torch.Tensor,
12911291
group_sizes: torch.Tensor,
1292-
tiling: Tuple[int, int, int] = (512, 512, 512)
1292+
tiling: Tuple[int, int, int] = (512, 512, 512),
1293+
group_offset: torch.Tensor | None = None,
1294+
transpose_rhs: bool = False,
12931295
) -> torch.Tensor:
12941296
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
12951297
@@ -1298,7 +1300,9 @@ def gmm(
12981300
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
12991301
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
13001302
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
1301-
1303+
group_offset: The group in group sizes to start computing from. This is
1304+
particularly useful for when rhs num_groups is sharded.
1305+
transpose_rhs: True if the rhs needs to be transposed.
13021306
Returns:
13031307
A 2d, torch.Tensor with shape [m, n].
13041308
"""
@@ -1310,15 +1314,18 @@ def gmm(
13101314
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
13111315
preferred_element_type = lhs.dtype
13121316
return xb.call_jax(gmm, (lhs, rhs, group_sizes, preferred_element_type,
1313-
(tm, tk, tn)))
1317+
(tm, tk, tn), group_offset),
1318+
{"transpose_rhs": transpose_rhs})
13141319

13151320

13161321
@requires_jax
13171322
def tgmm(
13181323
lhs: torch.Tensor,
13191324
rhs: torch.Tensor,
13201325
group_sizes: torch.Tensor,
1321-
tiling: Tuple[int, int, int] = (512, 512, 512)
1326+
tiling: Tuple[int, int, int] = (512, 512, 512),
1327+
group_offset: torch.Tensor | None = None,
1328+
num_actual_groups: int | None = None,
13221329
) -> torch.Tensor:
13231330
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
13241331
@@ -1340,7 +1347,7 @@ def tgmm(
13401347
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
13411348
preferred_element_type = lhs.dtype
13421349
return xb.call_jax(tgmm, (lhs, rhs, group_sizes, preferred_element_type,
1343-
(tm, tk, tn)))
1350+
(tm, tk, tn), group_offset, num_actual_groups))
13441351

13451352

13461353
def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)):
@@ -1547,7 +1554,7 @@ def ragged_paged_attention_non_xla(
15471554

15481555

15491556
XLA_LIB.define(
1550-
"gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor",
1557+
"gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None, Tensor? group_offset=None, bool transpose_rhs=False) -> Tensor",
15511558
)
15521559

15531560

@@ -1557,28 +1564,37 @@ def gmm_xla(
15571564
rhs: torch.Tensor,
15581565
group_sizes: torch.Tensor,
15591566
# pytorch custom op does not allow tuple type, use list instead
1560-
tiling: Optional[List[int]] = [512, 512, 512]):
1567+
tiling: Optional[List[int]] = [512, 512, 512],
1568+
group_offset: torch.Tensor | None = None,
1569+
transpose_rhs: bool = False):
1570+
if tiling is None:
1571+
tiling = [512, 512, 512]
15611572
assert len(tiling) == 3, "tiling must be a list with 3 integers"
15621573
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
15631574
assert rhs.dim(
15641575
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
15651576
tiling = tuple(tiling)
1566-
return gmm(lhs, rhs, group_sizes, tiling)
1577+
return gmm(lhs, rhs, group_sizes, tiling, group_offset, transpose_rhs)
15671578

15681579

15691580
@impl(XLA_LIB, "gmm", "CompositeExplicitAutograd")
15701581
def gmm_non_xla(lhs: torch.Tensor,
15711582
rhs: torch.Tensor,
15721583
group_sizes: torch.Tensor,
1573-
tiling: Optional[List[int]] = [512, 512, 512]):
1584+
tiling: Optional[List[int]] = [512, 512, 512],
1585+
group_offset: torch.Tensor | None = None,
1586+
transpose_rhs: bool = False):
15741587
# This will be called when dynamo use fake tensor to construct the fake output.
15751588
# We need to make sure output tensor's shape is correct.
15761589
if lhs.device != torch.device("meta"):
15771590
warnings.warn(f'XLA gmm should only be applied to tensors on XLA device')
1591+
if tiling is None:
1592+
tiling = [512, 512, 512]
15781593
assert len(tiling) == 3, "tiling must be a list with 3 integers"
15791594
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
15801595
assert rhs.dim(
1581-
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
1596+
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n] or [num_groups, n, k] when transpose_rhs is True"
1597+
rhs_dim_size = rhs.size()[1] if transpose_rhs is True else rhs.size()[2]
15821598

15831599
# we only need to return the tensor with correct shape for meta tensor.
1584-
return torch.empty(lhs.size()[0], rhs.size()[2], device=lhs.device)
1600+
return torch.empty(lhs.size()[0], rhs_dim_size, device=lhs.device)

0 commit comments

Comments
 (0)