Skip to content

Commit afae0b0

Browse files
apronchenkovcopybara-github
authored andcommitted
Minor refactoring in the aux-policies in Koladata
* Renamed the "default" classic policy to "classic" for better clarity * Moved policy declarations from py/koladata/types to py/koladata/operators * Updated registration logic (particularly, use pybind11 instead of python c api) PiperOrigin-RevId: 859016491 Change-Id: Ia6d25d89d1c3bdd5e9a4b915d125594ed77a0e59
1 parent 361a2f8 commit afae0b0

21 files changed

+373
-298
lines changed

docs/api_reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4864,7 +4864,7 @@ Returns:
48644864
A decorator that registers an overload for the operator with the
48654865
corresponding name.</code></pre>
48664866

4867-
### `kd.optools.add_to_registry_as_overloadable(name: str, *, unsafe_override: bool = False, view: type[ExprView] | None = <class 'koladata.expr.view.KodaView'>, repr_fn: Union[Callable[[Expr, NodeTokenView], ReprToken], None] = None, aux_policy: str = 'koladata_default_boxing', via_cc_operator_package: bool = False)` {#kd.optools.add_to_registry_as_overloadable}
4867+
### `kd.optools.add_to_registry_as_overloadable(name: str, *, unsafe_override: bool = False, view: type[ExprView] | None = <class 'koladata.expr.view.KodaView'>, repr_fn: Union[Callable[[Expr, NodeTokenView], ReprToken], None] = None, aux_policy: str = 'koladata_classic_aux_policy', via_cc_operator_package: bool = False)` {#kd.optools.add_to_registry_as_overloadable}
48684868

48694869
<pre class="no-copy"><code class="lang-text no-auto-prettify">Koda wrapper around Arolla&#39;s add_to_registry_as_overloadable.
48704870

koladata/expr/expr_operators.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ ToArollaInt64Operator::ToArollaInt64Operator()
161161
: ToArollaValueOperator(
162162
"koda_internal.to_arolla_int64",
163163
arolla::expr::ExprOperatorSignature({{"x"}},
164-
"koladata_default_boxing"),
164+
"koladata_classic_aux_policy"),
165165
"Returns `x` converted into an arolla int64 value.\n"
166166
"\n"
167167
"Note that `x` must adhere to the following requirements:\n"
@@ -180,7 +180,7 @@ ToArollaTextOperator::ToArollaTextOperator()
180180
: ToArollaValueOperator(
181181
"koda_internal.to_arolla_text",
182182
arolla::expr::ExprOperatorSignature({{"x"}},
183-
"koladata_default_boxing"),
183+
"koladata_classic_aux_policy"),
184184
"Returns `x` converted into an arolla text value.\n"
185185
"\n"
186186
"Note that `x` must adhere to the following requirements:\n"

py/koladata/operators/BUILD

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# Arolla Expr Operators on Koda abstractions.
1616

1717
load("@com_google_arolla//py/arolla/dynamic_deps:build_defs.bzl", "arolla_py_cc_deps")
18+
load("@rules_cc//cc:cc_library.bzl", "cc_library")
1819
load("@rules_cc//cc:cc_test.bzl", "cc_test")
1920
load("@rules_python//python:defs.bzl", "py_library", "py_test")
20-
load("//py/koladata/dynamic_deps:py_extension.bzl", "koladata_py_extension")
21+
load("//py/koladata/dynamic_deps:py_extension.bzl", "koladata_pybind_extension")
2122
load("//py/koladata/operators:optools.bzl", "koladata_cc_operator_package")
2223

2324
package(default_visibility = ["//visibility:private"])
@@ -164,13 +165,21 @@ py_library(
164165
],
165166
)
166167

