Skip to content

Commit 8955571

Browse files
authored
Support gmm and tgmm trace_pallas caching (#7921)
1 parent 1b4b828 commit 8955571

File tree

2 files changed

+164
-84
lines changed

2 files changed

+164
-84
lines changed

test/test_gmm.py

Lines changed: 141 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -94,60 +94,75 @@ def _init_test_cases(self):
9494
def test_gmm(self):
9595
met.clear_all()
9696
jax.config.update('jax_default_matmul_precision', "highest")
97+
compiled_gmm = torch.compile(torch.ops.xla.gmm, backend="openxla")
9798
gmm_funcs = [
98-
gmm, torch.ops.xla.gmm,
99-
torch.compile(torch.ops.xla.gmm, backend="openxla")
99+
gmm,
100+
torch.ops.xla.gmm,
101+
compiled_gmm,
100102
]
101103

102104
self._init_test_cases()
103-
for gmm_func in gmm_funcs:
104-
for test_case in self.tests_cases:
105-
num_groups = test_case['num_groups']
106-
k = test_case['k']
107-
m = test_case['m']
108-
n = test_case['n']
109-
lhs_dtype = rhs_dtype = test_case['dtype']
110-
111-
lhs = torch.rand(m, k, dtype=lhs_dtype)
112-
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
113-
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
114-
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
115-
116-
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
117-
self.assertTrue(torch.allclose(ref_out, out.cpu()))
105+
for test_cache in [False, True]:
106+
for gmm_func in gmm_funcs:
107+
for test_case in self.tests_cases:
108+
num_groups = test_case['num_groups']
109+
k = test_case['k']
110+
m = test_case['m']
111+
n = test_case['n']
112+
lhs_dtype = rhs_dtype = test_case['dtype']
113+
114+
lhs = torch.rand(m, k, dtype=lhs_dtype)
115+
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
116+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
117+
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
118+
119+
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
120+
# torch.compiled version of the gmm will cache the payload in dynamo layer
121+
# hence won't trigger the trace_pallas cache
122+
if test_cache and gmm_func != compiled_gmm:
123+
met.clear_counters()
124+
# execute the same gmm func, expected to hit the cache
125+
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
126+
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
127+
self.assertTrue(torch.allclose(ref_out, out.cpu()))
118128

119129
# Make sure gmm doesn't fallback.
120-
self.assertNotIn("aten::", met.short_metrics_report())
130+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
121131
jax.config.update('jax_default_matmul_precision', "default")
122132

123133
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
124134
def test_gmm_bf16(self):
125135
met.clear_all()
126136

127-
gmm_funcs = [
128-
gmm, torch.ops.xla.gmm,
129-
torch.compile(torch.ops.xla.gmm, backend="openxla")
130-
]
137+
compiled_gmm = torch.compile(torch.ops.xla.gmm, backend="openxla")
138+
gmm_funcs = [gmm, torch.ops.xla.gmm, compiled_gmm]
131139
self._init_test_cases()
132-
for gmm_func in gmm_funcs:
133-
for test_case in self.tests_cases:
134-
num_groups = test_case['num_groups']
135-
k = test_case['k']
136-
m = test_case['m']
137-
n = test_case['n']
138-
lhs_dtype = rhs_dtype = torch.bfloat16
139-
140-
lhs = torch.rand(m, k, dtype=lhs_dtype)
141-
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
142-
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
143-
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
144-
145-
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
146-
147-
self.assertTrue(torch.allclose(ref_out, out.cpu()))
140+
for test_cache in [False, True]:
141+
for gmm_func in gmm_funcs:
142+
for test_case in self.tests_cases:
143+
num_groups = test_case['num_groups']
144+
k = test_case['k']
145+
m = test_case['m']
146+
n = test_case['n']
147+
lhs_dtype = rhs_dtype = torch.bfloat16
148+
149+
lhs = torch.rand(m, k, dtype=lhs_dtype)
150+
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
151+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
152+
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
153+
154+
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
155+
# torch.compiled version of the gmm will cache the payload in dynamo layer
156+
# hence won't trigger the trace_pallas cache
157+
if test_cache and gmm_func != compiled_gmm:
158+
met.clear_counters()
159+
# execute the same gmm func, expected to hit the cache
160+
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
161+
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
162+
self.assertTrue(torch.allclose(ref_out, out.cpu()))
148163

149164
# Make sure gmm doesn't fallback.
150-
self.assertNotIn("aten::", met.short_metrics_report())
165+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
151166

152167
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
153168
def test_make_group_metadata(self):
@@ -313,47 +328,59 @@ def test_tgmm(self):
313328
jax.config.update('jax_default_matmul_precision', "highest")
314329

315330
self._init_test_cases()
316-
for test_case in self.tests_cases:
317-
num_groups = test_case['num_groups']
318-
k = test_case['k']
319-
m = test_case['m']
320-
n = test_case['n']
321-
lhs_dtype = rhs_dtype = test_case['dtype']
322-
323-
lhs = torch.rand(k, m, dtype=lhs_dtype)
324-
rhs = torch.rand(m, n, dtype=rhs_dtype)
325-
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
326-
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)
331+
for test_cache in [False, True]:
332+
for test_case in self.tests_cases:
333+
num_groups = test_case['num_groups']
334+
k = test_case['k']
335+
m = test_case['m']
336+
n = test_case['n']
337+
lhs_dtype = rhs_dtype = test_case['dtype']
327338

