Skip to content

Commit a041ea1

Browse files
Skip test_jnp_einsum_grad_y_pallas on gpu due to ooms
PiperOrigin-RevId: 695143627
1 parent 098d582 commit a041ea1

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/pallas/ops_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,9 @@ def isnan(x_ref, o_ref):
10341034
np.testing.assert_allclose(isnan(x), jnp.isnan(x))
10351035

10361036
def test_jnp_einsum_grad_y_pallas(self):
1037+
if jtu.test_device_matches(["gpu"]):
1038+
self.skipTest("This test ooms on gpu")
1039+
10371040
x = jnp.arange(128 * 256, dtype=jnp.float32).reshape((128, 256))
10381041
y = jnp.arange(256 * 128, dtype=jnp.float32).reshape((128, 256))
10391042

0 commit comments

Comments
 (0)