Skip to content

Commit 9cf952a

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
1 parent 9748e2a commit 9cf952a

File tree

2 files changed

+127
-5
lines changed

2 files changed

+127
-5
lines changed

jax/_src/pallas/pallas_call.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,85 @@ def checkify_pallas_kernel_body_jaxpr(
993993
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
994994
return checked_jaxpr, out_tree, error_effects
995995

996+
def pallas_call_checkify_oob_grid(error: checkify.Error,
997+
enabled_errors,
998+
args: jax_core.Value,
999+
grid_mapping: GridMapping,
1000+
input_output_aliases) -> checkify.Error:
1001+
if checkify.OOBError not in enabled_errors:
1002+
return error
1003+
dynamic_grid_args, args = split_list(
1004+
args, [grid_mapping.num_dynamic_grid_bounds]
1005+
)
1006+
output_args = _initialize_output_vals(grid_mapping.block_mappings_output,
1007+
args, input_output_aliases)
1008+
scalars, input_args, _ = split_list(
1009+
args, [grid_mapping.num_index_operands,
1010+
grid_mapping.num_inputs],
1011+
)
1012+
dynamic_grid_args_iter = iter(dynamic_grid_args)
1013+
grid = tuple(
1014+
a if a is not pallas_core.dynamic_grid_dim
1015+
else next(dynamic_grid_args_iter)
1016+
for a in grid_mapping.grid
1017+
)
1018+
grid_start_indices = (jnp.int32(0),) * len(grid)
1019+
if grid:
1020+
num_iterations = reduce(jnp.multiply, grid)
1021+
else:
1022+
# Base case is always one iteration when grid is ()
1023+
num_iterations = 1
1024+
1025+
is_indexing_dim = [
1026+
tuple(b is pallas_core.mapped for b in bm.block_shape)
1027+
for bm in grid_mapping.block_mappings
1028+
]
1029+
block_shapes = [
1030+
None if iid is None
1031+
else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
1032+
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
1033+
]
1034+
# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
1035+
# i:int32 is the interation index
1036+
# loop_idx: tuple[int32] are the program ids for each grid axis
1037+
def cond(carry):
1038+
i, *_ = carry
1039+
return i < num_iterations
1040+
def body(carry):
1041+
i, loop_idx = carry
1042+
if grid_mapping.local_grid_env is not None:
1043+
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
1044+
else:
1045+
local_grid_env = tuple(
1046+
pallas_core.GridAxis(idx, b)
1047+
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
1048+
if dim not in grid_mapping.vmapped_dims
1049+
)
1050+
with pallas_core.grid_env(local_grid_env):
1051+
start_indices = [
1052+
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
1053+
for bm in grid_mapping.block_mappings]
1054+
# We perform a dynamic slice on the i/o blocks, which will be checked by
1055+
# checkify for OOB accesses.
1056+
map(_maybe_dynamic_slice, start_indices, block_shapes,
1057+
[*input_args, *output_args], is_indexing_dim)
1058+
return (i + 1, _get_next_indices(grid, loop_idx))
1059+
def f(_):
1060+
lax.while_loop(
1061+
cond, body, (jnp.int32(0), grid_start_indices)
1062+
)
1063+
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
1064+
wrapped_loop, _ = api_util.flatten_fun_nokwargs(
1065+
lu.wrap_init(f), jaxpr_in_tree)
1066+
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
1067+
avals_in = map(jax_core.get_aval, flat_args)
1068+
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
1069+
wrapped_loop, list(avals_in))
1070+
traced_loop = jax_core.ClosedJaxpr(traced_loop, consts)
1071+
out_error, _ = checkify.checkify_jaxpr(
1072+
traced_loop, checkify.index_checks, error, flat_args)
1073+
return out_error
1074+
9961075
def pallas_call_checkify_rule(error: checkify.Error,
9971076
enabled_errors,
9981077
*args: jax_core.Value,
@@ -1002,6 +1081,10 @@ def pallas_call_checkify_rule(error: checkify.Error,
10021081
grid_mapping: GridMapping,
10031082
out_avals: tuple[jax_core.AbstractValue, ...],
10041083
**kwargs):
1084+
# Check for OOB accesses in the grid.
1085+
error = pallas_call_checkify_oob_grid(error, enabled_errors,
1086+
args, grid_mapping,
1087+
input_output_aliases)
10051088
# We implement the checkify rule in 4 steps:
10061089
# 1) First, trace the kernel body to get the expected error shapes.
10071090
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs

tests/pallas/pallas_test.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,11 +2033,12 @@ def _():
20332033
np.testing.assert_allclose(out, expected, atol=atol)
20342034

20352035

2036-
class PallasCheckifyInterpretTest(PallasBaseTest):
2037-
# TODO(b/346651778): Support non-interpret mode checkify.
2038-
INTERPRET = True
2036+
class PallasCheckifyTest(PallasBaseTest):
2037+
INTERPRET = False
20392038

20402039
def test_no_checkify(self,):
2040+
if jtu.test_device_matches(["gpu"]):
2041+
self.skipTest("Not supported on GPU.")
20412042
def kernel(y_ref):
20422043
y_ref[...] = jnp.zeros_like(y_ref[...])
20432044
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
@@ -2049,6 +2050,8 @@ def kernel(y_ref):
20492050
np.testing.assert_allclose(result, jnp.zeros_like(result))
20502051

20512052
def test_does_not_clobber_previous_error(self,):
2053+
if jtu.test_device_matches(["gpu"]):
2054+
self.skipTest("Not supported on GPU.")
20522055
def kernel(y_ref):
20532056
y_ref[...] = jnp.zeros_like(y_ref[...])
20542057
checkify.check(False, "error in kernel")
@@ -2067,6 +2070,8 @@ def error_before_call():
20672070

20682071
@parameterized.parameters((False,), (True,))
20692072
def test_trivial_check(self, assert_cond):
2073+
if jtu.test_device_matches(["gpu"]):
2074+
self.skipTest("Not supported on GPU.")
20702075
def kernel(x_ref, y_ref):
20712076
y_ref[...] = x_ref[...]
20722077
checkify.check(assert_cond, "pallas check failed")
@@ -2083,14 +2088,16 @@ def kernel(x_ref, y_ref):
20832088
np.testing.assert_allclose(result, input)
20842089

20852090
def test_nan_error(self):
2091+
if not self.INTERPRET:
2092+
self.skipTest("Not supported in non-interpret mode.")
20862093
def kernel(x_ref, y_ref):
20872094
y_ref[...] = jnp.log(x_ref[...])
20882095
input = jnp.arange(4, dtype=jnp.float32) - 2
20892096
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
20902097
pallas_call = self.pallas_call(kernel,
20912098
out_shape=out_shape)
20922099
checked_call = checkify.checkify(pallas_call,
2093-
errors=checkify.all_checks)
2100+
errors=checkify.nan_checks)
20942101
err, result = checked_call(input)
20952102
with self.assertRaisesRegex(
20962103
checkify.JaxRuntimeError, "nan generated by primitive: log"):
@@ -2119,6 +2126,8 @@ def kernel(x_ref, y_ref):
21192126
@parameterized.parameters((5, 0), (8, 3), (4, 3))
21202127
def test_checkify_returns_first_error_in_grid(
21212128
self, num_loops, fail_iteration):
2129+
if not self.INTERPRET:
2130+
self.skipTest("Not supported in non-interpret mode.")
21222131
# Check that checkify returns the first error that occurs
21232132
# TODO(justinfu): This test doesn't make sense on GPU, where threads run
21242133
# in parallel. Update checkify to return a grid of errors.
@@ -2137,12 +2146,42 @@ def kernel(x_ref, _):
21372146
out_shape=out_shape)
21382147

