Skip to content

Commit 2d38cd3

Browse files
apronchenkovcopybara-github
authored andcommitted
Change BindArguments and CallFunctorWithCompilationCache to use args and knames
This change simplifies the `kd.functor.map` implementation (and removes the need to copy argument names in `kd.functor.call`). PiperOrigin-RevId: 720559898 Change-Id: Ie8a49c1cd66681828180514b89d5f8c12bb3bd89
1 parent 5b3f421 commit 2d38cd3

File tree

9 files changed

+123
-123
lines changed

9 files changed

+123
-123
lines changed

koladata/functor/call.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ absl::StatusOr<std::vector<std::string>> GetVariableEvaluationOrder(
107107

108108
absl::StatusOr<arolla::TypedValue> CallFunctorWithCompilationCache(
109109
const DataSlice& functor, absl::Span<const arolla::TypedRef> args,
110-
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs,
110+
absl::Span<const std::string> kwnames,
111111
const expr::EvalOptions& eval_options) {
112112
ASSIGN_OR_RETURN(bool is_functor, IsFunctor(functor));
113113
if (!is_functor) {
@@ -117,7 +117,7 @@ absl::StatusOr<arolla::TypedValue> CallFunctorWithCompilationCache(
117117
ASSIGN_OR_RETURN(auto signature_item, functor.GetAttr(kSignatureAttrName));
118118
ASSIGN_OR_RETURN(auto signature, KodaSignatureToCppSignature(signature_item));
119119
ASSIGN_OR_RETURN(auto bound_arguments,
120-
BindArguments(signature, args, kwargs));
120+
BindArguments(signature, args, kwnames));
121121
ASSIGN_OR_RETURN(auto variable_evaluation_order,
122122
GetVariableEvaluationOrder(functor));
123123
if (variable_evaluation_order.empty() ||

koladata/functor/call.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,22 @@ namespace koladata::functor {
3030
// Calls the given functor with the provided arguments and keyword arguments.
3131
// The functor would typically be created by the CreateFunctor method,
3232
// and consists of a returns expression, a signature, and a set of variables.
33-
// The passed args and kwargs will be bound to the parameters of the signature
34-
// to produce the values for the inputs (I.foo) in the provided expression and
35-
// in the variable expressions. The expressions can also refer to the variables
36-
// via V.foo, in which case the variable expression will be evaluated before
37-
// evaluating the expression that refers to it. In case of a cycle in variables,
38-
// an error will be returned.
3933
//
40-
// The `eval_options` parameter provides a default buffer factory
41-
// (typically either the default allocator or an arena allocator)
42-
// and a cancellation checker, which allows the computation to be
43-
// interrupted midway if needed.
34+
// `args` must contain values for positional arguments followed by values for
35+
// keyword arguments; `kwnames` must contain the names of the keyword arguments,
36+
// so `kwnames` corresponds to a suffix of `args`. The passed arguments will be
37+
// bound to the parameters of the signature to produce the values for the inputs
38+
// (I.foo) in the provided expression and in the variable expressions. The
39+
// expressions can also refer to the variables via V.foo, in which case the
40+
// variable expression will be evaluated before evaluating the expression that
41+
// refers to it. In case of a cycle in variables, an error will be returned.
42+
//
43+
// `eval_options` parameter provides a default buffer factory (typically either
44+
// the default allocator or an arena allocator) and a cancellation context,
45+
// which allows the computation to be interrupted midway if needed.
4446
absl::StatusOr<arolla::TypedValue> CallFunctorWithCompilationCache(
4547
const DataSlice& functor, absl::Span<const arolla::TypedRef> args,
46-
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs,
48+
absl::Span<const std::string> kwnames,
4749
const expr::EvalOptions& eval_options);
4850

4951
} // namespace koladata::functor

koladata/functor/call_operator.cc

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//
1515
#include "koladata/functor/call_operator.h"
1616

17+
#include <cstdint>
1718
#include <memory>
1819
#include <string>
1920
#include <utility>
@@ -62,25 +63,21 @@ class CallOperator : public arolla::QExprOperator {
6263
const auto& fn_data_slice = frame.Get(fn_slot);
6364

6465
std::vector<arolla::TypedRef> arg_refs;
65-
arg_refs.reserve(args_slot.SubSlotCount());
66-
for (int i = 0; i < args_slot.SubSlotCount(); ++i) {
66+
arg_refs.reserve(args_slot.SubSlotCount() +
67+
kwargs_slot.SubSlotCount());
68+
for (int64_t i = 0; i < args_slot.SubSlotCount(); ++i) {
6769
arg_refs.push_back(
6870
arolla::TypedRef::FromSlot(args_slot.SubSlot(i), frame));
6971
}
70-
71-
auto kwargs_qtype = kwargs_slot.GetType();
72-
auto kwarg_names = arolla::GetFieldNames(kwargs_qtype);
73-
std::vector<std::pair<std::string, arolla::TypedRef>> kwarg_refs;
74-
kwarg_refs.reserve(kwargs_slot.SubSlotCount());
75-
for (int i = 0; i < kwargs_slot.SubSlotCount(); ++i) {
76-
kwarg_refs.push_back(
77-
{kwarg_names[i],
78-
arolla::TypedRef::FromSlot(kwargs_slot.SubSlot(i), frame)});
72+
for (int64_t i = 0; i < kwargs_slot.SubSlotCount(); ++i) {
73+
arg_refs.push_back(
74+
arolla::TypedRef::FromSlot(kwargs_slot.SubSlot(i), frame));
7975
}
76+
auto kwnames = arolla::GetFieldNames(kwargs_slot.GetType());
8077
ASSIGN_OR_RETURN(
8178
auto result,
82-
functor::CallFunctorWithCompilationCache(
83-
fn_data_slice, arg_refs, kwarg_refs, ctx->options()),
79+
functor::CallFunctorWithCompilationCache(fn_data_slice, arg_refs,
80+
kwnames, ctx->options()),
8481
ctx->set_status(std::move(_)));
8582
if (result.GetType() != output_slot.GetType()) {
8683
ctx->set_status(absl::InvalidArgumentError(absl::StrFormat(

koladata/functor/call_test.cc

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -116,26 +116,20 @@ TEST(CallTest, VariableRhombus) {
116116
arolla::TypedValue::FromValue(3),
117117
arolla::TypedValue::FromValue(4),
118118
};
119-
ASSERT_OK_AND_ASSIGN(auto result, CallFunctorWithCompilationCache(
120-
fn,
121-
/*args=*/
122-
{
123-
inputs[0].AsRef(),
124-
inputs[1].AsRef(),
125-
inputs[2].AsRef(),
126-
},
127-
/*kwargs=*/{}, /*eval_options=*/{}));
119+
ASSERT_OK_AND_ASSIGN(
120+
auto result,
121+
CallFunctorWithCompilationCache(
122+
fn,
123+
/*args=*/{inputs[0].AsRef(), inputs[1].AsRef(), inputs[2].AsRef()},
124+
/*kwnames=*/{}, /*eval_options=*/{}));
128125
EXPECT_THAT(result.As<int32_t>(), IsOkAndHolds(2 * ((3 + 5) + (3 + 4))));
129126

130127
ASSERT_OK_AND_ASSIGN(
131128
result,
132129
CallFunctorWithCompilationCache(
133130
fn,
134-
/*args=*/
135-
{
136-
inputs[0].AsRef(),
137-
},
138-
/*kwargs=*/{{"c", inputs[1].AsRef()}, {"b", inputs[2].AsRef()}},
131+
/*args=*/{inputs[0].AsRef(), inputs[1].AsRef(), inputs[2].AsRef()},
132+
/*kwnames=*/{"c", "b"},
139133
/*eval_options=*/{}));
140134
EXPECT_THAT(result.As<int32_t>(), IsOkAndHolds(2 * ((4 + 5) + (4 + 3))));
141135
}
@@ -150,7 +144,7 @@ TEST(CallTest, VariableCycle) {
150144
ASSERT_OK_AND_ASSIGN(auto fn,
151145
CreateFunctor(returns_expr, koda_signature,
152146
{{"a", var_a_expr}, {"b", var_b_expr}}));
153-
EXPECT_THAT(CallFunctorWithCompilationCache(fn, /*args=*/{}, /*kwargs=*/{},
147+
EXPECT_THAT(CallFunctorWithCompilationCache(fn, /*args=*/{}, /*kwnames=*/{},
154148
/*eval_options=*/{}),
155149
StatusIs(absl::StatusCode::kInvalidArgument,
156150
"variable [a] has a dependency cycle"));
@@ -165,7 +159,7 @@ TEST(CallTest, JustLiteral) {
165159
CreateFunctor(returns_expr, koda_signature, {}));
166160
ASSERT_OK_AND_ASSIGN(
167161
auto result, CallFunctorWithCompilationCache(
168-
fn, /*args=*/{}, /*kwargs=*/{}, /*eval_options=*/{}));
162+
fn, /*args=*/{}, /*kwnames=*/{}, /*eval_options=*/{}));
169163
EXPECT_THAT(result.As<int32_t>(), IsOkAndHolds(57));
170164
}
171165

@@ -177,7 +171,7 @@ TEST(CallTest, MustBeScalar) {
177171
ASSERT_OK_AND_ASSIGN(auto fn,
178172
CreateFunctor(returns_expr, koda_signature, {}));
179173
ASSERT_OK_AND_ASSIGN(fn, fn.Reshape(DataSlice::JaggedShape::FlatFromSize(1)));
180-
EXPECT_THAT(CallFunctorWithCompilationCache(fn, /*args=*/{}, /*kwargs=*/{},
174+
EXPECT_THAT(CallFunctorWithCompilationCache(fn, /*args=*/{}, /*kwnames=*/{},
181175
/*eval_options=*/{}),
182176
StatusIs(absl::StatusCode::kInvalidArgument,
183177
"the first argument of kd.call must be a functor"));
@@ -192,7 +186,7 @@ TEST(CallTest, NoBag) {
192186
CreateFunctor(returns_expr, koda_signature, {}));
193187
EXPECT_THAT(
194188
CallFunctorWithCompilationCache(fn.WithBag(nullptr), /*args=*/{},
195-
/*kwargs=*/{}, /*eval_options=*/{}),
189+
/*kwnames=*/{}, /*eval_options=*/{}),
196190
StatusIs(absl::StatusCode::kInvalidArgument,
197191
"the first argument of kd.call must be a functor"));
198192
}
@@ -209,7 +203,7 @@ TEST(CallTest, DataSliceVariable) {
209203
auto fn, CreateFunctor(returns_expr, koda_signature, {{"a", var_a}}));
210204
ASSERT_OK_AND_ASSIGN(auto result, CallFunctorWithCompilationCache(
211205
fn, /*args=*/{},
212-
/*kwargs=*/{}, /*eval_options=*/{}));
206+
/*kwnames=*/{}, /*eval_options=*/{}));
213207
EXPECT_THAT(result.As<DataSlice>(),
214208
IsOkAndHolds(IsEquivalentTo(var_a.WithBag(fn.GetBag()))));
215209
}
@@ -240,7 +234,7 @@ TEST(CallTest, EvalError) {
240234
CallFunctorWithCompilationCache(
241235
fn,
242236
/*args=*/{arolla::TypedRef::FromValue(input)},
243-
/*kwargs=*/{}, /*eval_options=*/{}),
237+
/*kwnames=*/{}, /*eval_options=*/{}),
244238
StatusIs(absl::StatusCode::kInvalidArgument,
245239
"expected numerics, got x: DATA_SLICE; while calling math.add "
246240
"with args {annotation.qtype(L['I.a'], DATA_SLICE), 57}; while "
@@ -277,7 +271,7 @@ TEST(CallTest, Cancellation) {
277271
expr::EvalOptions eval_options{.cancellation_context = cancel_ctx.get()};
278272
EXPECT_THAT(CallFunctorWithCompilationCache(
279273
fn, /*args=*/{arolla::TypedRef::FromValue(1)},
280-
/*kwargs=*/{}, eval_options),
274+
/*kwnames=*/{}, eval_options),
281275
StatusIs(absl::StatusCode::kCancelled));
282276
}
283277
{
@@ -292,7 +286,7 @@ TEST(CallTest, Cancellation) {
292286
ASSERT_OK_AND_ASSIGN(auto result,
293287
CallFunctorWithCompilationCache(
294288
fn, /*args=*/{arolla::TypedRef::FromValue(1)},
295-
/*kwargs=*/{}, eval_options));
289+
/*kwnames=*/{}, eval_options));
296290
EXPECT_THAT(result.As<int>(), IsOkAndHolds(3));
297291
}
298292
}

koladata/functor/map.cc

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,49 +48,35 @@ absl::StatusOr<DataSlice> MapFunctorWithCompilationCache(
4848
ASSIGN_OR_RETURN(auto aligned_args, shape::Align(std::move(args)));
4949
DataSlice aligned_functors = std::move(aligned_args.back());
5050
aligned_args.pop_back();
51-
int64_t num_positional = aligned_args.size() - kwnames.size();
5251
// Pre-allocate the vectors to avoid reallocations in the loop.
5352
std::vector<arolla::TypedRef> arg_refs;
54-
arg_refs.reserve(num_positional);
55-
std::vector<std::pair<std::string, arolla::TypedRef>> kwarg_refs;
56-
kwarg_refs.reserve(kwnames.size());
53+
arg_refs.reserve(aligned_args.size());
5754
std::vector<DataSlice> result_slices;
5855
ASSIGN_OR_RETURN(DataSlice missing,
5956
DataSlice::Create(internal::DataItem(std::nullopt),
6057
internal::DataItem(schema::kNone)));
6158

62-
auto call_on_items = [&kwnames, &missing, &arg_refs, &kwarg_refs,
63-
&num_positional, &include_missing,
59+
auto call_on_items = [&kwnames, &missing, &arg_refs, &include_missing,
6460
&eval_options](const DataSlice& functor,
6561
const std::vector<DataSlice>& arg_slices)
6662
-> absl::StatusOr<DataSlice> {
6763
if (!functor.item().has_value()) {
6864
return missing;
6965
}
7066
arg_refs.clear();
71-
for (int64_t i = 0; i < num_positional; ++i) {
72-
const auto& value = arg_slices[i];
67+
for (const auto& value : arg_slices) {
7368
DCHECK(value.is_item());
7469
if (!include_missing && !value.item().has_value()) {
7570
return missing;
7671
}
7772
arg_refs.push_back(arolla::TypedRef::FromValue(value));
7873
}
79-
kwarg_refs.clear();
80-
for (int64_t i = 0; i < kwnames.size(); ++i) {
81-
const auto& value = arg_slices[num_positional + i];
82-
DCHECK(value.is_item());
83-
if (!include_missing && !value.item().has_value()) {
84-
return missing;
85-
}
86-
kwarg_refs.emplace_back(kwnames[i], arolla::TypedRef::FromValue(value));
87-
}
8874
// We can improve the performance a lot if "functor" is the same for many
8975
// items, as a lot of the work inside CallFunctorWithCompilationCache could
9076
// be reused between items then.
9177
ASSIGN_OR_RETURN(auto result,
92-
CallFunctorWithCompilationCache(functor, arg_refs,
93-
kwarg_refs, eval_options));
78+
CallFunctorWithCompilationCache(functor, arg_refs, kwnames,
79+
eval_options));
9480
if (result.GetType() != arolla::GetQType<DataSlice>()) {
9581
return absl::InvalidArgumentError(absl::StrFormat(
9682
"the functor is expected to be evaluated to a DataItem"

koladata/functor/map.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@ namespace koladata::functor {
2828
// Calls the given functors pointwise on items from the provided arguments and
2929
// keyword arguments, expects them to return DataItems, and stacks
3030
// those into a single DataSlice.
31-
// `args` must contain first all positional arguments, then all keyword
32-
// arguments, so `kwnames` corresponds to a suffix of `args`.
31+
//
32+
// `args` must contain values for positional arguments followed by values for
33+
// keyword arguments; `kwnames` must contain the names of the keyword arguments,
34+
// so `kwnames` corresponds to a suffix of `args`.
35+
//
3336
// `include_missing` controls whether to call the functors on missing items of
3437
// `args` and `kwargs`.
3538
//
36-
// The `eval_options` parameter provides a default buffer factory
37-
// (typically either the default allocator or an arena allocator)
38-
// and a cancellation checker, which allows the computation to be
39-
// interrupted midway if needed.
39+
// `eval_options` parameter provides a default buffer factory (typically either
40+
// the default allocator or an arena allocator) and a cancellation context,
41+
// which allows the computation to be interrupted midway if needed.
4042
absl::StatusOr<DataSlice> MapFunctorWithCompilationCache(
4143
const DataSlice& functors, std::vector<DataSlice> args,
4244
absl::Span<const std::string> kwnames, bool include_missing,

koladata/functor/signature.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ Signature::Signature(absl::Span<const Parameter> parameters)
107107

108108
absl::StatusOr<std::vector<arolla::TypedValue>> BindArguments(
109109
const Signature& signature, absl::Span<const arolla::TypedRef> args,
110-
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs) {
110+
absl::Span<const std::string> kwnames) {
111+
if (args.size() < kwnames.size()) {
112+
return absl::InvalidArgumentError("args.size < kwnames.size()");
113+
}
114+
const size_t kwargs_offset = args.size() - kwnames.size();
115+
111116
const auto& parameters = signature.parameters();
112117
const auto& keyword_parameter_index = signature.keyword_parameter_index();
113118
std::vector<arolla::TypedValue> bound_arguments(
@@ -119,7 +124,7 @@ absl::StatusOr<std::vector<arolla::TypedValue>> BindArguments(
119124
std::vector<arolla::TypedRef> unknown_kwarg_values;
120125

121126
// Process positional arguments.
122-
for (size_t i = 0; i < args.size(); ++i) {
127+
for (size_t i = 0; i < kwargs_offset; ++i) {
123128
if (i >= parameters.size() ||
124129
(parameters[i].kind != Signature::Parameter::Kind::kPositionalOnly &&
125130
parameters[i].kind !=
@@ -131,15 +136,17 @@ absl::StatusOr<std::vector<arolla::TypedValue>> BindArguments(
131136
}
132137

133138
// Process keyword arguments.
134-
for (const auto& [name, value] : kwargs) {
139+
for (size_t i = kwargs_offset; i < args.size(); ++i) {
140+
const auto& name = kwnames[i - kwargs_offset];
141+
const auto& value = args[i];
135142
auto it = keyword_parameter_index.find(name);
136143
if (it == keyword_parameter_index.end()) {
137144
unknown_kwarg_names.push_back(name);
138145
unknown_kwarg_values.push_back(value);
139146
} else {
140147
if (bound_arguments[it->second].GetType() != arolla::GetNothingQType()) {
141-
return absl::InvalidArgumentError(absl::StrFormat(
142-
"parameter [%s] specified twice", parameters[it->second].name));
148+
return absl::InvalidArgumentError(
149+
absl::StrFormat("parameter [%s] specified twice", name));
143150
}
144151
bound_arguments[it->second] = arolla::TypedValue(value);
145152
}

koladata/functor/signature.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <cstddef>
1919
#include <optional>
2020
#include <string>
21-
#include <utility>
2221
#include <vector>
2322

2423
#include "absl/container/flat_hash_map.h"
@@ -96,7 +95,7 @@ class Signature {
9695
// must own it.
9796
absl::StatusOr<std::vector<arolla::TypedValue>> BindArguments(
9897
const Signature& signature, absl::Span<const arolla::TypedRef> args,
99-
absl::Span<const std::pair<std::string, arolla::TypedRef>> kwargs);
98+
absl::Span<const std::string> kwnames);
10099

101100
// Makes it possible to use absl::StrCat/absl::StrFormat (with %v) with kinds.
102101
template <typename Sink>

0 commit comments

Comments
 (0)