Skip to content

Commit dbdebfd

Browse files
simonteozwSimon Teo
andauthored
Implement bucketize in torch aten ops (#7396) (#8007)
Co-authored-by: Simon Teo <[email protected]>
1 parent 9c7f083 commit dbdebfd

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"_upsample_bilinear2d_aa",
1717
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
1818
"block_diag",
19-
"bucketize",
2019
"byte",
2120
"cat",
2221
"cauchy",

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,11 @@ def fix_dim(p):
698698
new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
699699
return self.reshape(new_shape)
700700

701+
@op(torch.ops.aten.bucketize)
702+
def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
703+
assert boundaries[0] < boundaries[-1], "boundaries must contain a strictly increasing sequence"
704+
return_type = jnp.int32 if out_int32 else jnp.int64
705+
return jnp.digitize(input, boundaries, right=not right).astype(return_type)
701706

702707
@op(torch.ops.aten.convolution)
703708
def _aten_convolution(

0 commit comments

Comments
 (0)