Skip to content

Commit 2fd4d39

Browse files
apronchenkovcopybara-github
authored andcommitted
Use aux-policies for View assignment in Koladata
Motivation: This change enables the use of ad-hoc operators in Koladata without registration. Previously, unregistered operators defined using `@kd.optools.as_*_operator` defaulted to `arolla.expr.DefaultExprView`, which limited their usability. This change sets KodaView as the default (and allows for further customization). PiperOrigin-RevId: 862469628 Change-Id: Ib3ac34a4b5642e38edc4e26995e0c90ac2109027
1 parent e77838b commit 2fd4d39

23 files changed

+441
-358
lines changed

docs/api_reference.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4823,15 +4823,14 @@ Operator definition and registration tooling.
48234823

48244824
<pre class="no-copy"><code class="lang-text no-auto-prettify">Adds an alias for an operator.</code></pre>
48254825

4826-
### `kd.optools.add_to_registry(name: str | None = None, *, aliases: Collection[str] = (), unsafe_override: bool = False, view: type[ExprView] = <class 'koladata.expr.view.KodaView'>, repr_fn: Callable[[Expr, NodeTokenView], ReprToken] = <default_op_repr>, via_cc_operator_package: bool = False)` {#kd.optools.add_to_registry}
4826+
### `kd.optools.add_to_registry(name: str | None = None, *, aliases: Collection[str] = (), unsafe_override: bool = False, repr_fn: Callable[[Expr, NodeTokenView], ReprToken] = <default_op_repr>, via_cc_operator_package: bool = False)` {#kd.optools.add_to_registry}
48274827

48284828
<pre class="no-copy"><code class="lang-text no-auto-prettify">Wrapper around Arolla&#39;s add_to_registry with Koda functionality.
48294829

48304830
Args:
48314831
name: Optional name of the operator. Otherwise, inferred from the op.
48324832
aliases: Optional aliases for the operator.
48334833
unsafe_override: Whether to override an existing operator.
4834-
view: View for the operator and its aliases.
48354834
repr_fn: Repr function for the operator and its aliases.
48364835
via_cc_operator_package: If True, the operator will be only registered
48374836
during koladata_cc_operator_package construction, and just looked up in
@@ -4867,7 +4866,7 @@ Returns:
48674866
A decorator that registers an overload for the operator with the
48684867
corresponding name.</code></pre>
48694868

4870-
### `kd.optools.add_to_registry_as_overloadable(name: str, *, unsafe_override: bool = False, view: type[ExprView] = <class 'koladata.expr.view.KodaView'>, repr_fn: Callable[[Expr, NodeTokenView], ReprToken] = <default_op_repr>, aux_policy: str = 'koladata_classic_aux_policy', via_cc_operator_package: bool = False)` {#kd.optools.add_to_registry_as_overloadable}
4869+
### `kd.optools.add_to_registry_as_overloadable(name: str, *, unsafe_override: bool = False, repr_fn: Callable[[Expr, NodeTokenView], ReprToken] = <default_op_repr>, aux_policy: str = 'koladata_classic_aux_policy', via_cc_operator_package: bool = False)` {#kd.optools.add_to_registry_as_overloadable}
48714870

48724871
<pre class="no-copy"><code class="lang-text no-auto-prettify">Koda wrapper around Arolla&#39;s add_to_registry_as_overloadable.
48734872

@@ -4877,7 +4876,6 @@ repr function.
48774876
Args:
48784877
name: Name of the operator.
48794878
unsafe_override: Whether to override an existing operator.
4880-
view: View for the operator and its aliases.
48814879
repr_fn: Repr function for the operator and its aliases.
48824880
aux_policy: Aux policy for the operator.
48834881
via_cc_operator_package: If True, the operator will be only registered
@@ -4890,7 +4888,7 @@ Args:
48904888
Returns:
48914889
An overloadable registered operator.</code></pre>
48924890

