2626#include " absl/status/status_matchers.h"
2727#include " absl/status/statusor.h"
2828#include " absl/strings/string_view.h"
29+ #include " absl/time/time.h"
2930#include " koladata/data_slice.h"
31+ #include " koladata/expr/expr_eval.h"
3032#include " koladata/functor/functor.h"
3133#include " koladata/functor/signature.h"
3234#include " koladata/functor/signature_storage.h"
3840#include " arolla/expr/quote.h"
3941#include " arolla/qtype/typed_ref.h"
4042#include " arolla/qtype/typed_value.h"
43+ #include " arolla/util/cancellation_context.h"
4144#include " arolla/util/text.h"
4245#include " arolla/util/status_macros_backport.h"
4346
@@ -113,23 +116,27 @@ TEST(CallTest, VariableRhombus) {
113116 arolla::TypedValue::FromValue (3 ),
114117 arolla::TypedValue::FromValue (4 ),
115118 };
116- ASSERT_OK_AND_ASSIGN (auto result,
117- CallFunctorWithCompilationCache (fn,
118- {
119- inputs[0 ].AsRef (),
120- inputs[1 ].AsRef (),
121- inputs[2 ].AsRef (),
122- },
123- {}));
119+ ASSERT_OK_AND_ASSIGN (auto result, CallFunctorWithCompilationCache (
120+ fn,
121+ /* args=*/
122+ {
123+ inputs[0 ].AsRef (),
124+ inputs[1 ].AsRef (),
125+ inputs[2 ].AsRef (),
126+ },
127+ /* kwargs=*/ {}, /* eval_options=*/ {}));
124128 EXPECT_THAT (result.As <int32_t >(), IsOkAndHolds (2 * ((3 + 5 ) + (3 + 4 ))));
125129
126130 ASSERT_OK_AND_ASSIGN (
127- result, CallFunctorWithCompilationCache (
128- fn,
129- {
130- inputs[0 ].AsRef (),
131- },
132- {{" c" , inputs[1 ].AsRef ()}, {" b" , inputs[2 ].AsRef ()}}));
131+ result,
132+ CallFunctorWithCompilationCache (
133+ fn,
134+ /* args=*/
135+ {
136+ inputs[0 ].AsRef (),
137+ },
138+ /* kwargs=*/ {{" c" , inputs[1 ].AsRef ()}, {" b" , inputs[2 ].AsRef ()}},
139+ /* eval_options=*/ {}));
133140 EXPECT_THAT (result.As <int32_t >(), IsOkAndHolds (2 * ((4 + 5 ) + (4 + 3 ))));
134141}
135142
@@ -143,7 +150,8 @@ TEST(CallTest, VariableCycle) {
143150 ASSERT_OK_AND_ASSIGN (auto fn,
144151 CreateFunctor (returns_expr, koda_signature,
145152 {{" a" , var_a_expr}, {" b" , var_b_expr}}));
146- EXPECT_THAT (CallFunctorWithCompilationCache (fn, {}, {}),
153+ EXPECT_THAT (CallFunctorWithCompilationCache (fn, /* args=*/ {}, /* kwargs=*/ {},
154+ /* eval_options=*/ {}),
147155 StatusIs (absl::StatusCode::kInvalidArgument ,
148156 " variable [a] has a dependency cycle" ));
149157}
@@ -155,8 +163,9 @@ TEST(CallTest, JustLiteral) {
155163 ASSERT_OK_AND_ASSIGN (auto returns_expr, WrapExpr (arolla::expr::Literal (57 )));
156164 ASSERT_OK_AND_ASSIGN (auto fn,
157165 CreateFunctor (returns_expr, koda_signature, {}));
158- ASSERT_OK_AND_ASSIGN (auto result,
159- CallFunctorWithCompilationCache (fn, {}, {}));
166+ ASSERT_OK_AND_ASSIGN (
167+ auto result, CallFunctorWithCompilationCache (
168+ fn, /* args=*/ {}, /* kwargs=*/ {}, /* eval_options=*/ {}));
160169 EXPECT_THAT (result.As <int32_t >(), IsOkAndHolds (57 ));
161170}
162171
@@ -168,7 +177,8 @@ TEST(CallTest, MustBeScalar) {
168177 ASSERT_OK_AND_ASSIGN (auto fn,
169178 CreateFunctor (returns_expr, koda_signature, {}));
170179 ASSERT_OK_AND_ASSIGN (fn, fn.Reshape (DataSlice::JaggedShape::FlatFromSize (1 )));
171- EXPECT_THAT (CallFunctorWithCompilationCache (fn, {}, {}),
180+ EXPECT_THAT (CallFunctorWithCompilationCache (fn, /* args=*/ {}, /* kwargs=*/ {},
181+ /* eval_options=*/ {}),
172182 StatusIs (absl::StatusCode::kInvalidArgument ,
173183 " the first argument of kd.call must be a functor" ));
174184}
@@ -180,9 +190,11 @@ TEST(CallTest, NoBag) {
180190 ASSERT_OK_AND_ASSIGN (auto returns_expr, WrapExpr (arolla::expr::Literal (57 )));
181191 ASSERT_OK_AND_ASSIGN (auto fn,
182192 CreateFunctor (returns_expr, koda_signature, {}));
183- EXPECT_THAT (CallFunctorWithCompilationCache (fn.WithBag (nullptr ), {}, {}),
184- StatusIs (absl::StatusCode::kInvalidArgument ,
185- " the first argument of kd.call must be a functor" ));
193+ EXPECT_THAT (
194+ CallFunctorWithCompilationCache (fn.WithBag (nullptr ), /* args=*/ {},
195+ /* kwargs=*/ {}, /* eval_options=*/ {}),
196+ StatusIs (absl::StatusCode::kInvalidArgument ,
197+ " the first argument of kd.call must be a functor" ));
186198}
187199
188200TEST (CallTest, DataSliceVariable) {
@@ -195,8 +207,9 @@ TEST(CallTest, DataSliceVariable) {
195207 internal::DataItem (schema::kInt32 )));
196208 ASSERT_OK_AND_ASSIGN (
197209 auto fn, CreateFunctor (returns_expr, koda_signature, {{" a" , var_a}}));
198- ASSERT_OK_AND_ASSIGN (auto result,
199- CallFunctorWithCompilationCache (fn, {}, {}));
210+ ASSERT_OK_AND_ASSIGN (auto result, CallFunctorWithCompilationCache (
211+ fn, /* args=*/ {},
212+ /* kwargs=*/ {}, /* eval_options=*/ {}));
200213 EXPECT_THAT (result.As <DataSlice>(),
201214 IsOkAndHolds (IsEquivalentTo (var_a.WithBag (fn.GetBag ()))));
202215}
@@ -224,18 +237,66 @@ TEST(CallTest, EvalError) {
224237 // It is OK to only improve this on the Python side, the C++ error is not
225238 // so important.
226239 EXPECT_THAT (
227- CallFunctorWithCompilationCache (fn,
228- {
229- arolla::TypedRef::FromValue (input),
230- },
231- {}),
240+ CallFunctorWithCompilationCache (
241+ fn,
242+ /* args=*/ {arolla::TypedRef::FromValue (input)},
243+ /* kwargs=*/ {}, /* eval_options=*/ {}),
232244 StatusIs (absl::StatusCode::kInvalidArgument ,
233245 " expected numerics, got x: DATA_SLICE; while calling math.add "
234246 " with args {annotation.qtype(L['I.a'], DATA_SLICE), 57}; while "
235247 " transforming M.math.add(L['I.a'], 57); while compiling the "
236248 " expression" ));
237249}
238250
251+ TEST (CallTest, Cancellation) {
252+ ASSERT_OK_AND_ASSIGN (
253+ auto signature,
254+ Signature::Create ({
255+ {.name = " a" ,
256+ .kind = Signature::Parameter::Kind::kPositionalOrKeyword },
257+ }));
258+ ASSERT_OK_AND_ASSIGN (auto koda_signature,
259+ CppSignatureToKodaSignature (signature));
260+ // returns_expr = I.x + I.x + I.x
261+ ASSERT_OK_AND_ASSIGN (
262+ auto returns_expr,
263+ WrapExpr (arolla::expr::CallOp (
264+ " math.add" , {arolla::expr::CallOp (
265+ " math.add" , {CreateInput (" a" ), CreateInput (" a" )}),
266+ CreateInput (" a" )})));
267+ ASSERT_OK_AND_ASSIGN (auto fn,
268+ CreateFunctor (returns_expr, koda_signature, {}));
269+
270+ {
271+ int op_count = 2 ; // Stop after the second operator.
272+ auto cancel_ctx = arolla::CancellationContext::Make (
273+ /* no cooldown*/ absl::Nanoseconds (-1 ),
274+ /* no countdown*/ -1 , [&op_count] {
275+ return --op_count > 0 ? absl::OkStatus () : absl::CancelledError (" " );
276+ });
277+ expr::EvalOptions eval_options{.cancellation_context = cancel_ctx.get ()};
278+ EXPECT_THAT (CallFunctorWithCompilationCache (
279+ fn, /* args=*/ {arolla::TypedRef::FromValue (1 )},
280+ /* kwargs=*/ {}, eval_options),
281+ StatusIs (absl::StatusCode::kCancelled ));
282+ }
283+ {
284+ int op_count = 3 ; // Should stop after the third operator;
285+ // however, there are only two operators.
286+ auto cancel_ctx = arolla::CancellationContext::Make (
287+ /* no cooldown*/ absl::Nanoseconds (-1 ),
288+ /* no countdown*/ -1 , [&op_count] {
289+ return --op_count > 0 ? absl::OkStatus () : absl::CancelledError (" " );
290+ });
291+ expr::EvalOptions eval_options{.cancellation_context = cancel_ctx.get ()};
292+ ASSERT_OK_AND_ASSIGN (auto result,
293+ CallFunctorWithCompilationCache (
294+ fn, /* args=*/ {arolla::TypedRef::FromValue (1 )},
295+ /* kwargs=*/ {}, eval_options));
296+ EXPECT_THAT (result.As <int >(), IsOkAndHolds (3 ));
297+ }
298+ }
299+
239300} // namespace
240301
241302} // namespace koladata::functor
0 commit comments