Skip to content

Commit 7760d87

Browse files
simonteozwSimon Teo
andauthored
Fix nn.functional.bilinear (#7517) (#8256)
Co-authored-by: Simon Teo <[email protected]>
1 parent 80db07b commit 7760d87

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
"nn.functional.adaptive_max_pool2d",
4747
"nn.functional.adaptive_max_pool3d",
4848
"nn.functional.alpha_dropout",
49-
"nn.functional.bilinear",
5049
"nn.functional.conv_transpose1d",
5150
"nn.functional.conv_transpose2d",
5251
"nn.functional.conv_transpose3d",

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4388,6 +4388,11 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
43884388
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)
43894389

43904390

4391+
@op(torch.ops.aten._trilinear)
4392+
def _aten_trilinear(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1):
4393+
return _aten_sum(jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * jnp.expand_dims(i3, expand3), sumdim)
4394+
4395+
43914396
@op(torch.ops.aten.max_unpool2d)
43924397
@op(torch.ops.aten.max_unpool3d)
43934398
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):

0 commit comments

Comments
 (0)