21392148
checked_call = checkify.checkify(pallas_call,
2140-
errors=checkify.all_checks)
2149+
errors=checkify.user_checks)
21412150
err, _ = checked_call(input_arr)
21422151
with self.assertRaisesRegex(
21432152
checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"):
21442153
err.throw()
21452154

2155+
def test_checkify_on_oob_grid_access(self):
2156+
if not self.INTERPRET:
2157+
self.skipTest("Not supported in non-interpret mode.")
2158+
if config.enable_x64.value:
2159+
self.skipTest("Not supported in x64 mode.")
2160+
def kernel(x_ref, o_ref):
2161+
o_ref[...] = x_ref[...]
2162+
input_arr = jnp.arange(18, dtype=jnp.float32)
2163+
in_specs = [pl.BlockSpec((8,), lambda x: (x,))]
2164+
out_specs = pl.BlockSpec((8,), lambda x: (x,))
2165+
out_shape = jax.ShapeDtypeStruct((18,), dtype=jnp.float32)
2166+
pallas_call = self.pallas_call(kernel,
2167+
grid=(3,),
2168+
in_specs=in_specs,
2169+
out_specs=out_specs,
2170+
out_shape=out_shape)
2171+
2172+
checked_call = checkify.checkify(pallas_call,
2173+
errors=checkify.index_checks)
2174+
err, result = checked_call(input_arr)
2175+
with self.assertRaisesRegex(checkify.JaxRuntimeError,
2176+
(r"out-of-bounds indexing for array of shape \(18,\): index 16 "
2177+
r"is out of bounds for axis 0 with size 18")):
2178+
err.throw()
2179+
np.testing.assert_array_equal(result, input_arr)
2180+
2181+
2182+
class PallasCheckifyInterpretTest(PallasCheckifyTest):
2183+
INTERPRET = True
2184+
21462185

21472186
class PallasCallNamedGridTest(PallasBaseTest):
21482187
def test_named_grid(self):

0 commit comments

Comments
 (0)