Skip to content

Commit ca68448

Browse files
odimkacopybara-github
authored andcommitted
Support negative values for dim= in kd.index.
PiperOrigin-RevId: 713646153 Change-Id: I5b442f709b7c6b763c56a7a9ab20816f3835b6c7
1 parent 37fb836 commit ca68448

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

py/koladata/operators/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ py_library(
470470
deps = [
471471
":arolla_bridge",
472472
":assertion",
473+
":comparison",
473474
":functor",
474475
":jagged_shape",
475476
":masking",

py/koladata/operators/slices.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from arolla.jagged_shape import jagged_shape
2121
from koladata.operators import arolla_bridge
2222
from koladata.operators import assertion
23+
from koladata.operators import comparison
2324
from koladata.operators import functor
2425
from koladata.operators import jagged_shape as jagged_shape_ops
2526
from koladata.operators import masking
@@ -762,15 +763,28 @@ def group_by(x, *args):
762763
return dispatch_op(x, args)
763764

764765

766+
@optools.as_lambda_operator(
767+
'kde.slices._normalize_dim',
768+
qtype_constraints=[
769+
qtype_utils.expect_data_slice(P.x),
770+
qtype_utils.expect_data_slice(P.dim),
771+
],
772+
)
773+
def normalize_dim(x, dim):
774+
"""Returns dim if dim >= 0, otherwise get_ndim(x) + dim."""
775+
# TODO: masking.cond can be slow, optimize
776+
return masking.cond(comparison.less(dim, 0), get_ndim(x) + dim, dim)
777+
778+
765779
@optools.add_to_registry(aliases=['kde.index'])
766780
@optools.as_lambda_operator(
767781
'kde.slices.index',
768782
qtype_constraints=[
769783
qtype_utils.expect_data_slice(P.x),
770-
qtype_utils.expect_data_slice_or_unspecified(P.dim),
784+
qtype_utils.expect_data_slice(P.dim),
771785
],
772786
)
773-
def index(x, dim=arolla.unspecified()):
787+
def index(x, dim=-1):
774788
"""Returns the indices of the elements computed over the last dim dimensions.
775789
776790
The resulting slice has the same shape as the input.
@@ -790,24 +804,23 @@ def index(x, dim=arolla.unspecified()):
790804
# -> kd.slice([[[0, None, 0], [0, 0]], [[None, 1], [1, 1, 1]]])
791805
kd.index(ds, dim=1)
792806
# -> kd.slice([[[0, None, 0], [1, 1]], [[None, 0], [1, 1, 1]]])
793-
kd.index(ds, dim=2)
807+
kd.index(ds, dim=2) # (same as kd.index(ds, -1) or kd.index(ds))
794808
# -> kd.slice([[[0, None, 2], [0, 1]], [[None, 1], [0, 1, 2]]])
795809
796810
kd.index(ds) -> kd.index(ds, dim=ds.get_ndim() - 1)
797811
798812
Args:
799813
x: A DataSlice.
800-
dim: The dimension to compute indices over. Requires 0 <= dim < get_ndim(x).
801-
If unspecified, it is set to the last dimension of x.
814+
dim: The dimension to compute indices over. Requires abs(dim) < get_ndim(x).
815+
If dim < 0 then dim = get_ndim(x) + dim.
802816
"""
803817
x = assertion.with_assertion(
804818
x,
805819
get_ndim(x) != 0,
806820
'kde.slices.index: argument `x` must have non-zero rank',
807821
)
808822

809-
dim = M.core.default_if_unspecified(dim, get_ndim(x) - 1)
810-
823+
dim = normalize_dim(x, dim)
811824
ndim = get_ndim(x) - dim - 1
812825
ndim = assertion.with_assertion(
813826
ndim,

py/koladata/operators/tests/slices_index_test.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,12 @@
4444
QTYPES = frozenset([
4545
(DATA_SLICE, DATA_SLICE),
4646
(DATA_SLICE, DATA_SLICE, DATA_SLICE),
47-
(DATA_SLICE, arolla.UNSPECIFIED, DATA_SLICE),
4847
])
4948

5049

5150
class SlicesIndexTest(parameterized.TestCase):
5251

5352
@parameterized.parameters(
54-
(ds([5, 6, None, 7]), arolla.unspecified(), ds([0, 1, None, 3], INT64)),
5553
(
5654
ds([[5, None, 5], [6, 7], [None, None]]),
5755
ds(0),
@@ -62,6 +60,11 @@ class SlicesIndexTest(parameterized.TestCase):
6260
ds(1),
6361
ds([[0, None, 2], [0, 1], [None, None]], INT64),
6462
),
63+
(
64+
ds([[5, None, 5], [6, 7], [None, None]]),
65+
ds(-1),
66+
ds([[0, None, 2], [0, 1], [None, None]], INT64),
67+
),
6568
(
6669
ds([[['a', 'b', 'c'], ['d', 'e']], [['f', 'g'], ['h', 'i', 'j']]]),
6770
ds(0),
@@ -72,11 +75,21 @@ class SlicesIndexTest(parameterized.TestCase):
7275
ds(1),
7376
ds([[[0, 0, 0], [1, 1]], [[0, 0], [1, 1, 1]]], INT64),
7477
),
78+
(
79+
ds([[['a', 'b', 'c'], ['d', 'e']], [['f', 'g'], ['h', 'i', 'j']]]),
80+
ds(-2),
81+
ds([[[0, 0, 0], [1, 1]], [[0, 0], [1, 1, 1]]], INT64),
82+
),
7583
(
7684
ds([[['a', 'b', 'c'], ['d', 'e']], [['f', 'g'], ['h', 'i', 'j']]]),
7785
ds(2),
7886
ds([[[0, 1, 2], [0, 1]], [[0, 1], [0, 1, 2]]], INT64),
7987
),
88+
(
89+
ds([[['a', 'b', 'c'], ['d', 'e']], [['f', 'g'], ['h', 'i', 'j']]]),
90+
ds(-1),
91+
ds([[[0, 1, 2], [0, 1]], [[0, 1], [0, 1, 2]]], INT64),
92+
),
8093
)
8194
def test_eval(self, x, dim, expected_value):
8295
actual_value = expr_eval.eval(kde.slices.index(x, dim))
@@ -114,7 +127,7 @@ def test_data_item_input_error(self):
114127
):
115128
expr_eval.eval(kde.slices.index(x))
116129

117-
@parameterized.parameters(-1, 2)
130+
@parameterized.parameters(2, -2)
118131
def test_out_of_bounds_ndim_error(self, ndim):
119132
x = data_slice.DataSlice.from_vals([1, 2, 3])
120133
with self.assertRaisesRegex(

py/koladata/operators/tests/slices_range_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_eval(self, *args_and_expected):
9595
def test_qtype_signatures(self):
9696
self.assertCountEqual(
9797
arolla.testing.detect_qtype_signatures(
98-
kde.slices.index,
98+
kde.slices.range,
9999
possible_qtypes=test_qtypes.DETECT_SIGNATURES_QTYPES,
100100
),
101101
QTYPES,

0 commit comments

Comments
 (0)