Skip to content

Commit e5bbf3a

Browse files
apronchenkovcopybara-github
authored andcommitted
Simplify operator implementations in expr_operators.cc.
Notably, instead of using a custom binding policy for the annotation.with_name operator, this changelist configures an existing binding policy to serve the same role. PiperOrigin-RevId: 860047013 Change-Id: Ib8b12da0241635251efdf662ee8ee2c7e1d622fd
1 parent b63bb85 commit e5bbf3a

File tree

6 files changed

+139
-171
lines changed

6 files changed

+139
-171
lines changed

koladata/expr/expr_operators.cc

Lines changed: 95 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "absl/strings/str_format.h"
2727
#include "absl/strings/string_view.h"
2828
#include "absl/types/span.h"
29+
#include "arolla/expr/annotation_expr_operators.h"
2930
#include "arolla/expr/basic_expr_operator.h"
3031
#include "arolla/expr/expr.h"
3132
#include "arolla/expr/expr_attributes.h"
@@ -68,6 +69,48 @@ absl::Status ValidateTextLiteral(const arolla::expr::ExprAttributes& attr,
6869
return absl::OkStatus();
6970
}
7071

72+
// Basic implementation of non-lowerable operators that convert a DataSlice to
73+
// an Arolla type T. Supports compile-time evaluation if the provided input is
74+
// a literal.
75+
class ToArollaValueOperator final
76+
: public arolla::expr::BackendExprOperatorTag,
77+
public arolla::expr::ExprOperatorWithFixedSignature {
78+
public:
79+
ToArollaValueOperator(absl::string_view name, absl::string_view doc,
80+
arolla::QTypePtr output_qtype)
81+
: ExprOperatorWithFixedSignature(
82+
name,
83+
arolla::expr::ExprOperatorSignature({{"x"}},
84+
"koladata_classic_aux_policy"),
85+
doc,
86+
arolla::FingerprintHasher("::koladata::expr::ToArollaValueOperator")
87+
.Combine(name, doc, output_qtype)
88+
.Finish()),
89+
output_qtype_(output_qtype) {}
90+
91+
absl::StatusOr<arolla::expr::ExprAttributes> InferAttributes(
92+
absl::Span<const arolla::expr::ExprAttributes> inputs) const final {
93+
RETURN_IF_ERROR(ValidateOpInputsCount(inputs));
94+
if (auto* x_qtype = inputs[0].qtype()) {
95+
if (x_qtype != arolla::GetQType<DataSlice>()) {
96+
return absl::InvalidArgumentError(
97+
absl::StrFormat("expected DATA_SLICE, got x: %s", x_qtype->name()));
98+
}
99+
} else {
100+
return arolla::expr::ExprAttributes{}; // Not ready.
101+
}
102+
if (const auto& x = inputs[0].qvalue()) {
103+
ASSIGN_OR_RETURN(auto result, arolla::InvokeOperator(display_name(), {*x},
104+
output_qtype_));
105+
return arolla::expr::ExprAttributes{std::move(result)};
106+
}
107+
return arolla::expr::ExprAttributes{output_qtype_}; // Can not evaluated.
108+
}
109+
110+
private:
111+
arolla::QTypePtr output_qtype_;
112+
};
113+
71114
} // namespace
72115

73116
InputOperator::InputOperator()
@@ -98,11 +141,14 @@ absl::StatusOr<arolla::expr::ExprAttributes> InputOperator::InferAttributes(
98141
return arolla::expr::ExprAttributes{};
99142
}
100143