167-
koladata_py_extension(
168-
name = "py_optools_py_ext",
169-
srcs = [
170-
"py_optools_module.cc",
171-
"py_unified_binding_policy.cc",
172-
"py_unified_binding_policy.h",
168+
koladata_pybind_extension(
169+
name = "clib",
170+
srcs = ["clib.cc"],
171+
deps = [
172+
":py_unified_binding_policy",
173+
"//koladata/expr:non_determinism",
174+
"@com_google_absl//absl/strings:string_view",
175+
"@pybind11_abseil//pybind11_abseil:absl_casters",
173176
],
177+
)
178+
179+
cc_library(
180+
name = "py_unified_binding_policy",
181+
srcs = ["py_unified_binding_policy.cc"],
182+
hdrs = ["py_unified_binding_policy.h"],
174183
deps = [
175184
"//koladata/expr:expr_operators",
176185
"//koladata/expr:non_determinism",
@@ -179,7 +188,6 @@ koladata_py_extension(
179188
"@com_google_absl//absl/container:flat_hash_map",
180189
"@com_google_absl//absl/container:inlined_vector",
181190
"@com_google_absl//absl/log:check",
182-
"@com_google_absl//absl/random",
183191
"@com_google_absl//absl/strings",
184192
"@com_google_absl//absl/strings:str_format",
185193
"@com_google_absl//absl/types:span",
@@ -189,15 +197,16 @@ koladata_py_extension(
189197
"@com_google_arolla//arolla/util:status_backport",
190198
"@com_google_arolla//py/arolla/abc:py_abc",
191199
"@com_google_arolla//py/arolla/py_utils",
192-
"@rules_python//python/cc:current_py_cc_headers", # buildcleaner: keep
200+
"@rules_python//python/cc:current_py_cc_headers",
193201
],
194202
)
195203

196204
py_library(
197205
name = "unified_binding_policy",
198206
srcs = ["unified_binding_policy.py"],
199207
deps = [
200-
":py_optools_py_ext",
208+
":aux_policies",
209+
":clib",
201210
"//py:python_path",
202211
"//py/koladata/types:py_boxing",
203212
"@com_google_arolla//py/arolla",
@@ -208,6 +217,7 @@ py_test(
208217
name = "unified_binding_policy_test",
209218
srcs = ["unified_binding_policy_test.py"],
210219
deps = [
220+
":aux_policies",
211221
":op_repr",
212222
":unified_binding_policy",
213223
"//py:python_path",
@@ -221,11 +231,29 @@ py_test(
221231
],
222232
)
223233

234+
py_test(
235+
name = "classic_binding_policy_test",
236+
srcs = ["classic_binding_policy_test.py"],
237+
deps = [
238+
":aux_policies",
239+
":kde_operators",
240+
"//py:python_path",
241+
"//py/koladata/expr:input_container",
242+
"//py/koladata/testing",
243+
"//py/koladata/types:data_slice",
244+
"//py/koladata/types:ellipsis",
245+
"//py/koladata/types:literal_operator",
246+
"@com_google_absl_py//absl/testing:absltest",
247+
"@com_google_arolla//py/arolla",
248+
],
249+
)
250+
224251
py_library(
225252
name = "optools",
226253
srcs = ["optools.py"],
227254
visibility = ["//koladata:internal"],
228255
deps = [
256+
":aux_policies",
229257
":op_repr",
230258
":qtype_utils",
231259
":unified_binding_policy",
@@ -407,6 +435,7 @@ py_library(
407435
deps = [
408436
":arolla_bridge",
409437
":assertion",
438+
":aux_policies",
410439
":jagged_shape",
411440
":masking",
412441
":optools",
@@ -429,11 +458,11 @@ py_library(
429458
srcs = ["tuple.py"],
430459
deps = [
431460
":arolla_bridge",
461+
":aux_policies",
432462
":optools",
433463
":qtype_utils",
434464
":view_overloads",
435465
"//py:python_path",
436-
"//py/koladata/types:py_boxing",
437466
"//py/koladata/types:schema_constants",
438467
"@com_google_arolla//py/arolla",
439468
"@com_google_arolla//py/arolla/jagged_shape",
@@ -446,6 +475,7 @@ py_library(
446475
deps = [
447476
":arolla_bridge",
448477
":assertion",
478+
":aux_policies",
449479
":bags",
450480
":entities",
451481
":masking",
@@ -605,6 +635,7 @@ py_library(
605635
deps = [
606636
":arolla_bridge",
607637
":assertion",
638+
":aux_policies",
608639
":comparison",
609640
":jagged_shape",
610641
":masking",
@@ -667,6 +698,7 @@ py_library(
667698
name = "op_repr",
668699
srcs = ["op_repr.py"],
669700
deps = [
701+
":aux_policies",
670702
":unified_binding_policy",
671703
"//py:python_path",
672704
"//py/koladata/types:data_slice",
@@ -795,11 +827,11 @@ py_library(
795827
name = "bags",
796828
srcs = ["bags.py"],
797829
deps = [
830+
":aux_policies",
798831
":optools",
799832
":qtype_utils",
800833
":view_overloads",
801834
"//py:python_path",
802-
"//py/koladata/types:py_boxing",
803835
"//py/koladata/types:qtypes",
804836
"//py/koladata/types:schema_constants",
805837
"@com_google_arolla//py/arolla",
@@ -932,6 +964,17 @@ py_library(
932964
],
933965
)
934966

967+
py_library(
968+
name = "aux_policies",
969+
srcs = ["aux_policies.py"],
970+
deps = [
971+
":clib",
972+
"//py:python_path",
973+
"//py/koladata/types:py_boxing",
974+
"@com_google_arolla//py/arolla",
975+
],
976+
)
977+
935978
py_test(
936979
name = "eager_op_utils_test",
937980
srcs = ["eager_op_utils_test.py"],
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Module with deqring auxiliary policies for Koladata."""
16+
17+
from arolla import arolla
18+
from koladata.operators import clib
19+
from koladata.types import py_boxing
20+
21+
# The default auxiliary policy for Koladata operators.
22+
#
23+
# This policy implements binding rules that support positional-only,
24+
# keyword-only, variadic positional, and variadic keyword parameters.
25+
# It also supports per-parameter boxing control (defaulting to Koladata-specific
26+
# rules) and allows expressing non-deterministic operator behaviour.
27+
#
28+
# NOTE: Non-variadic parameters are effectively compatible with classic
29+
# Arolla binding rules and are represented as positional-or-keyword parameters.
30+
# For example, given:
31+
#
32+
# def op(a, *, b, c=default): ...
33+
#
34+
# op(a, b=b)
35+
#
36+
# is effectively equivalent to
37+
#
38+
# arolla.abc.bind_op(op, a, b=b)
39+
#
40+
# However, variadic parameters are represented differently: positional
41+
# variadic arguments are packed into a tuple, and keyword variadic
42+
# arguments are packed into a namedtuple. For example, given:
43+
#
44+
# def op(*args, **kwargs): ...
45+
#
46+
# op(1, 2, x=3, y=4)
47+
#
48+
# is equivalent to:
49+
#
50+
# arolla.abc.bind_op(op, args=arolla.tuple(1, 2),
51+
# kwargs=arolla.namedtuple(x=3, y=4))
52+
#
53+
# The binding rules implementation can be found in:
54+
# //py/koladata/operators/py_unified_binding_policy.cc
55+
UNIFIED_AUX_POLICY = 'koladata_unified_aux_policy'
56+
57+
58+
# The classic Arolla binding rules with Koladata-specific boxing: it wraps most
59+
# values into DataItems (excluding lists and tuples). QValues remain unchanged.
60+
CLASSIC_AUX_POLICY = 'koladata_classic_aux_policy'
61+
62+
63+
clib.register_unified_aux_binding_policy(UNIFIED_AUX_POLICY)
64+
65+
arolla.abc.register_classic_aux_binding_policy_with_custom_boxing(
66+
CLASSIC_AUX_POLICY,
67+
py_boxing.as_qvalue_or_expr,
68+
make_literal_fn=py_boxing.literal,
69+
)

py/koladata/operators/bags.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
from arolla import arolla as _arolla
1818
from arolla.jagged_shape import jagged_shape as _jagged_shape
19+
from koladata.operators import aux_policies as _aux_policies
1920
from koladata.operators import optools as _optools
2021
from koladata.operators import qtype_utils as _qtype_utils
2122
from koladata.operators import view_overloads as _
22-
from koladata.types import py_boxing as _py_boxing
2323
from koladata.types import qtypes as _qtypes
2424
from koladata.types import schema_constants as _schema_constants
2525

@@ -62,7 +62,7 @@ def is_null_bag(bag): # pylint: disable=unused-argument
6262
_qtype_utils.expect_data_bag_args(_P.bags),
6363
],
6464
qtype_inference_expr=_qtypes.DATA_BAG,
65-
experimental_aux_policy=_py_boxing.DEFAULT_BOXING_POLICY,
65+
experimental_aux_policy=_aux_policies.CLASSIC_AUX_POLICY,
6666
)
6767
def enriched(*bags): # pylint: disable=unused-argument
6868
"""Creates a new immutable DataBag enriched by `bags`.
@@ -93,7 +93,7 @@ def enriched(*bags): # pylint: disable=unused-argument
9393
_qtype_utils.expect_data_bag_args(_P.bags),
9494
],
9595
qtype_inference_expr=_qtypes.DATA_BAG,
96-
experimental_aux_policy=_py_boxing.DEFAULT_BOXING_POLICY,
96+
experimental_aux_policy=_aux_policies.CLASSIC_AUX_POLICY,
9797
)
9898
def updated(*bags): # pylint: disable=unused-argument
9999
"""Creates a new immutable DataBag updated by `bags`.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
17+
from absl.testing import absltest
18+
from arolla import arolla
19+
from koladata.expr import input_container
20+
from koladata.operators import aux_policies
21+
from koladata.operators import kde_operators
22+
from koladata.testing import testing
23+
from koladata.types import data_slice
24+
from koladata.types import ellipsis
25+
from koladata.types import literal_operator
26+
27+
28+
ds = data_slice.DataSlice.from_vals
29+
I = input_container.InputContainer('I')
30+
kde = kde_operators.kde
31+
32+
33+
@arolla.optools.as_lambda_operator(
34+
'op_with_classic_binding_policy',
35+
experimental_aux_policy=aux_policies.CLASSIC_AUX_POLICY,
36+
)
37+
def op_with_classic_binding_policy(x, y):
38+
return (x, y)
39+
40+
41+
class ClassicBindingPolicyTest(absltest.TestCase):
42+
43+
def test_basic(self):
44+
expr = op_with_classic_binding_policy(1, 2)
45+
testing.assert_equal(
46+
expr,
47+
arolla.abc.bind_op(
48+
op_with_classic_binding_policy,
49+
literal_operator.literal(ds(1)),
50+
literal_operator.literal(ds(2)),
51+
),
52+
)
53+
54+
def test_with_slice(self):
55+
expr = op_with_classic_binding_policy(1, slice(1, None, 2))
56+
testing.assert_equal(
57+
expr,
58+
arolla.abc.bind_op(
59+
op_with_classic_binding_policy,
60+
literal_operator.literal(ds(1)),
61+
literal_operator.literal(arolla.types.Slice(ds(1), None, ds(2))),
62+
),
63+
)
64+
65+
def test_with_ellipsis(self):
66+
expr = op_with_classic_binding_policy(1, ...)
67+
testing.assert_equal(
68+
expr,
69+
arolla.abc.bind_op(
70+
op_with_classic_binding_policy,
71+
literal_operator.literal(ds(1)),
72+
literal_operator.literal(ellipsis.ellipsis()),
73+
),
74+
)
75+
76+
def test_missing_argument(self):
77+
with self.assertRaisesWithLiteralMatch(
78+
TypeError, "missing 1 required positional argument: 'y'"
79+
):
80+
op_with_classic_binding_policy(1)
81+
82+
def test_list_unsupported(self):
83+
with self.assertRaisesRegex(ValueError, re.escape('list')):
84+
op_with_classic_binding_policy(1, [2, 3, 4])
85+
86+
87+
if __name__ == '__main__':
88+
absltest.main()

0 commit comments

Comments
 (0)