Skip to content

Commit dadaa9e

Browse files
petrmitrichevcopybara-github
authored andcommitted
Add support for kd.slice([kd.obj(...)]) in tracing mode, and surrounding fixes.
The design is roughly similar to kde.fstr. This adds some complexity, but kd.slice and kd.item are very frequently used and I feel it is justified to support this. PiperOrigin-RevId: 714905688 Change-Id: I438cf9f154401ec3c90638b52c751f4d26752a48
1 parent 8912222 commit dadaa9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2278
-135
lines changed

py/koladata/functions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ py_library(
3535
":schema",
3636
"//py/koladata/fstring",
3737
"//py/koladata/types:data_bag",
38+
"//py/koladata/types:data_item",
3839
"//py/koladata/types:data_slice",
3940
"//py/koladata/types:general_eager_ops",
4041
],

py/koladata/functions/functions.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from koladata.functions import s11n as _s11n
2727
from koladata.functions import schema as _schema
2828
from koladata.types import data_bag as _data_bag
29+
from koladata.types import data_item as _data_item
2930
from koladata.types import data_slice as _data_slice
3031
from koladata.types import general_eager_ops as _general_eager_ops
3132

@@ -159,6 +160,14 @@ def new_schema(
159160

160161
py_reference = _py_conversions.py_reference
161162

163+
from_proto = _proto_conversions.from_proto
164+
to_proto = _proto_conversions.to_proto
165+
166+
dumps = _s11n.dumps
167+
loads = _s11n.loads
168+
169+
slice = _data_slice.DataSlice.from_vals # pylint: disable=redefined-builtin
170+
item = _data_item.DataItem.from_vals
162171
int32 = _py_conversions.int32
163172
int64 = _py_conversions.int64
164173
float32 = _py_conversions.float32
@@ -169,9 +178,20 @@ def new_schema(
169178
mask = _py_conversions.mask
170179
expr_quote = _py_conversions.expr_quote
171180

172-
173-
from_proto = _proto_conversions.from_proto
174-
to_proto = _proto_conversions.to_proto
175-
176-
dumps = _s11n.dumps
177-
loads = _s11n.loads
181+
slices = _py_types.SimpleNamespace(
182+
# We use the top-level functions for `slice` and `item` instead of taking
183+
# DataSlice.from_vals again here since for functions implemented in C++
184+
# taking them twice from their class returns in different pointers,
185+
# so "assertIs" test for aliases fails otherwise.
186+
slice=slice,
187+
item=item,
188+
int32=_py_conversions.int32,
189+
int64=_py_conversions.int64,
190+
float32=_py_conversions.float32,
191+
float64=_py_conversions.float64,
192+
str=_py_conversions.str_,
193+
bytes=_py_conversions.bytes_,
194+
bool=_py_conversions.bool_,
195+
mask=_py_conversions.mask,
196+
expr_quote=_py_conversions.expr_quote,
197+
)

py/koladata/functions/tests/BUILD

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,37 @@ py_test(
850850
],
851851
)
852852

853+
py_test(
854+
name = "slice_test",
855+
srcs = ["slice_test.py"],
856+
deps = [
857+
"//py/koladata/functions",
858+
"//py/koladata/operators:kde_operators",
859+
"//py/koladata/testing",
860+
"//py/koladata/types:data_slice",
861+
"//py/koladata/types:mask_constants",
862+
"//py/koladata/types:schema_constants",
863+
"@com_google_absl_py//absl/testing:absltest",
864+
"@com_google_absl_py//absl/testing:parameterized",
865+
],
866+
)
867+
868+
py_test(
869+
name = "item_test",
870+
srcs = ["item_test.py"],
871+
deps = [
872+
"//py/koladata/functions",
873+
"//py/koladata/operators:kde_operators",
874+
"//py/koladata/testing",
875+
"//py/koladata/types:data_slice",
876+
"//py/koladata/types:mask_constants",
877+
"//py/koladata/types:schema_constants",
878+
"@com_google_absl_py//absl/testing:absltest",
879+
"@com_google_absl_py//absl/testing:parameterized",
880+
"@com_google_arolla//py/arolla",
881+
],
882+
)
883+
853884
proto_library(
854885
name = "test_proto",
855886
srcs = ["test.proto"],

py/koladata/functions/tests/bool_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def test_bool_errors(self, x, expected_error_msg):
4141
with self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
4242
fns.bool(x)
4343

44+
def test_alias(self):
45+
self.assertIs(fns.bool, fns.slices.bool)
46+
4447

4548
if __name__ == '__main__':
4649
absltest.main()

py/koladata/functions/tests/bytes_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def test_bytes_errors(self, x, expected_error_msg):
3939
with self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
4040
fns.bytes(x)
4141

42+
def test_alias(self):
43+
self.assertIs(fns.bytes, fns.slices.bytes)
44+
4245

4346
if __name__ == '__main__':
4447
absltest.main()

py/koladata/functions/tests/expr_quote_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def test_expr_quote_errors(self, x, expected_error_msg):
4444
with self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
4545
fns.expr_quote(x)
4646

47+
def test_alias(self):
48+
self.assertIs(fns.expr_quote, fns.slices.expr_quote)
49+
4750

4851
if __name__ == '__main__':
4952
absltest.main()

py/koladata/functions/tests/float32_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def test_float32_errors(self, x, expected_error_msg):
4949
with self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
5050
fns.float32(x)
5151

52+
def test_alias(self):
53+
self.assertIs(fns.float32, fns.slices.float32)
54+
5255

5356
if __name__ == '__main__':
5457
absltest.main()

py/koladata/functions/tests/float64_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def test_float64_errors(self, x, expected_error_msg):
4949
with self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
5050
fns.float64(x)
5151

52+
def test_alias(self):
53+
self.assertIs(fns.float64, fns.slices.float64)
54+
5255

5356
if __name__ == '__main__':
5457
absltest.main()

py/koladata/functions/tests/from_py_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def test_empty_slice(self):
394394
a=schema_constants.STRING, b=fns.list_schema(schema_constants.INT32)
395395
)
396396
res = fns.from_py([], from_dim=1, schema=schema)
397-
testing.assert_equal(res.no_bag(), ds([], schema))
397+
testing.assert_equal(res.no_bag(), ds([], schema.no_bag()))
398398

399399
def test_obj_reference(self):
400400
obj = fns.obj()

py/koladata/functions/tests/int32_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def test_int32_errors(self, x, expected_error_msg):
4444
with self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
4545
fns.int32(x)
4646

47+
def test_alias(self):
48+
self.assertIs(fns.int32, fns.slices.int32)
49+
4750

4851
if __name__ == '__main__':
4952
absltest.main()

0 commit comments

Comments
 (0)