Skip to content

Commit 3f6c318

Browse files
committed
Skip test_batch_axis_sharding_jvp because of hipSPARSE issue
1 parent 7a5431c commit 3f6c318

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/linalg_sharding_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def test_non_batch_axis_sharding(self, fun_and_shapes, dtype):
158158
)
159159
@jtu.run_on_devices("gpu", "cpu")
160160
def test_batch_axis_sharding_jvp(self, fun_and_shapes, dtype):
161+
if fun_and_shapes[0] is lax.linalg.tridiagonal_solve and jtu.is_device_rocm():
162+
self.skipTest("test_batch_axis_sharding_jvp is not supported on ROCm")
161163
fun, shapes = self.get_fun_and_shapes(fun_and_shapes, grad=True)
162164
primals = self.get_args(shapes, dtype, batch_size=8)
163165
tangents = tuple(map(jnp.ones_like, primals))

0 commit comments

Comments
 (0)