@@ -2033,11 +2033,12 @@ def _():
20332033 np .testing .assert_allclose (out , expected , atol = atol )
20342034
20352035
2036- class PallasCheckifyInterpretTest (PallasBaseTest ):
2037- # TODO(b/346651778): Support non-interpret mode checkify.
2038- INTERPRET = True
2036+ class PallasCheckifyTest (PallasBaseTest ):
2037+ INTERPRET = False
20392038
20402039 def test_no_checkify (self ,):
2040+ if jtu .test_device_matches (["gpu" ]):
2041+ self .skipTest ("Not supported on GPU." )
20412042 def kernel (y_ref ):
20422043 y_ref [...] = jnp .zeros_like (y_ref [...])
20432044 out_shape = jax .ShapeDtypeStruct ((2 , 2 ), jnp .float32 )
@@ -2049,6 +2050,8 @@ def kernel(y_ref):
20492050 np .testing .assert_allclose (result , jnp .zeros_like (result ))
20502051
20512052 def test_does_not_clobber_previous_error (self ,):
2053+ if jtu .test_device_matches (["gpu" ]):
2054+ self .skipTest ("Not supported on GPU." )
20522055 def kernel (y_ref ):
20532056 y_ref [...] = jnp .zeros_like (y_ref [...])
20542057 checkify .check (False , "error in kernel" )
@@ -2067,6 +2070,8 @@ def error_before_call():
20672070
20682071 @parameterized .parameters ((False ,), (True ,))
20692072 def test_trivial_check (self , assert_cond ):
2073+ if jtu .test_device_matches (["gpu" ]):
2074+ self .skipTest ("Not supported on GPU." )
20702075 def kernel (x_ref , y_ref ):
20712076 y_ref [...] = x_ref [...]
20722077 checkify .check (assert_cond , "pallas check failed" )
@@ -2083,14 +2088,16 @@ def kernel(x_ref, y_ref):
20832088 np .testing .assert_allclose (result , input )
20842089
20852090 def test_nan_error (self ):
2091+ if not self .INTERPRET :
2092+ self .skipTest ("Not supported in non-interpret mode." )
20862093 def kernel (x_ref , y_ref ):
20872094 y_ref [...] = jnp .log (x_ref [...])
20882095 input = jnp .arange (4 , dtype = jnp .float32 ) - 2
20892096 out_shape = jax .ShapeDtypeStruct (input .shape , input .dtype )
20902097 pallas_call = self .pallas_call (kernel ,
20912098 out_shape = out_shape )
20922099 checked_call = checkify .checkify (pallas_call ,
2093- errors = checkify .all_checks )
2100+ errors = checkify .nan_checks )
20942101 err , result = checked_call (input )
20952102 with self .assertRaisesRegex (
20962103 checkify .JaxRuntimeError , "nan generated by primitive: log" ):
@@ -2119,6 +2126,8 @@ def kernel(x_ref, y_ref):
21192126 @parameterized .parameters ((5 , 0 ), (8 , 3 ), (4 , 3 ))
21202127 def test_checkify_returns_first_error_in_grid (
21212128 self , num_loops , fail_iteration ):
2129+ if not self .INTERPRET :
2130+ self .skipTest ("Not supported in non-interpret mode." )
21222131 # Check that checkify returns the first error that occurs
21232132 # TODO(justinfu): This test doesn't make sense on GPU, where threads run
21242133 # in parallel. Update checkify to return a grid of errors.
@@ -2137,12 +2146,42 @@ def kernel(x_ref, _):
21372146 out_shape = out_shape )
21382147
21392148 checked_call = checkify .checkify (pallas_call ,
2140- errors = checkify .all_checks )
2149+ errors = checkify .user_checks )
21412150 err , _ = checked_call (input_arr )
21422151 with self .assertRaisesRegex (
21432152 checkify .JaxRuntimeError , f"failed on loop { fail_iteration } " ):
21442153 err .throw ()
21452154
2155+ def test_checkify_on_oob_grid_access (self ):
2156+ if not self .INTERPRET :
2157+ self .skipTest ("Not supported in non-interpret mode." )
2158+ if config .enable_x64 .value :
2159+ self .skipTest ("Not supported in x64 mode." )
2160+ def kernel (x_ref , o_ref ):
2161+ o_ref [...] = x_ref [...]
2162+ input_arr = jnp .arange (18 , dtype = jnp .float32 )
2163+ in_specs = [pl .BlockSpec ((8 ,), lambda x : (x ,))]
2164+ out_specs = pl .BlockSpec ((8 ,), lambda x : (x ,))
2165+ out_shape = jax .ShapeDtypeStruct ((18 ,), dtype = jnp .float32 )
2166+ pallas_call = self .pallas_call (kernel ,
2167+ grid = (3 ,),
2168+ in_specs = in_specs ,
2169+ out_specs = out_specs ,
2170+ out_shape = out_shape )
2171+
2172+ checked_call = checkify .checkify (pallas_call ,
2173+ errors = checkify .index_checks )
2174+ err , result = checked_call (input_arr )
2175+ with self .assertRaisesRegex (checkify .JaxRuntimeError ,
2176+ (r"out-of-bounds indexing for array of shape \(18,\): index 16 "
2177+ r"is out of bounds for axis 0 with size 18" )):
2178+ err .throw ()
2179+ np .testing .assert_array_equal (result , input_arr )
2180+
2181+
2182+ class PallasCheckifyInterpretTest (PallasCheckifyTest ):
2183+ INTERPRET = True
2184+
21462185
21472186class PallasCallNamedGridTest (PallasBaseTest ):
21482187 def test_named_grid (self ):
0 commit comments