From d4c90c91749c5ff762258375edc1f4001204201c Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:07:58 -0800 Subject: [PATCH] [torch_xla2] Enable nn.functional.cosine_embedding_loss (#8368) --- experimental/torch_xla2/test/test_ops.py | 7 ------ .../torch_xla2/torch_xla2/ops/jaten.py | 24 ++++++++++++++++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3ace00ea495..5558b373b97 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -37,7 +37,6 @@ "nn.functional.conv_transpose1d", "nn.functional.conv_transpose2d", "nn.functional.conv_transpose3d", - "nn.functional.cosine_embedding_loss", "nn.functional.ctc_loss", "nn.functional.dropout2d", "nn.functional.dropout3d", @@ -45,21 +44,15 @@ "nn.functional.embedding_bag", "nn.functional.fractional_max_pool2d", "nn.functional.fractional_max_pool3d", - "nn.functional.group_norm", - "nn.functional.hinge_embedding_loss", "nn.functional.interpolate", - "nn.functional.margin_ranking_loss", "nn.functional.max_pool1d", "nn.functional.max_pool2d", "nn.functional.max_pool3d", "nn.functional.multi_head_attention_forward", - "nn.functional.multi_margin_loss", "nn.functional.multilabel_margin_loss", "nn.functional.pairwise_distance", "nn.functional.poisson_nll_loss", "nn.functional.rrelu", - "nn.functional.triplet_margin_loss", - "nn.functional.triplet_margin_with_distance_loss", "nn.functional.upsample_nearest", "nonzero", "nonzero_static", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 9b24b075cec..c72459cbdb3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -43,6 +43,7 @@ torch.ops.aten.relu_: torch.ops.aten.relu, # squeeze_ is expected to change tensor's shape. So replace with new value torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True), + torch.ops.aten.sqrt_: torch.ops.aten.sqrt, torch.ops.aten.clamp_: torch.ops.aten.clamp, torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min, torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid, @@ -112,7 +113,11 @@ def _aten_add(x, y, *, alpha=1): assert x.dtype == y.dtype, (x.dtype, y.dtype) """ - return x + y * alpha + res = x + y * alpha + if isinstance(x, float) or isinstance(y, float): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = res.astype(new_dtype) + return res @op(torch.ops.aten.copy_, is_jax_function=False) @@ -169,6 +174,16 @@ def _aten_cauchy_(x, median=0, sigma=1): return x.at[:].set(samples) +@op(torch.ops.aten.atleast_2d) +def _aten_atleast_2d(inputs): + return jnp.atleast_2d(inputs) + + +@op(torch.ops.aten.atleast_1d) +def _aten_atleast_1d(inputs): + return jnp.atleast_1d(inputs) + + # aten.complex @op(torch.ops.aten.complex) def _aten_complex(real, imag): @@ -281,6 +296,10 @@ def _aten_mul(x, y): res = x * y if isinstance(x, float) or isinstance(y, float): res = res.astype(new_dtype) + else: + if (not isinstance(x, int)) and (not isinstance(y, int)): + if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64): + res = res.astype(new_dtype) return res @@ -1284,6 +1303,9 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): input_shape = input.shape + if 0 in input_shape: + return input, input, input + # Reshape for group-wise normalization reshaped_input = jnp.reshape(input, (1, N * group, -1))