Skip to content

Commit

Permalink
fix test fail on jax 0.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jan 22, 2025
1 parent b05d203 commit d9c01b1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions experimental/torch_xla2/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pytest
pytest-xdist
sentencepiece
expecttest
optax
6 changes: 6 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4796,7 +4796,13 @@ def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs):
# torch.linalg.solve
@op(torch.ops.aten._linalg_solve_ex)
def _aten__linalg_solve_ex(a, b):
batched = False
if b.ndim > 1 and b.shape[-1] == a.shape[-1]:
batched = True
b = b[..., None]
res = jnp.linalg.solve(a, b)
if batched:
res = res.squeeze(-1)
info_shape = a.shape[0] if len(a.shape) >= 3 else []
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
return res, info
Expand Down

0 comments on commit d9c01b1

Please sign in to comment.