Skip to content

Commit

Permalink
Fix nn.functional.bilinear (#7517) (#8256)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Teo <[email protected]>
  • Loading branch information
simonteozw and Simon Teo authored Oct 14, 2024
1 parent 80db07b commit 7760d87
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"nn.functional.adaptive_max_pool2d",
"nn.functional.adaptive_max_pool3d",
"nn.functional.alpha_dropout",
"nn.functional.bilinear",
"nn.functional.conv_transpose1d",
"nn.functional.conv_transpose2d",
"nn.functional.conv_transpose3d",
Expand Down
5 changes: 5 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4388,6 +4388,11 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)


@op(torch.ops.aten._trilinear)
def _aten_trilinear(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1):
return _aten_sum(jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * jnp.expand_dims(i3, expand3), sumdim)


@op(torch.ops.aten.max_unpool2d)
@op(torch.ops.aten.max_unpool3d)
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
Expand Down

0 comments on commit 7760d87

Please sign in to comment.