101-
absl::StatusOr<arolla::expr::ExprNodePtr> MakeLiteral(
102-
arolla::TypedValue value) {
103-
ASSIGN_OR_RETURN(auto op, expr::LiteralOperator::Make(value));
104-
return arolla::expr::ExprNode::UnsafeMakeOperatorNode(
105-
std::move(op), {}, arolla::expr::ExprAttributes(std::move(value)));
144+
bool IsInput(const arolla::expr::ExprNodePtr& node) {
145+
if (!node->is_op()) {
146+
return false;
147+
}
148+
return nullptr != arolla::fast_dynamic_downcast_final<const InputOperator*>(
149+
arolla::expr::DecayRegisteredOperator(node->op())
150+
.value_or(nullptr)
151+
.get());
106152
}
107153

108154
absl::StatusOr<std::shared_ptr<LiteralOperator>> LiteralOperator::Make(
@@ -132,68 +178,18 @@ absl::string_view LiteralOperator::py_qvalue_specialization_key() const {
132178
return "::koladata::expr::LiteralOperator";
133179
}
134180

135-
const arolla::TypedValue& LiteralOperator::value() const { return value_; }
136-
137-
absl::StatusOr<arolla::expr::ExprAttributes>
138-
ToArollaValueOperator::InferAttributes(
139-
absl::Span<const arolla::expr::ExprAttributes> inputs) const {
140-
RETURN_IF_ERROR(ValidateOpInputsCount(inputs));
141-
if (!inputs[0].qtype()) {
142-
return arolla::expr::ExprAttributes{}; // Not ready yet.
143-
}
144-
if (inputs[0].qtype() != arolla::GetQType<DataSlice>()) {
145-
return absl::InvalidArgumentError(absl::StrFormat(
146-
"expected DATA_SLICE, got x: %s", inputs[0].qtype()->name()));
147-
}
148-
// Eval if possible.
149-
if (inputs[0].qvalue()) {
150-
ASSIGN_OR_RETURN(
151-
auto casted_value,
152-
arolla::InvokeOperator(backend_operator_name_, {*inputs[0].qvalue()},
153-
output_qtype_));
154-
return arolla::expr::ExprAttributes{std::move(casted_value)};
155-
}
156-
// Otherwise, return the output qtype.
157-
return arolla::expr::ExprAttributes{output_qtype_};
181+
absl::StatusOr<arolla::expr::ExprNodePtr> MakeLiteral(
182+
arolla::TypedValue value) {
183+
ASSIGN_OR_RETURN(auto op, expr::LiteralOperator::Make(value));
184+
return arolla::expr::ExprNode::UnsafeMakeOperatorNode(
185+
std::move(op), {}, arolla::expr::ExprAttributes(std::move(value)));
158186
}
159187

160-
ToArollaInt64Operator::ToArollaInt64Operator()
161-
: ToArollaValueOperator(
162-
"koda_internal.to_arolla_int64",
163-
arolla::expr::ExprOperatorSignature({{"x"}},
164-
"koladata_classic_aux_policy"),
165-
"Returns `x` converted into an arolla int64 value.\n"
166-
"\n"
167-
"Note that `x` must adhere to the following requirements:\n"
168-
"* `rank = 0`.\n"
169-
"* Have one of the following schemas: NONE, INT32, INT64, OBJECT.\n"
170-
"* Have a present value with type INT32 or INT64.\n"
171-
"\n"
172-
"In all other cases, an exception is raised.\n\n"
173-
"Args:\n"
174-
" x: A DataItem to be converted into an arolla int64 value.",
175-
arolla::FingerprintHasher("::koladata::expr::ToArollaInt64Operator")
176-
.Finish(),
177-
"koda_internal.to_arolla_int64", arolla::GetQType<int64_t>()) {}
178-
179-
ToArollaTextOperator::ToArollaTextOperator()
180-
: ToArollaValueOperator(
181-
"koda_internal.to_arolla_text",
182-
arolla::expr::ExprOperatorSignature({{"x"}},
183-
"koladata_classic_aux_policy"),
184-
"Returns `x` converted into an arolla text value.\n"
185-
"\n"
186-
"Note that `x` must adhere to the following requirements:\n"
187-
"* `rank = 0`.\n"
188-
"* Have one of the following schemas: NONE, STRING, OBJECT, ANY.\n"
189-
"* Have a present value with type TEXT.\n"
190-
"\n"
191-
"In all other cases, an exception is raised.\n\n"
192-
"Args:\n"
193-
" x: A DataItem to be converted into an arolla text value.",
194-
arolla::FingerprintHasher("::koladata::expr::ToArollaTextOperator")
195-
.Finish(),
196-
"koda_internal.to_arolla_text", arolla::GetQType<arolla::Text>()) {}
188+
bool IsLiteral(const arolla::expr::ExprNodePtr& node) {
189+
return node->is_literal() ||
190+
(arolla::fast_dynamic_downcast_final<const LiteralOperator*>(
191+
node->op().get()) != nullptr);
192+
}
197193

198194
NonDeterministicOperator::NonDeterministicOperator()
199195
: arolla::expr::ExprOperatorWithFixedSignature(
@@ -211,20 +207,42 @@ NonDeterministicOperator::InferAttributes(
211207
arolla::GetQType<internal::NonDeterministicToken>());
212208
}
213209

214-
bool IsInput(const arolla::expr::ExprNodePtr& node) {
215-
if (!node->is_op()) {
216-
return false;
217-
}
218-
return nullptr != arolla::fast_dynamic_downcast_final<const InputOperator*>(
219-
arolla::expr::DecayRegisteredOperator(node->op())
220-
.value_or(nullptr)
221-
.get());
210+
arolla::expr::ExprOperatorPtr MakeToArollaInt64Operator() {
211+
return std::make_shared<ToArollaValueOperator>(
212+
"koda_internal.to_arolla_int64",
213+
"Returns `x` converted into an arolla int64 value.\n"
214+
"\n"
215+
"Note that `x` must adhere to the following requirements:\n"
216+
"* `rank = 0`.\n"
217+
"* Have one of the following schemas: NONE, INT32, INT64, OBJECT.\n"
218+
"* Have a present value with type INT32 or INT64.\n"
219+
"\n"
220+
"In all other cases, an exception is raised.\n\n"
221+
"Args:\n"
222+
" x: A DataItem to be converted into an arolla int64 value.",
223+
arolla::GetQType<int64_t>());
222224
}
223225

224-
bool IsLiteral(const arolla::expr::ExprNodePtr& node) {
225-
return node->is_literal() ||
226-
(arolla::fast_dynamic_downcast_final<const LiteralOperator*>(
227-
node->op().get()) != nullptr);
226+
arolla::expr::ExprOperatorPtr MakeToArollaTextOperator() {
227+
return std::make_shared<ToArollaValueOperator>(
228+
"koda_internal.to_arolla_text",
229+
"Returns `x` converted into an arolla text value.\n"
230+
"\n"
231+
"Note that `x` must adhere to the following requirements:\n"
232+
"* `rank = 0`.\n"
233+
"* Have one of the following schemas: NONE, STRING, OBJECT, ANY.\n"
234+
"* Have a present value with type TEXT.\n"
235+
"\n"
236+
"In all other cases, an exception is raised.\n\n"
237+
"Args:\n"
238+
" x: A DataItem to be converted into an arolla text value.",
239+
arolla::GetQType<arolla::Text>());
240+
}
241+
242+
arolla::expr::ExprOperatorPtr MakeNameAnnotationOperator() {
243+
// Use aux-policy with arolla boxing rules for the second parameter.
244+
return std::make_shared<arolla::expr::NameAnnotation>(
245+
"koladata_unified_aux_policy:_p:_arolla_as_qvalue_or_expr;01");
228246
}
229247

230248
absl::StatusOr<InputContainer> InputContainer::Create(

koladata/expr/expr_operators.h

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828
#include "arolla/expr/expr_attributes.h"
2929
#include "arolla/expr/expr_node.h"
3030
#include "arolla/expr/expr_operator.h"
31-
#include "arolla/expr/expr_operator_signature.h"
32-
#include "arolla/qtype/qtype.h"
3331
#include "arolla/qtype/typed_value.h"
34-
#include "arolla/util/fingerprint.h"
3532

3633
namespace koladata::expr {
3734

@@ -48,6 +45,9 @@ class InputOperator final
4845
absl::Span<const arolla::expr::ExprAttributes> inputs) const final;
4946
};
5047

48+
// Returns true if `node` is an operator node with InputOperator.
49+
bool IsInput(const arolla::expr::ExprNodePtr& node);
50+
5151
// Non-lowerable stateful operator `koda_internal.literal()` that wraps a
5252
// TypedValue. This operator allows us to attach a view to non-DataSlice
5353
// literals.
@@ -66,57 +66,19 @@ class LiteralOperator final
6666

6767
absl::string_view py_qvalue_specialization_key() const final;
6868

69-
const arolla::TypedValue& value() const;
69+
const arolla::TypedValue& value() const { return value_; }
7070

7171
private:
7272
arolla::TypedValue value_;
7373
};
7474

75-
// Base class for non-lowerable operators that converts a DataSlice to an Arolla
76-
// type T. Supports evaluation at operator binding time if the provided input is
77-
// a literal. Dispatches the actual conversion to a corresponding
78-
// backend-operator.
79-
class ToArollaValueOperator
80-
: public arolla::expr::BackendExprOperatorTag,
81-
public arolla::expr::ExprOperatorWithFixedSignature {
82-
public:
83-
ToArollaValueOperator(absl::string_view name,
84-
arolla::expr::ExprOperatorSignature signature,
85-
absl::string_view doc, arolla::Fingerprint fingerprint,
86-
std::string backend_operator_name,
87-
arolla::QTypePtr output_qtype)
88-
: ExprOperatorWithFixedSignature(name, std::move(signature), doc,
89-
fingerprint),
90-
backend_operator_name_(std::move(backend_operator_name)),
91-
output_qtype_(output_qtype) {}
92-
93-
absl::StatusOr<arolla::expr::ExprAttributes> InferAttributes(
94-
absl::Span<const arolla::expr::ExprAttributes> inputs) const final;
95-
96-
private:
97-
std::string backend_operator_name_;
98-
arolla::QTypePtr output_qtype_;
99-
};
100-
101-
// Non-lowerable operator `koda_internal.to_arolla_int64(x)` that converts
102-
// DataSlice to int64_t. Supports evaluation at operator binding time if the
103-
// provided input is a literal.
104-
class ToArollaInt64Operator final : public ToArollaValueOperator {
105-
public:
106-
ToArollaInt64Operator();
107-
};
108-
109-
// Non-lowerable operator `koda_internal.to_arolla_text(x)` that converts
110-
// DataSlice to Text. Supports evaluation at operator binding time if the
111-
// provided input is a literal.
112-
class ToArollaTextOperator final : public ToArollaValueOperator {
113-
public:
114-
ToArollaTextOperator();
115-
};
116-
11775
// Return a literal Expr node.
11876
absl::StatusOr<arolla::expr::ExprNodePtr> MakeLiteral(arolla::TypedValue value);
11977

78+
// Returns true if `node` is either an arolla literal or an operator node with
79+
// LiteralOperator.
80+
bool IsLiteral(const arolla::expr::ExprNodePtr& node);
81+
12082
// Non-lowerable operator
12183
// ```
12284
// koda_internal.non_deterministic(
@@ -136,12 +98,18 @@ class NonDeterministicOperator final
13698
absl::Span<const arolla::expr::ExprAttributes> inputs) const final;
13799
};
138100

139-
// Returns true if `node` is an operator node with InputOperator.
140-
bool IsInput(const arolla::expr::ExprNodePtr& node);
101+
// Non-lowerable operator `koda_internal.to_arolla_int64(x)` that converts
102+
// DataSlice to int64_t. Supports evaluation at operator binding time if the
103+
// provided input is a literal.
104+
arolla::expr::ExprOperatorPtr MakeToArollaInt64Operator();
141105

142-
// Returns true if `node` is either an arolla literal or an operator node with
143-
// LiteralOperator.
144-
bool IsLiteral(const arolla::expr::ExprNodePtr& node);
106+
// Non-lowerable operator `koda_internal.to_arolla_text(x)` that converts
107+
// DataSlice to Text. Supports evaluation at operator binding time if the
108+
// provided input is a literal.
109+
arolla::expr::ExprOperatorPtr MakeToArollaTextOperator();
110+
111+
// Non-lowerable operator `kd.annotation.with_name(x, name)`.
112+
arolla::expr::ExprOperatorPtr MakeNameAnnotationOperator();
145113

146114
// Helper container to create Koda specific inputs
147115
class InputContainer {

koladata/expr/init_expr_operators.cc

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
#include <memory>
16-
1715
#include "absl/status/status.h"
1816
#include "absl/status/statusor.h"
1917
#include "absl/strings/string_view.h"
@@ -32,6 +30,7 @@ namespace koladata::expr {
3230

3331
AROLLA_INITIALIZER(
3432
.name = "arolla_operators/koda_internal",
33+
.deps = {"arolla_operators/standard"},
3534
.reverse_deps = {arolla::initializer_dep::kOperators},
3635
.init_fn = []() -> absl::Status {
3736
RETURN_IF_ERROR(arolla::expr::RegisterOperator<InputOperator>(
@@ -44,26 +43,25 @@ AROLLA_INITIALIZER(
4443
arolla::expr::Literal(internal::Ellipsis{})))
4544
.status());
4645
RETURN_IF_ERROR(
47-
arolla::expr::RegisterOperator(
48-
"koda_internal.with_name",
49-
std::make_shared<arolla::expr::NameAnnotation>(
50-
/*aux_policy=*/"_koladata_annotation_with_name"))
46+
arolla::expr::RegisterOperator("koda_internal.with_name",
47+
MakeNameAnnotationOperator())
48+
.status());
49+
RETURN_IF_ERROR(
50+
arolla::expr::RegisterOperator("koda_internal.to_arolla_int64",
51+
MakeToArollaInt64Operator())
52+
.status());
53+
RETURN_IF_ERROR(
54+
arolla::expr::RegisterOperator("koda_internal.to_arolla_text",
55+
MakeToArollaTextOperator())
5156
.status());
52-
RETURN_IF_ERROR(arolla::expr::RegisterOperator(
53-
"koda_internal.to_arolla_int64",
54-
std::make_shared<ToArollaInt64Operator>())
55-
.status());
56-
RETURN_IF_ERROR(arolla::expr::RegisterOperator(
57-
"koda_internal.to_arolla_text",
58-
std::make_shared<ToArollaTextOperator>())
59-
.status());
6057
RETURN_IF_ERROR(
6158
arolla::expr::RegisterOperator<NonDeterministicOperator>(
62-
"koda_internal.non_deterministic").status());
63-
RETURN_IF_ERROR(arolla::expr::RegisterOperator(
64-
"kd.annotation.source_location",
65-
arolla::expr::SourceLocationAnnotation::Make())
66-
.status());
59+
"koda_internal.non_deterministic")
60+
.status());
61+
RETURN_IF_ERROR(
62+
arolla::expr::RegisterOperatorAlias(
63+
"kd.annotation.source_location", "annotation.source_location")
64+
.status());
6765
return absl::OkStatus();
6866
})
6967

py/koladata/operators/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,6 @@ py_library(
747747
deps = [
748748
":optools",
749749
"//py:python_path",
750-
"//py/koladata/types:py_boxing",
751750
"@com_google_arolla//py/arolla",
752751
],
753752
)

0 commit comments

Comments
 (0)