328-
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
329-
self.assertTrue(torch.allclose(ref_out, out.cpu()))
339+
lhs = torch.rand(k, m, dtype=lhs_dtype)
340+
rhs = torch.rand(m, n, dtype=rhs_dtype)
341+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
342+
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)
343+
344+
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
345+
if test_cache:
346+
met.clear_counters()
347+
# execute the same gmm func, expected to hit the cache
348+
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
349+
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
350+
self.assertTrue(torch.allclose(ref_out, out.cpu()))
330351

331352
# Make sure tgmm doesn't fallback.
332-
self.assertNotIn("aten::", met.short_metrics_report())
353+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
333354
jax.config.update('jax_default_matmul_precision', "default")
334355

335356
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
336357
def test_tgmm_bf16(self):
337358
met.clear_all()
338359

339360
self._init_test_cases()
340-
for test_case in self.tests_cases:
341-
num_groups = test_case['num_groups']
342-
k = test_case['k']
343-
m = test_case['m']
344-
n = test_case['n']
345-
lhs_dtype = rhs_dtype = torch.bfloat16
346-
347-
lhs = torch.rand(k, m, dtype=lhs_dtype)
348-
rhs = torch.rand(m, n, dtype=rhs_dtype)
349-
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
350-
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)
361+
for test_cache in [False, True]:
362+
for test_case in self.tests_cases:
363+
num_groups = test_case['num_groups']
364+
k = test_case['k']
365+
m = test_case['m']
366+
n = test_case['n']
367+
lhs_dtype = rhs_dtype = torch.bfloat16
351368

352-
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
353-
self.assertTrue(torch.allclose(ref_out, out.cpu()))
369+
lhs = torch.rand(k, m, dtype=lhs_dtype)
370+
rhs = torch.rand(m, n, dtype=rhs_dtype)
371+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
372+
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)
373+
374+
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
375+
if test_cache:
376+
met.clear_counters()
377+
# execute the same gmm func, expected to hit the cache
378+
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
379+
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
380+
self.assertTrue(torch.allclose(ref_out, out.cpu()))
354381

355382
# Make sure tgmm doesn't fallback.
356-
self.assertNotIn("aten::", met.short_metrics_report())
383+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
357384

358385
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
359386
def test_gmm_backward(self):
@@ -365,25 +392,31 @@ def test_gmm_backward(self):
365392
n = test_case['n']
366393
lhs_dtype = rhs_dtype = torch.bfloat16
367394

368-
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
369-
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
370-
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
371-
lhs.retain_grad()
372-
rhs.retain_grad()
395+
for test_cache in [False, True]:
396+
met.clear_all()
397+
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
398+
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
399+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
400+
lhs.retain_grad()
401+
rhs.retain_grad()
373402

374-
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
375-
ref_out.sum().backward()
403+
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
404+
ref_out.sum().backward()
376405

377-
ref_out_backward = torch.ones_like(ref_out)
378-
grad_lhs, grad_rhs = gmm_backward(
379-
ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"),
380-
group_sizes.to("xla"))
406+
ref_out_backward = torch.ones_like(ref_out)
407+
grad_lhs, grad_rhs = gmm_backward(
408+
ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"),
409+
group_sizes.to("xla"))
410+
# same gmm/tgmm was run for the `test_cache=False` case so the
411+
# cache should be populated now
412+
if test_cache:
413+
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 2)
381414

