Skip to content

Commit fce985f

Browse files
apronchenkovcopybara-github
authored andcommitted
Support EvaluationContext::Options in kd.functor.call
PiperOrigin-RevId: 720251621 Change-Id: If3168d52f5d88e38db9604da33c7973a34416b51
1 parent cbf6da3 commit fce985f

File tree

10 files changed

+174
-51
lines changed

10 files changed

+174
-51
lines changed

koladata/functor/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ cc_test(
212212
":signature",
213213
":signature_storage",
214214
"//koladata:data_slice",
215+
"//koladata/expr:expr_eval",
215216
"//koladata/internal:data_item",
216217
"//koladata/internal:dtype",
217218
"//koladata/s11n",
@@ -221,6 +222,7 @@ cc_test(
221222
"@com_google_absl//absl/status:status_matchers",
222223
"@com_google_absl//absl/status:statusor",
223224
"@com_google_absl//absl/strings:string_view",
225+
"@com_google_absl//absl/time",
224226
"@com_google_arolla//arolla/expr",
225227
"@com_google_arolla//arolla/expr/operators/all",
226228
"@com_google_arolla//arolla/qexpr/operators/all",

koladata/functor/call.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ absl::StatusOr<std::vector<std::string>> GetVariableEvaluationOrder(
107107

108108
absl::StatusOr<arolla::TypedValue> CallFunctorWithCompilationCache(
109109
const DataSlice& functor, absl::Span<const arolla::TypedRef> args,
110-
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs) {
110+
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs,
111+
const expr::EvalOptions& eval_options) {
111112
ASSIGN_OR_RETURN(bool is_functor, IsFunctor(functor));
112113
if (!is_functor) {
113114
return absl::InvalidArgumentError(
@@ -142,8 +143,9 @@ absl::StatusOr<arolla::TypedValue> CallFunctorWithCompilationCache(
142143
// This passes all variables computed so far, even those not used, and
143144
// EvalExprWithCompilationCache will traverse all provided variables,
144145
// so this is O(num_variables**2). We can optimize this later if needed.
145-
ASSIGN_OR_RETURN(auto variable_value, expr::EvalExprWithCompilationCache(
146-
expr, inputs, variables));
146+
ASSIGN_OR_RETURN(auto variable_value,
147+
expr::EvalExprWithCompilationCache(
148+
expr, inputs, variables, eval_options));
147149
computed_variable_holder.push_back(std::move(variable_value));
148150
} else {
149151
computed_variable_holder.push_back(

koladata/functor/call.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "absl/status/statusor.h"
2222
#include "absl/types/span.h"
2323
#include "koladata/data_slice.h"
24+
#include "koladata/expr/expr_eval.h"
2425
#include "arolla/qtype/typed_ref.h"
2526
#include "arolla/qtype/typed_value.h"
2627

@@ -35,9 +36,15 @@ namespace koladata::functor {
3536
// via V.foo, in which case the variable expression will be evaluated before
3637
// evaluating the expression that refers to it. In case of a cycle in variables,
3738
// an error will be returned.
39+
//
40+
// The `eval_options` parameter provides a default buffer factory
41+
// (typically either the default allocator or an arena allocator)
42+
// and a cancellation checker, which allows the computation to be
43+
// interrupted midway if needed.
3844
absl::StatusOr<arolla::TypedValue> CallFunctorWithCompilationCache(
3945
const DataSlice& functor, absl::Span<const arolla::TypedRef> args,
40-
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs);
46+
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs,
47+
const expr::EvalOptions& eval_options);
4148

4249
} // namespace koladata::functor
4350

koladata/functor/call_operator.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class CallOperator : public arolla::QExprOperator {
5050
explicit CallOperator(absl::Span<const arolla::QTypePtr> input_types,
5151
arolla::QTypePtr output_type)
5252
: QExprOperator("kd.functor.call", arolla::QExprOperatorSignature::Get(
53-
input_types, output_type)) {}
53+
input_types, output_type)) {}
5454

5555
absl::StatusOr<std::unique_ptr<arolla::BoundOperator>> DoBind(
5656
absl::Span<const arolla::TypedSlot> input_slots,
@@ -77,10 +77,11 @@ class CallOperator : public arolla::QExprOperator {
7777
{kwarg_names[i],
7878
arolla::TypedRef::FromSlot(kwargs_slot.SubSlot(i), frame)});
7979
}
80-
ASSIGN_OR_RETURN(auto result,
81-
functor::CallFunctorWithCompilationCache(
82-
fn_data_slice, arg_refs, kwarg_refs),
83-
ctx->set_status(std::move(_)));
80+
ASSIGN_OR_RETURN(
81+
auto result,
82+
functor::CallFunctorWithCompilationCache(
83+
fn_data_slice, arg_refs, kwarg_refs, ctx->options()),
84+
ctx->set_status(std::move(_)));
8485
if (result.GetType() != output_slot.GetType()) {
8586
ctx->set_status(absl::InvalidArgumentError(absl::StrFormat(
8687
"the functor was called with `%s` as the output type, but the"
@@ -121,13 +122,15 @@ absl::StatusOr<arolla::OperatorPtr> CallOperatorFamily::DoGetOperator(
121122
output_type);
122123
}
123124

124-
absl::StatusOr<DataSlice> MaybeCall(const DataSlice& maybe_fn,
125+
absl::StatusOr<DataSlice> MaybeCall(arolla::EvaluationContext* ctx,
126+
const DataSlice& maybe_fn,
125127
const DataSlice& arg) {
126128
ASSIGN_OR_RETURN(bool is_functor, IsFunctor(maybe_fn));
127129
if (is_functor) {
128130
ASSIGN_OR_RETURN(auto result,
129131
functor::CallFunctorWithCompilationCache(
130-
maybe_fn, {arolla::TypedRef::FromValue(arg)}, {}));
132+
maybe_fn, /*args=*/{arolla::TypedRef::FromValue(arg)},
133+
/*kwargs=*/{}, /*eval_options=*/ctx->options()));
131134
if (result.GetType() != arolla::GetQType<DataSlice>()) {
132135
return absl::InternalError(absl::StrFormat(
133136
"the functor is expected to be evaluated to a DataSlice"

koladata/functor/call_operator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class CallOperatorFamily : public arolla::OperatorFamily {
3434
// kd.functor._maybe_call operator.
3535
// If the first argument is a functor, calls it on the second argument.
3636
// Otherwise, returns the first argument.
37-
absl::StatusOr<DataSlice> MaybeCall(const DataSlice& maybe_fn,
37+
absl::StatusOr<DataSlice> MaybeCall(arolla::EvaluationContext* ctx,
38+
const DataSlice& maybe_fn,
3839
const DataSlice& arg);
3940

4041
} // namespace koladata::functor

koladata/functor/call_test.cc

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
#include "absl/status/status_matchers.h"
2727
#include "absl/status/statusor.h"
2828
#include "absl/strings/string_view.h"
29+
#include "absl/time/time.h"
2930
#include "koladata/data_slice.h"
31+
#include "koladata/expr/expr_eval.h"
3032
#include "koladata/functor/functor.h"
3133
#include "koladata/functor/signature.h"
3234
#include "koladata/functor/signature_storage.h"
@@ -38,6 +40,7 @@
3840
#include "arolla/expr/quote.h"
3941
#include "arolla/qtype/typed_ref.h"
4042
#include "arolla/qtype/typed_value.h"
43+
#include "arolla/util/cancellation_context.h"
4144
#include "arolla/util/text.h"
4245
#include "arolla/util/status_macros_backport.h"
4346

@@ -113,23 +116,27 @@ TEST(CallTest, VariableRhombus) {
113116
arolla::TypedValue::FromValue(3),
114117
arolla::TypedValue::FromValue(4),
115118
};
116-
ASSERT_OK_AND_ASSIGN(auto result,
117-
CallFunctorWithCompilationCache(fn,
118-
{
119-
inputs[0].AsRef(),
120-
inputs[1].AsRef(),
121-
inputs[2].AsRef(),
122-
},
123-
{}));
119+
ASSERT_OK_AND_ASSIGN(auto result, CallFunctorWithCompilationCache(
120+
fn,
121+
/*args=*/
122+
{
123+
inputs[0].AsRef(),
124+
inputs[1].AsRef(),
125+
inputs[2].AsRef(),
126+
},
127+
/*kwargs=*/{}, /*eval_options=*/{}));
124128
EXPECT_THAT(result.As<int32_t>(), IsOkAndHolds(2 * ((3 + 5) + (3 + 4))));
125129

126130
ASSERT_OK_AND_ASSIGN(
127-
result, CallFunctorWithCompilationCache(
128-
fn,
129-
{
130-
inputs[0].AsRef(),
131-
},
132-
{{"c", inputs[1].AsRef()}, {"b", inputs[2].AsRef()}}));
131+
result,
132+
CallFunctorWithCompilationCache(
133+
fn,
134+
/*args=*/
135+
{
136+
inputs[0].AsRef(),
137+
},
138+
/*kwargs=*/{{"c", inputs[1].AsRef()}, {"b", inputs[2].AsRef()}},
139+
/*eval_options=*/{}));
133140
EXPECT_THAT(result.As<int32_t>(), IsOkAndHolds(2 * ((4 + 5) + (4 + 3))));
134141
}
135142

@@ -143,7 +150,8 @@ TEST(CallTest, VariableCycle) {
143150
ASSERT_OK_AND_ASSIGN(auto fn,
144151
CreateFunctor(returns_expr, koda_signature,
145152
{{"a", var_a_expr}, {"b", var_b_expr}}));
146-
EXPECT_THAT(CallFunctorWithCompilationCache(fn, {}, {}),
153+
EXPECT_THAT(CallFunctorWithCompilationCache(fn, /*args=*/{}, /*kwargs=*/{},
154+
/*eval_options=*/{}),
147155
StatusIs(absl::StatusCode::kInvalidArgument,
148156
"variable [a] has a dependency cycle"));
149157
}
@@ -155,8 +163,9 @@ TEST(CallTest, JustLiteral) {
155163
ASSERT_OK_AND_ASSIGN(auto returns_expr, WrapExpr(arolla::expr::Literal(57)));
156164
ASSERT_OK_AND_ASSIGN(auto fn,
157165
CreateFunctor(returns_expr, koda_signature, {}));
158-
ASSERT_OK_AND_ASSIGN(auto result,
159-
CallFunctorWithCompilationCache(fn, {}, {}));
166+
ASSERT_OK_AND_ASSIGN(
167+
auto result, CallFunctorWithCompilationCache(
168+
fn, /*args=*/{}, /*kwargs=*/{}, /*eval_options=*/{}));
160169
EXPECT_THAT(result.As<int32_t>(), IsOkAndHolds(57));
161170
}
162171

@@ -168,7 +177,8 @@ TEST(CallTest, MustBeScalar) {
168177
ASSERT_OK_AND_ASSIGN(auto fn,
169178
CreateFunctor(returns_expr, koda_signature, {}));
170179
ASSERT_OK_AND_ASSIGN(fn, fn.Reshape(DataSlice::JaggedShape::FlatFromSize(1)));
171-
EXPECT_THAT(CallFunctorWithCompilationCache(fn, {}, {}),
180+
EXPECT_THAT(CallFunctorWithCompilationCache(fn, /*args=*/{}, /*kwargs=*/{},
181+
/*eval_options=*/{}),
172182
StatusIs(absl::StatusCode::kInvalidArgument,
173183
"the first argument of kd.call must be a functor"));
174184
}
@@ -180,9 +190,11 @@ TEST(CallTest, NoBag) {
180190
ASSERT_OK_AND_ASSIGN(auto returns_expr, WrapExpr(arolla::expr::Literal(57)));
181191
ASSERT_OK_AND_ASSIGN(auto fn,
182192
CreateFunctor(returns_expr, koda_signature, {}));
183-
EXPECT_THAT(CallFunctorWithCompilationCache(fn.WithBag(nullptr), {}, {}),
184-
StatusIs(absl::StatusCode::kInvalidArgument,
185-
"the first argument of kd.call must be a functor"));
193+
EXPECT_THAT(
194+
CallFunctorWithCompilationCache(fn.WithBag(nullptr), /*args=*/{},
195+
/*kwargs=*/{}, /*eval_options=*/{}),
196+
StatusIs(absl::StatusCode::kInvalidArgument,
197+
"the first argument of kd.call must be a functor"));
186198
}
187199

188200
TEST(CallTest, DataSliceVariable) {
@@ -195,8 +207,9 @@ TEST(CallTest, DataSliceVariable) {
195207
internal::DataItem(schema::kInt32)));
196208
ASSERT_OK_AND_ASSIGN(
197209
auto fn, CreateFunctor(returns_expr, koda_signature, {{"a", var_a}}));
198-
ASSERT_OK_AND_ASSIGN(auto result,
199-
CallFunctorWithCompilationCache(fn, {}, {}));
210+
ASSERT_OK_AND_ASSIGN(auto result, CallFunctorWithCompilationCache(
211+
fn, /*args=*/{},
212+
/*kwargs=*/{}, /*eval_options=*/{}));
200213
EXPECT_THAT(result.As<DataSlice>(),
201214
IsOkAndHolds(IsEquivalentTo(var_a.WithBag(fn.GetBag()))));
202215
}
@@ -224,18 +237,66 @@ TEST(CallTest, EvalError) {
224237
// It is OK to only improve this on the Python side, the C++ error is not
225238
// so important.
226239
EXPECT_THAT(
227-
CallFunctorWithCompilationCache(fn,
228-
{
229-
arolla::TypedRef::FromValue(input),
230-
},
231-
{}),
240+
CallFunctorWithCompilationCache(
241+
fn,
242+
/*args=*/{arolla::TypedRef::FromValue(input)},
243+
/*kwargs=*/{}, /*eval_options=*/{}),
232244
StatusIs(absl::StatusCode::kInvalidArgument,
233245
"expected numerics, got x: DATA_SLICE; while calling math.add "
234246
"with args {annotation.qtype(L['I.a'], DATA_SLICE), 57}; while "
235247
"transforming M.math.add(L['I.a'], 57); while compiling the "
236248
"expression"));
237249
}
238250

251+
TEST(CallTest, Cancellation) {
252+
ASSERT_OK_AND_ASSIGN(
253+
auto signature,
254+
Signature::Create({
255+
{.name = "a",
256+
.kind = Signature::Parameter::Kind::kPositionalOrKeyword},
257+
}));
258+
ASSERT_OK_AND_ASSIGN(auto koda_signature,
259+
CppSignatureToKodaSignature(signature));
260+
// returns_expr = I.x + I.x + I.x
261+
ASSERT_OK_AND_ASSIGN(
262+
auto returns_expr,
263+
WrapExpr(arolla::expr::CallOp(
264+
"math.add", {arolla::expr::CallOp(
265+
"math.add", {CreateInput("a"), CreateInput("a")}),
266+
CreateInput("a")})));
267+
ASSERT_OK_AND_ASSIGN(auto fn,
268+
CreateFunctor(returns_expr, koda_signature, {}));
269+
270+
{
271+
int op_count = 2; // Stop after the second operator.
272+
auto cancel_ctx = arolla::CancellationContext::Make(
273+
/*no cooldown*/ absl::Nanoseconds(-1),
274+
/*no countdown*/ -1, [&op_count] {
275+
return --op_count > 0 ? absl::OkStatus() : absl::CancelledError("");
276+
});
277+
expr::EvalOptions eval_options{.cancellation_context = cancel_ctx.get()};
278+
EXPECT_THAT(CallFunctorWithCompilationCache(
279+
fn, /*args=*/{arolla::TypedRef::FromValue(1)},
280+
/*kwargs=*/{}, eval_options),
281+
StatusIs(absl::StatusCode::kCancelled));
282+
}
283+
{
284+
int op_count = 3; // Should stop after the third operator;
285+
// however, there are only two operators.
286+
auto cancel_ctx = arolla::CancellationContext::Make(
287+
/*no cooldown*/ absl::Nanoseconds(-1),
288+
/*no countdown*/ -1, [&op_count] {
289+
return --op_count > 0 ? absl::OkStatus() : absl::CancelledError("");
290+
});
291+
expr::EvalOptions eval_options{.cancellation_context = cancel_ctx.get()};
292+
ASSERT_OK_AND_ASSIGN(auto result,
293+
CallFunctorWithCompilationCache(
294+
fn, /*args=*/{arolla::TypedRef::FromValue(1)},
295+
/*kwargs=*/{}, eval_options));
296+
EXPECT_THAT(result.As<int>(), IsOkAndHolds(3));
297+
}
298+
}
299+
239300
} // namespace
240301

241302
} // namespace koladata::functor

py/koladata/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ cc_test(
129129
"//koladata:data_bag",
130130
"//koladata:data_slice",
131131
"//koladata:data_slice_qtype",
132+
"//koladata/expr:expr_eval",
132133
"//koladata/functor",
133134
"//koladata/functor:call",
134135
"//koladata/internal:data_item",

py/koladata/cc_benchmarks.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "koladata/data_bag.h"
2525
#include "koladata/data_slice.h"
2626
#include "koladata/data_slice_qtype.h"
27+
#include "koladata/expr/expr_eval.h"
2728
#include "koladata/functor/call.h"
2829
#include "koladata/functor/functor.h"
2930
#include "koladata/internal/data_item.h"
@@ -60,10 +61,10 @@ void BM_Add(benchmark::State& state) {
6061

6162
auto db = DataBag::Empty();
6263
auto ds = DataSlice::CreateWithSchemaFromData(
63-
internal::DataSliceImpl::Create(
64-
arolla::CreateFullDenseArray<int>(values)),
65-
DataSlice::JaggedShape::FlatFromSize(slice_size), db)
66-
.value();
64+
internal::DataSliceImpl::Create(
65+
arolla::CreateFullDenseArray<int>(values)),
66+
DataSlice::JaggedShape::FlatFromSize(slice_size), db)
67+
.value();
6768

6869
auto expr = Leaf("x");
6970
for (size_t i = 0; i < num_operators; ++i) {
@@ -157,10 +158,10 @@ void BM_Equal_Int32_Int64(benchmark::State& state) {
157158

158159
auto db = DataBag::Empty();
159160
auto int32_ds = DataSlice::CreateWithSchemaFromData(
160-
internal::DataSliceImpl::Create(
161-
arolla::CreateFullDenseArray<int>(int32_values)),
162-
DataSlice::JaggedShape::FlatFromSize(slice_size), db)
163-
.value();
161+
internal::DataSliceImpl::Create(
162+
arolla::CreateFullDenseArray<int>(int32_values)),
163+
DataSlice::JaggedShape::FlatFromSize(slice_size), db)
164+
.value();
164165
auto int64_ds = DataSlice::CreateWithSchemaFromData(
165166
internal::DataSliceImpl::Create(
166167
arolla::CreateFullDenseArray<int64_t>(int64_values)),
@@ -282,9 +283,11 @@ void BM_AddViaFunctor(benchmark::State& state) {
282283
.value();
283284
auto functor = functor::CreateFunctor(expr_slice, std::nullopt, {}).value();
284285

285-
auto fn = [&functor](const auto& ds) {
286+
expr::EvalOptions eval_options;
287+
288+
auto fn = [&functor, &eval_options](const auto& ds) {
286289
return functor::CallFunctorWithCompilationCache(
287-
functor, {arolla::TypedRef::FromValue(ds)}, {});
290+
functor, {arolla::TypedRef::FromValue(ds)}, {}, eval_options);
288291
};
289292

290293
{

0 commit comments

Comments
 (0)