Skip to content

Commit

Permalink
[torch_xla2] Enable nn.functional.cosine_embedding_loss (#8368)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 13, 2024
1 parent 7220aee commit d4c90c9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
7 changes: 0 additions & 7 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,22 @@
"nn.functional.conv_transpose1d",
"nn.functional.conv_transpose2d",
"nn.functional.conv_transpose3d",
"nn.functional.cosine_embedding_loss",
"nn.functional.ctc_loss",
"nn.functional.dropout2d",
"nn.functional.dropout3d",
"nn.functional.dropout",
"nn.functional.embedding_bag",
"nn.functional.fractional_max_pool2d",
"nn.functional.fractional_max_pool3d",
"nn.functional.group_norm",
"nn.functional.hinge_embedding_loss",
"nn.functional.interpolate",
"nn.functional.margin_ranking_loss",
"nn.functional.max_pool1d",
"nn.functional.max_pool2d",
"nn.functional.max_pool3d",
"nn.functional.multi_head_attention_forward",
"nn.functional.multi_margin_loss",
"nn.functional.multilabel_margin_loss",
"nn.functional.pairwise_distance",
"nn.functional.poisson_nll_loss",
"nn.functional.rrelu",
"nn.functional.triplet_margin_loss",
"nn.functional.triplet_margin_with_distance_loss",
"nn.functional.upsample_nearest",
"nonzero",
"nonzero_static",
Expand Down
24 changes: 23 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
torch.ops.aten.relu_: torch.ops.aten.relu,
# squeeze_ is expected to change tensor's shape. So replace with new value
torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
torch.ops.aten.sqrt_: torch.ops.aten.sqrt,
torch.ops.aten.clamp_: torch.ops.aten.clamp,
torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
Expand Down Expand Up @@ -112,7 +113,11 @@ def _aten_add(x, y, *, alpha=1):
assert x.dtype == y.dtype, (x.dtype, y.dtype)
"""
return x + y * alpha
res = x + y * alpha
if isinstance(x, float) or isinstance(y, float):
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
res = res.astype(new_dtype)
return res


@op(torch.ops.aten.copy_, is_jax_function=False)
Expand Down Expand Up @@ -169,6 +174,16 @@ def _aten_cauchy_(x, median=0, sigma=1):
return x.at[:].set(samples)


@op(torch.ops.aten.atleast_2d)
def _aten_atleast_2d(inputs):
return jnp.atleast_2d(inputs)


@op(torch.ops.aten.atleast_1d)
def _aten_atleast_1d(inputs):
return jnp.atleast_1d(inputs)


# aten.complex
@op(torch.ops.aten.complex)
def _aten_complex(real, imag):
Expand Down Expand Up @@ -281,6 +296,10 @@ def _aten_mul(x, y):
res = x * y
if isinstance(x, float) or isinstance(y, float):
res = res.astype(new_dtype)
else:
if (not isinstance(x, int)) and (not isinstance(y, int)):
if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64):
res = res.astype(new_dtype)
return res


Expand Down Expand Up @@ -1284,6 +1303,9 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):

input_shape = input.shape

if 0 in input_shape:
return input, input, input

# Reshape for group-wise normalization
reshaped_input = jnp.reshape(input, (1, N * group, -1))

Expand Down

0 comments on commit d4c90c9

Please sign in to comment.