Skip to content

Commit

Permalink
Implement bucketize in torch aten ops (#7396) (#8007)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Teo <[email protected]>
  • Loading branch information
simonteozw and Simon Teo authored Sep 14, 2024
1 parent 9c7f083 commit dbdebfd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"block_diag",
"bucketize",
"byte",
"cat",
"cauchy",
Expand Down
5 changes: 5 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,11 @@ def fix_dim(p):
new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
return self.reshape(new_shape)

@op(torch.ops.aten.bucketize)
def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
assert boundaries[0] < boundaries[-1], "boundaries must contain a strictly increasing sequence"
return_type = jnp.int32 if out_int32 else jnp.int64
return jnp.digitize(input, boundaries, right=not right).astype(return_type)

@op(torch.ops.aten.convolution)
def _aten_convolution(
Expand Down

0 comments on commit dbdebfd

Please sign in to comment.