Skip to content

Commit

Permalink
Support gmm and tgmm trace_pallas caching (#7921)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Aug 30, 2024
1 parent 1b4b828 commit 8955571
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 84 deletions.
224 changes: 141 additions & 83 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,60 +94,75 @@ def _init_test_cases(self):
def test_gmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', "highest")
compiled_gmm = torch.compile(torch.ops.xla.gmm, backend="openxla")
gmm_funcs = [
gmm, torch.ops.xla.gmm,
torch.compile(torch.ops.xla.gmm, backend="openxla")
gmm,
torch.ops.xla.gmm,
compiled_gmm,
]

self._init_test_cases()
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))
for test_cache in [False, True]:
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
# torch.compiled version of the gmm will cache the payload in dynamo layer
# hence won't trigger the trace_pallas cache
if test_cache and gmm_func != compiled_gmm:
met.clear_counters()
# execute the same gmm func, expected to hit the cache
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_bf16(self):
met.clear_all()

gmm_funcs = [
gmm, torch.ops.xla.gmm,
torch.compile(torch.ops.xla.gmm, backend="openxla")
]
compiled_gmm = torch.compile(torch.ops.xla.gmm, backend="openxla")
gmm_funcs = [gmm, torch.ops.xla.gmm, compiled_gmm]
self._init_test_cases()
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))

self.assertTrue(torch.allclose(ref_out, out.cpu()))
for test_cache in [False, True]:
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
# torch.compiled version of the gmm will cache the payload in dynamo layer
# hence won't trigger the trace_pallas cache
if test_cache and gmm_func != compiled_gmm:
met.clear_counters()
# execute the same gmm func, expected to hit the cache
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_make_group_metadata(self):
Expand Down Expand Up @@ -313,47 +328,59 @@ def test_tgmm(self):
jax.config.update('jax_default_matmul_precision', "highest")

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(k, m, dtype=lhs_dtype)
rhs = torch.rand(m, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)
for test_cache in [False, True]:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))
lhs = torch.rand(k, m, dtype=lhs_dtype)
rhs = torch.rand(m, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
if test_cache:
met.clear_counters()
# execute the same gmm func, expected to hit the cache
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure tgmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm_bf16(self):
met.clear_all()

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(k, m, dtype=lhs_dtype)
rhs = torch.rand(m, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)
for test_cache in [False, True]:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))
lhs = torch.rand(k, m, dtype=lhs_dtype)
rhs = torch.rand(m, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
if test_cache:
met.clear_counters()
# execute the same gmm func, expected to hit the cache
out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 1)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure tgmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward(self):
Expand All @@ -365,25 +392,31 @@ def test_gmm_backward(self):
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
lhs.retain_grad()
rhs.retain_grad()
for test_cache in [False, True]:
met.clear_all()
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
lhs.retain_grad()
rhs.retain_grad()

ref_out = self._reference_gmm(lhs, rhs, group_sizes)
ref_out.sum().backward()
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
ref_out.sum().backward()

ref_out_backward = torch.ones_like(ref_out)
grad_lhs, grad_rhs = gmm_backward(
ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"),
group_sizes.to("xla"))
ref_out_backward = torch.ones_like(ref_out)
grad_lhs, grad_rhs = gmm_backward(
ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"),
group_sizes.to("xla"))
# same gmm/tgmm was run for the `test_cache=False` case so the
# cache should be populated now
if test_cache:
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), 2)

self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))
self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward_2(self):
Expand Down Expand Up @@ -420,7 +453,7 @@ def test_gmm_backward_2(self):
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward_3(self):
Expand Down Expand Up @@ -458,7 +491,32 @@ def test_gmm_backward_3(self):
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
self.assertEqual(len(torch_xla._XLAC._get_executed_fallback_ops()), 0)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_cache_miss(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', "highest")

self._init_test_cases()
test_case = self.tests_cases[-1]
# make sure that cache miss for different input shapes and dtype
met.clear_all()
for mul_factor in [[2, 1, 1, 1], [1, 2, 1, 1], [2, 1, 2, 1], [2, 1, 1, 2]]:
for dtype in [torch.float32, torch.bfloat16]:
for tiling in [(128, 128, 128), (256, 256, 256)]:
num_groups = test_case['num_groups'] * mul_factor[0]
k = test_case['k'] * mul_factor[1]
m = test_case['m'] * mul_factor[2]
n = test_case['n'] * mul_factor[3]
lhs_dtype = rhs_dtype = dtype

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)

out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"), tiling)
self.assertEqual(met.counter_value('trace_pallas_cache_hit'), None)


if __name__ == '__main__':
Expand Down
24 changes: 23 additions & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met

from typing import Any, List, Callable, Optional, Tuple
from typing import Any, List, Callable, Optional, Tuple, Dict
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

Expand Down Expand Up @@ -93,10 +93,14 @@ def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
convert_torch_dtype_to_jax(tensor.dtype))


trace_pallas_arg_to_payload: Dict[Tuple[Any], str] = {}


def trace_pallas(kernel: Callable,
*args,
static_argnums=None,
static_argnames=None,
use_cache=False,
**kwargs):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
Expand All @@ -116,11 +120,27 @@ def trace_pallas(kernel: Callable,
else:
jax_args.append(arg)

hash_key = ()
if use_cache:
global trace_pallas_arg_to_payload
# implcit assumption here that everything in kwargs is hashable and not a tensor,
# which is true for the gmm and tgmm.
hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args),
repr(sorted(kwargs.items())).encode())
if hash_key in trace_pallas_arg_to_payload:
torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1)
return trace_pallas_arg_to_payload[hash_key], tensor_args

# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
ir = jax.jit(
kernel, static_argnums=static_argnums,
static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
payload = _extract_backend_config(ir)

if use_cache:
# if we reach here it means we have a cache miss.
trace_pallas_arg_to_payload[hash_key] = payload

return payload, tensor_args


Expand Down Expand Up @@ -770,6 +790,7 @@ def gmm(
rhs,
group_sizes,
static_argnames=["tiling", "preferred_element_type"],
use_cache=True,
preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type),
tiling=(tm, tk, tn))

Expand Down Expand Up @@ -822,6 +843,7 @@ def tgmm(
rhs,
group_sizes,
static_argnames=["tiling", "preferred_element_type"],
use_cache=True,
preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type),
tiling=(tm, tk, tn))

Expand Down

0 comments on commit 8955571

Please sign in to comment.