4893-
### `kd.optools.as_backend_operator(name: str, *, qtype_inference_expr: Expr | QType = DATA_SLICE, qtype_constraints: Iterable[tuple[Expr, str]] = (), deterministic: bool = True, custom_boxing_fn_name_per_parameter: dict[str, str] | None = None) -> Callable[[function], BackendOperator]` {#kd.optools.as_backend_operator}
4891+
### `kd.optools.as_backend_operator(name: str, *, qtype_inference_expr: Expr | QType = DATA_SLICE, qtype_constraints: Iterable[tuple[Expr, str]] = (), deterministic: bool = True, custom_boxing_fn_name_per_parameter: dict[str, str] | None = None, view: str | type[ExprView] = '') -> Callable[[function], BackendOperator]` {#kd.optools.as_backend_operator}
48944892

48954893
<pre class="no-copy"><code class="lang-text no-auto-prettify">Decorator for Koladata backend operators with a unified binding policy.
48964894

@@ -4911,12 +4909,14 @@ Args:
49114909
custom_boxing_fn_name_per_parameter: A dictionary specifying a custom boxing
49124910
function per parameter (constants with the boxing functions look like:
49134911
`koladata.types.py_boxing.WITH_*`, e.g. `WITH_PY_FUNCTION_TO_PY_OBJECT`).
4912+
view: The view for the for the operator, with the default being KodaView
4913+
(supported values: &#39;&#39;|KodaView, &#39;base&#39;|BaseKodaView, &#39;arolla&#39;|ArollaView).
49144914

49154915
Returns:
49164916
A decorator that constructs a backend operator based on the provided Python
49174917
function signature.</code></pre>
49184918

4919-
### `kd.optools.as_lambda_operator(name: str, *, qtype_constraints: Iterable[tuple[Expr, str]] = (), deterministic: bool | None = None, custom_boxing_fn_name_per_parameter: dict[str, str] | None = None, suppress_unused_parameter_warning: bool = False) -> Callable[[function], LambdaOperator | RestrictedLambdaOperator]` {#kd.optools.as_lambda_operator}
4919+
### `kd.optools.as_lambda_operator(name: str, *, qtype_constraints: Iterable[tuple[Expr, str]] = (), deterministic: bool | None = None, custom_boxing_fn_name_per_parameter: dict[str, str] | None = None, suppress_unused_parameter_warning: bool = False, view: str | type[ExprView] = '') -> Callable[[function], LambdaOperator | RestrictedLambdaOperator]` {#kd.optools.as_lambda_operator}
49204920

49214921
<pre class="no-copy"><code class="lang-text no-auto-prettify">Decorator for Koladata lambda operators with a unified binding policy.
49224922

@@ -4937,11 +4937,13 @@ Args:
49374937
`koladata.types.py_boxing.WITH_*`, e.g. `WITH_PY_FUNCTION_TO_PY_OBJECT`).
49384938
suppress_unused_parameter_warning: If True, unused parameters will not cause
49394939
a warning.
4940+
view: The view for the for the operator, with the default being KodaView
4941+
(supported values: &#39;&#39;|KodaView, &#39;base&#39;|BaseKodaView, &#39;arolla&#39;|ArollaView).
49404942

49414943
Returns:
49424944
A decorator that constructs a lambda operator by tracing a Python function.</code></pre>
49434945

4944-
### `kd.optools.as_py_function_operator(name: str, *, qtype_inference_expr: Expr | QType = DATA_SLICE, qtype_constraints: Iterable[tuple[Expr, str]] = (), codec: bytes | None = None, deterministic: bool = True, custom_boxing_fn_name_per_parameter: dict[str, str] | None = None) -> Callable[[function], Operator]` {#kd.optools.as_py_function_operator}
4946+
### `kd.optools.as_py_function_operator(name: str, *, qtype_inference_expr: Expr | QType = DATA_SLICE, qtype_constraints: Iterable[tuple[Expr, str]] = (), codec: bytes | None = None, deterministic: bool = True, custom_boxing_fn_name_per_parameter: dict[str, str] | None = None, view: str | type[ExprView] = '') -> Callable[[function], Operator]` {#kd.optools.as_py_function_operator}
49454947

49464948
<pre class="no-copy"><code class="lang-text no-auto-prettify">Returns a decorator for defining Koladata-specific py-function operators.
49474949

@@ -4965,7 +4967,9 @@ Args:
49654967
(i.e., non-deterministic or has side effects).
49664968
custom_boxing_fn_name_per_parameter: A dictionary specifying a custom boxing
49674969
function per parameter (constants with the boxing functions look like:
4968-
`koladata.types.py_boxing.WITH_*`, e.g. `WITH_PY_FUNCTION_TO_PY_OBJECT`).</code></pre>
4970+
`koladata.types.py_boxing.WITH_*`, e.g. `WITH_PY_FUNCTION_TO_PY_OBJECT`).
4971+
view: The view for the for the operator, with the default being KodaView
4972+
(supported values: &#39;&#39;|KodaView, &#39;base&#39;|BaseKodaView, &#39;arolla&#39;|ArollaView).</code></pre>
49694973

49704974
### `kd.optools.as_qvalue(arg: Any) -> QValue` {#kd.optools.as_qvalue}
49714975

koladata/expr/expr_operators.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ class ToArollaValueOperator final
8080
arolla::QTypePtr output_qtype)
8181
: ExprOperatorWithFixedSignature(
8282
name,
83-
arolla::expr::ExprOperatorSignature({{"x"}},
84-
"koladata_classic_aux_policy"),
83+
arolla::expr::ExprOperatorSignature(
84+
// Use aux-policy with ArollaView.
85+
{{"x"}}, "koladata_unified_aux_policy$arolla:_"),
8586
doc,
8687
arolla::FingerprintHasher("::koladata::expr::ToArollaValueOperator")
8788
.Combine(name, doc, output_qtype)
@@ -116,8 +117,9 @@ class ToArollaValueOperator final
116117
InputOperator::InputOperator()
117118
: arolla::expr::ExprOperatorWithFixedSignature(
118119
kInternalInput,
119-
arolla::expr::ExprOperatorSignature{{"container_name"},
120-
{"input_key"}},
120+
arolla::expr::ExprOperatorSignature(
121+
{{"container_name"}, {"input_key"}},
122+
"koladata_arolla_classic_aux_policy"),
121123
"Koda input with DATA_SLICE qtype.\n"
122124
"\n"
123125
"Note that this operator cannot be evaluated.\n"
@@ -161,7 +163,10 @@ absl::StatusOr<std::shared_ptr<LiteralOperator>> LiteralOperator::Make(
161163
LiteralOperator::LiteralOperator(PrivateConstructorTag,
162164
arolla::TypedValue value)
163165
: arolla::expr::ExprOperatorWithFixedSignature(
164-
"koda_internal.literal", arolla::expr::ExprOperatorSignature{},
166+
"koda_internal.literal",
167+
arolla::expr::ExprOperatorSignature(
168+
// Use aux-policy with BaseKodaView.
169+
{}, "koladata_unified_aux_policy$base"),
165170
"Koda literal.",
166171
arolla::FingerprintHasher("::koladata::expr::LiteralOperator")
167172
.Combine(value.GetFingerprint())

koladata/expr/init_expr_operators.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ AROLLA_INITIALIZER(
4343
arolla::expr::Literal(internal::Ellipsis{})))
4444
.status());
4545
RETURN_IF_ERROR(
46-
arolla::expr::RegisterOperator("koda_internal.with_name",
46+
arolla::expr::RegisterOperator("kd.annotation.with_name",
4747
MakeNameAnnotationOperator())
4848
.status());
4949
RETURN_IF_ERROR(

py/koladata/expr/view.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,12 @@ def uuobj(self, *args, **kwargs): # pylint: disable=unused-argument
753753
_raise_eager_only_method('uuobj', 'DataBag')
754754

755755

756+
class ArollaView(BaseKodaView, arolla.expr.DefaultExprView):
757+
"""Arolla view with a few extra methods specific to Koda."""
758+
759+
_arolla_view_tag = True
760+
761+
756762
def has_koda_view(node: arolla.Expr) -> bool:
757763
"""Returns true iff the node has the full koda view."""
758764
return hasattr(node, '_koda_view_tag')
@@ -763,21 +769,11 @@ def has_base_koda_view(node: arolla.Expr) -> bool:
763769
return hasattr(node, '_base_koda_view_tag')
764770

765771

766-
class ArollaView(BaseKodaView, arolla.expr.DefaultExprView):
767-
"""Arolla view with a few extra methods specific to Koda."""
768-
769-
pass
772+
def has_arolla_view(node: arolla.Expr) -> bool:
773+
"""Returns true iff the node has the arolla view."""
774+
return hasattr(node, '_arolla_view_tag')
770775

771776

772777
arolla.abc.set_expr_view_for_qtype(qtypes.DATA_SLICE, KodaView)
773778
arolla.abc.set_expr_view_for_qtype(qtypes.DATA_BAG, KodaView)
774779
arolla.abc.set_expr_view_for_qtype(qtypes.JAGGED_SHAPE, KodaView)
775-
arolla.abc.set_expr_view_for_registered_operator(
776-
'koda_internal.input', KodaView
777-
)
778-
779-
# NOTE: This attaches a BaseKodaView to all literals, including Arolla values.
780-
# This adds generic Koda methods (such as eval) to the expression.
781-
arolla.abc.set_expr_view_for_operator_family(
782-
'::koladata::expr::LiteralOperator', ArollaView
783-
)

py/koladata/expr/view_test.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,11 @@ class BaseKodaViewTest(parameterized.TestCase):
4545

4646
def setUp(self):
4747
super().setUp()
48-
arolla.abc.set_expr_view_for_registered_operator(
49-
'test.op', view.BaseKodaView
50-
)
48+
arolla.abc.set_expr_view_for_registered_operator(op, view.BaseKodaView)
5149

5250
def tearDown(self):
5351
# Clear the view.
54-
arolla.abc.set_expr_view_for_registered_operator('test.op', None)
52+
arolla.abc.set_expr_view_for_registered_operator(op, None)
5553
super().tearDown()
5654

5755
def test_eval(self):
@@ -83,11 +81,11 @@ class KodaViewTest(parameterized.TestCase):
8381

8482
def setUp(self):
8583
super().setUp()
86-
arolla.abc.set_expr_view_for_registered_operator('test.op', view.KodaView)
84+
arolla.abc.set_expr_view_for_registered_operator(op, view.KodaView)
8785

8886
def tearDown(self):
8987
# Clear the view.
90-
arolla.abc.set_expr_view_for_registered_operator('test.op', None)
88+
arolla.abc.set_expr_view_for_registered_operator(op, None)
9189
super().tearDown()
9290

9391
# To be overridden by KodaViewWithTracingTest subclass below.
@@ -646,7 +644,7 @@ def test_make_tuple(*args):
646644
return arolla.optools.fix_trace_args(args)
647645

648646
arolla.abc.set_expr_view_for_registered_operator(
649-
'test_make_tuple', view.KodaView
647+
test_make_tuple, view.KodaView
650648
)
651649

652650
expr = arolla.M.annotation.name(
@@ -896,6 +894,28 @@ def test_annotation_source_location_view(self):
896894
)
897895

898896

897+
class ArollaViewTest(parameterized.TestCase):
898+
899+
def setUp(self):
900+
super().setUp()
901+
arolla.abc.set_expr_view_for_registered_operator(op, view.ArollaView)
902+
903+
def tearDown(self):
904+
arolla.abc.set_expr_view_for_registered_operator(op, None)
905+
super().tearDown()
906+
907+
def test_has_arolla_view(self):
908+
self.assertTrue(view.has_arolla_view(op()))
909+
self.assertTrue(view.has_base_koda_view(op()))
910+
self.assertFalse(view.has_koda_view(op()))
911+
912+
def test_less_operator(self):
913+
arolla.testing.assert_expr_equal_by_fingerprint(
914+
op(C.x) < op(C.y),
915+
arolla.abc.unsafe_parse_sexpr((('core.less', (op, C.x), (op, C.y)))),
916+
)
917+
918+
899919
class KodaViewWithTracingTest(KodaViewTest):
900920

901921
def setUp(self):

py/koladata/operators/BUILD

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ py_library(
205205
name = "unified_binding_policy",
206206
srcs = ["unified_binding_policy.py"],
207207
deps = [
208-
":aux_policies",
209208
":clib",
210209
"//py:python_path",
211210
"//py/koladata/types:py_boxing",
@@ -218,6 +217,7 @@ py_test(
218217
srcs = ["unified_binding_policy_test.py"],
219218
deps = [
220219
":aux_policies",
220+
":clib",
221221
":op_repr",
222222
":unified_binding_policy",
223223
"//py:python_path",
@@ -260,7 +260,6 @@ py_library(
260260
"//py:python_path",
261261
"//py/koladata/expr:input_container",
262262
"//py/koladata/expr:tracing_mode",
263-
"//py/koladata/expr:view",
264263
"//py/koladata/types:py_boxing",
265264
"//py/koladata/types:qtypes",
266265
"@com_google_arolla//py/arolla",
@@ -271,12 +270,12 @@ py_test(
271270
name = "optools_test",
272271
srcs = ["optools_test.py"],
273272
deps = [
273+
":aux_policies",
274274
":comparison",
275275
":jagged_shape",
276276
":koda_internal",
277277
":math",
278278
":optools",
279-
":optools_test_utils",
280279
":qtype_utils",
281280
":tuple",
282281
"//py:python_path",
@@ -444,6 +443,7 @@ py_library(
444443
":slices",
445444
"//koladata/operators",
446445
"//py:python_path",
446+
"//py/koladata/expr:view",
447447
"//py/koladata/fstring",
448448
"//py/koladata/types:data_slice",
449449
"//py/koladata/types:py_boxing",
@@ -548,6 +548,7 @@ py_library(
548548
":qtype_utils",
549549
"//koladata/operators",
550550
"//py:python_path",
551+
"//py/koladata/expr:view",
551552
"//py/koladata/types:data_slice",
552553
"//py/koladata/types:py_boxing",
553554
"//py/koladata/types:qtypes",
@@ -647,6 +648,7 @@ py_library(
647648
":schema",
648649
":view_overloads",
649650
"//py:python_path",
651+
"//py/koladata/expr:view",
650652
"//py/koladata/types:data_item",
651653
"//py/koladata/types:data_slice",
652654
"//py/koladata/types:py_boxing",
@@ -799,6 +801,7 @@ py_library(
799801
":slices",
800802
":view_overloads",
801803
"//py:python_path",
804+
"//py/koladata/expr:view",
802805
"//py/koladata/types:data_slice",
803806
"//py/koladata/types:py_boxing",
804807
"//py/koladata/types:qtypes",
@@ -972,11 +975,24 @@ py_library(
972975
deps = [
973976
":clib",
974977
"//py:python_path",
978+
"//py/koladata/expr:view",
975979
"//py/koladata/types:py_boxing",
976980
"@com_google_arolla//py/arolla",
977981
],
978982
)
979983

984+
py_test(
985+
name = "aux_policies_test",
986+
srcs = ["aux_policies_test.py"],
987+
deps = [
988+
":aux_policies",
989+
"//py:python_path",
990+
"//py/koladata/expr:view",
991+
"@com_google_absl_py//absl/testing:absltest",
992+
"@com_google_absl_py//absl/testing:parameterized",
993+
],
994+
)
995+
980996
py_test(
981997
name = "eager_op_utils_test",
982998
srcs = ["eager_op_utils_test.py"],

py/koladata/operators/annotation.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717
from arolla import arolla
1818
from koladata.operators import optools
1919

20-
2120
# NOTE: Implemented in C++.
2221
source_location = arolla.abc.lookup_operator('kd.annotation.source_location')
22+
with_name = arolla.abc.lookup_operator('kd.annotation.with_name')
2323

24-
25-
with_name = optools.add_to_registry(
26-
name='kd.annotation.with_name',
27-
aliases=['kd.with_name'],
28-
via_cc_operator_package=True,
29-
)(arolla.abc.lookup_operator('koda_internal.with_name'))
24+
optools.add_to_registry('kd.with_name', via_cc_operator_package=True)(with_name)

0 commit comments

Comments
 (0)