382-
self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
383-
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))
415+
self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
416+
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))
384417

385418
# Make sure gmm doesn't fallback.
386-
self.assertNotIn("aten::", met.short_metrics_report())
419+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
387420

388421
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
389422
def test_gmm_backward_2(self):
@@ -420,7 +453,7 @@ def test_gmm_backward_2(self):
420453
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))
421454

422455
# Make sure gmm doesn't fallback.
423-
self.assertNotIn("aten::", met.short_metrics_report())
456+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
424457

425458
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
426459
def test_gmm_backward_3(self):
@@ -458,7 +491,32 @@ def test_gmm_backward_3(self):
458491
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))
459492

460493
# Make sure gmm doesn't fallback.
461-
self.assertNotIn("aten::", met.short_metrics_report())
494+
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
495+
496+
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
497+
def test_gmm_cache_miss(self):
498+
met.clear_all()
499+
jax.config.update('jax_default_matmul_precision', "highest")
500+
501+
self._init_test_cases()
502+
test_case = self.tests_cases[-1]
503+
# make sure that cache miss for different input shapes and dtype
504+
met.clear_all()
505+
for mul_factor in [[2, 1, 1, 1], [1, 2, 1, 1], [2, 1, 2, 1], [2, 1, 1, 2]]:
506+
for dtype in [torch.float32, torch.bfloat16]:
507+
for tiling in [(128, 128, 128), (256, 256, 256)]:
508+
num_groups = test_case['num_groups'] * mul_factor[0]
509+
k = test_case['k'] * mul_factor[1]
510+
m = test_case['m'] * mul_factor[2]
511+
n = test_case['n'] * mul_factor[3]
512+
lhs_dtype = rhs_dtype = dtype
513+
514+
lhs = torch.rand(m, k, dtype=lhs_dtype)
515+
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
516+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
517+
518+
out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"), tiling)
519+
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), None)
462520

463521

464522
if __name__ == '__main__':

torch_xla/experimental/custom_kernel.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch_xla.distributed.spmd as xs
1010
import torch_xla.debug.metrics as met
1111

12-
from typing import Any, List, Callable, Optional, Tuple
12+
from typing import Any, List, Callable, Optional, Tuple, Dict
1313
from torch.library import impl
1414
from torch_xla.core.xla_model import XLA_LIB
1515

@@ -93,10 +93,14 @@ def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
9393
convert_torch_dtype_to_jax(tensor.dtype))
9494

9595

96+
trace_pallas_arg_to_payload: Dict[Tuple[Any], str] = {}
97+
98+
9699
def trace_pallas(kernel: Callable,
97100
*args,
98101
static_argnums=None,
99102
static_argnames=None,
103+
use_cache=False,
100104
**kwargs):
101105
# Import JAX within the function such that we don't need to call the jax_import_guard()
102106
# in the global scope which could cause problems for xmp.spawn.
@@ -116,11 +120,27 @@ def trace_pallas(kernel: Callable,
116120
else:
117121
jax_args.append(arg)
118122

123+
hash_key = ()
124+
if use_cache:
125+
global trace_pallas_arg_to_payload
126+
# implcit assumption here that everything in kwargs is hashable and not a tensor,
127+
# which is true for the gmm and tgmm.
128+
hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args),
129+
repr(sorted(kwargs.items())).encode())
130+
if hash_key in trace_pallas_arg_to_payload:
131+
torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1)
132+
return trace_pallas_arg_to_payload[hash_key], tensor_args
133+
119134
# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
120135
ir = jax.jit(
121136
kernel, static_argnums=static_argnums,
122137
static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
123138
payload = _extract_backend_config(ir)
139+
140+
if use_cache:
141+
# if we reach here it means we have a cache miss.
142+
trace_pallas_arg_to_payload[hash_key] = payload
143+
124144
return payload, tensor_args
125145

126146

@@ -770,6 +790,7 @@ def gmm(
770790
rhs,
771791
group_sizes,
772792
static_argnames=["tiling", "preferred_element_type"],
793+
use_cache=True,
773794
preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type),
774795
tiling=(tm, tk, tn))
775796

@@ -822,6 +843,7 @@ def tgmm(
822843
rhs,
823844
group_sizes,
824845
static_argnames=["tiling", "preferred_element_type"],
846+
use_cache=True,
825847
preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type),
826848
tiling=(tm, tk, tn))
827849

0 commit comments

Comments
 (0)