From af223d324856e9e289b80f02a923488ec7dedfa0 Mon Sep 17 00:00:00 2001 From: David Huang Date: Fri, 10 Jan 2025 10:37:12 -0800 Subject: [PATCH] [torch_xla2] Fix reenabled op info tests (#8548) --- experimental/torch_xla2/test/test_ops.py | 7 ++----- experimental/torch_xla2/torch_xla2/ops/jaten.py | 12 +++++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 516f7bd7f86..e3b686f68ad 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -12,14 +12,11 @@ skiplist = { "_segment_reduce", - "_unsafe_masked_index_put_accumulate", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", "cat", "cholesky_solve", - "cov", "diagonal_copy", - "gather", "geqrf", "histogram", # hard op: AssertionError: Tensor-likes are not close! "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. @@ -44,7 +41,6 @@ "normal", "ormqr", "pca_lowrank", - "scatter", "searchsorted", "special.airy_ai", "special.scaled_modified_bessel_k0", @@ -96,7 +92,8 @@ 'nn.functional.dropout', } -atol_dict = {"linalg.eig": (2e0, 3e0), +atol_dict = {"cov": (2e-1, 2e-4), + "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0), "linalg.eigvalsh": (5e1, 3e0), "linalg.pinv": (8e-1, 2e0), diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index fe2b56c2301..a5d7c21b2a3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1745,9 +1745,8 @@ def _aten_atan(self): @op(torch.ops.aten.scatter) @op(torch.ops.aten.scatter_reduce) def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): - if isinstance(src, float): - dtype = _torch_binary_scalar_type(src, input) - src = jnp.array(src, dtype=dtype) + if not isinstance(src, jnp.ndarray): + src = jnp.array(src, dtype=input.dtype) input_indexes, source_indexes = _scatter_index(dim, index) # "Zero out" target elements when not included if not include_self: @@ -2596,6 +2595,9 @@ def _aten_frexp(input): def _aten_gather(input, dim, index): if input.ndim == 0: return jnp.broadcast_to(input, index.shape) + # short circuit for empty outputs + if not all(index.shape): + return jnp.zeros(index.shape, dtype=input.dtype) if dim < 0: dim += input.ndim input_indexes, source_indexes = _scatter_index(dim, index) @@ -4732,9 +4734,9 @@ def _new_empty_strided(self, size, stride, dtype=None, **kwargs): return jnp.empty(size, dtype=jax_dtype) -@op(torch.ops.aten._unsafe_index_put, is_jax_function=False) +@op(torch.ops.aten._unsafe_index_put) def _aten_unsafe_index_put(self, indices, values, accumulate=False): - return self.index_put_(indices, values, accumulate) + return _aten_index_put(self, indices, values, accumulate) @op(torch.ops.aten.conj_physical,