diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2a96b010ab7..5d12932012f 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -28,7 +28,6 @@ "diagonal_copy", "diagonal_scatter", "digamma", - "erfinv", "exponential", "gcd", "geometric", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 8f34c675df7..eb9b04cf1c6 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2057,6 +2057,12 @@ def _aten_erf(x): return jax.lax.erf(x) +@op(torch.ops.aten.erfinv) +@op_base.promote_int_input +def _aten_erfinv(input): + return jax.lax.erf_inv(input) + + # aten.exp @op(torch.ops.aten.exp) def _aten_exp(input):