Skip to content

Commit d4c90c9

Browse files
authored
[torch_xla2] Enable nn.functional.cosine_embedding_loss (#8368)
1 parent 7220aee commit d4c90c9

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,22 @@
3737
"nn.functional.conv_transpose1d",
3838
"nn.functional.conv_transpose2d",
3939
"nn.functional.conv_transpose3d",
40-
"nn.functional.cosine_embedding_loss",
4140
"nn.functional.ctc_loss",
4241
"nn.functional.dropout2d",
4342
"nn.functional.dropout3d",
4443
"nn.functional.dropout",
4544
"nn.functional.embedding_bag",
4645
"nn.functional.fractional_max_pool2d",
4746
"nn.functional.fractional_max_pool3d",
48-
"nn.functional.group_norm",
49-
"nn.functional.hinge_embedding_loss",
5047
"nn.functional.interpolate",
51-
"nn.functional.margin_ranking_loss",
5248
"nn.functional.max_pool1d",
5349
"nn.functional.max_pool2d",
5450
"nn.functional.max_pool3d",
5551
"nn.functional.multi_head_attention_forward",
56-
"nn.functional.multi_margin_loss",
5752
"nn.functional.multilabel_margin_loss",
5853
"nn.functional.pairwise_distance",
5954
"nn.functional.poisson_nll_loss",
6055
"nn.functional.rrelu",
61-
"nn.functional.triplet_margin_loss",
62-
"nn.functional.triplet_margin_with_distance_loss",
6356
"nn.functional.upsample_nearest",
6457
"nonzero",
6558
"nonzero_static",

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
torch.ops.aten.relu_: torch.ops.aten.relu,
4444
# squeeze_ is expected to change tensor's shape. So replace with new value
4545
torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
46+
torch.ops.aten.sqrt_: torch.ops.aten.sqrt,
4647
torch.ops.aten.clamp_: torch.ops.aten.clamp,
4748
torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
4849
torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
@@ -112,7 +113,11 @@ def _aten_add(x, y, *, alpha=1):
112113
113114
assert x.dtype == y.dtype, (x.dtype, y.dtype)
114115
"""
115-
return x + y * alpha
116+
res = x + y * alpha
117+
if isinstance(x, float) or isinstance(y, float):
118+
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
119+
res = res.astype(new_dtype)
120+
return res
116121

117122

118123
@op(torch.ops.aten.copy_, is_jax_function=False)
@@ -169,6 +174,16 @@ def _aten_cauchy_(x, median=0, sigma=1):
169174
return x.at[:].set(samples)
170175

171176

177+
@op(torch.ops.aten.atleast_2d)
178+
def _aten_atleast_2d(inputs):
179+
return jnp.atleast_2d(inputs)
180+
181+
182+
@op(torch.ops.aten.atleast_1d)
183+
def _aten_atleast_1d(inputs):
184+
return jnp.atleast_1d(inputs)
185+
186+
172187
# aten.complex
173188
@op(torch.ops.aten.complex)
174189
def _aten_complex(real, imag):
@@ -281,6 +296,10 @@ def _aten_mul(x, y):
281296
res = x * y
282297
if isinstance(x, float) or isinstance(y, float):
283298
res = res.astype(new_dtype)
299+
else:
300+
if (not isinstance(x, int)) and (not isinstance(y, int)):
301+
if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64):
302+
res = res.astype(new_dtype)
284303
return res
285304

286305

@@ -1284,6 +1303,9 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):
12841303

12851304
input_shape = input.shape
12861305

1306+
if 0 in input_shape:
1307+
return input, input, input
1308+
12871309
# Reshape for group-wise normalization
12881310
reshaped_input = jnp.reshape(input, (1, N * group, -1))
12891311

0 commit comments

Comments
 (0)