Skip to content

Commit d9c01b1

Browse files
committed
fix test fail on jax 0.5.0
1 parent b05d203 commit d9c01b1

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

experimental/torch_xla2/test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pytest
55
pytest-xdist
66
sentencepiece
77
expecttest
8+
optax

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4796,7 +4796,13 @@ def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs):
47964796
# torch.linalg.solve
47974797
@op(torch.ops.aten._linalg_solve_ex)
47984798
def _aten__linalg_solve_ex(a, b):
4799+
batched = False
4800+
if b.ndim > 1 and b.shape[-1] == a.shape[-1]:
4801+
batched = True
4802+
b = b[..., None]
47994803
res = jnp.linalg.solve(a, b)
4804+
if batched:
4805+
res = res.squeeze(-1)
48004806
info_shape = a.shape[0] if len(a.shape) >= 3 else []
48014807
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
48024808
return res, info

0 commit comments

Comments
 (0)