Skip to content

Commit e744225

Browse files
timofey-stepanovcopybara-github
authored andcommitted
Enable parallel transform of kd.switch.
Also minor fixes in repr. PiperOrigin-RevId: 869902121 Change-Id: I2aebeeab3a330bc17ceef54b69d86ca6c9ea2e6f
1 parent a627648 commit e744225

13 files changed

+293
-52
lines changed

docs/api_reference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5338,9 +5338,9 @@ parallel version (DataSlice -> future[DataSlice]), this is work in progress.
53385338

53395339
Args:
53405340
fn: The functor to transform.
5341-
allow_runtime_transforms: Whether to allow sub-functors to be not literals,
5342-
but computed expressions, which will therefore have to be transformed at
5343-
runtime. This can be slow.
5341+
allow_runtime_transforms: Whether to allow sub-functors to be not fully
5342+
defined at transform time (i.e. to depend on functor inputs), which will
5343+
therefore have to be transformed at runtime. This can be slow.
53445344

53455345
Returns:
53465346
The transformed functor.</code></pre>

koladata/functor/parallel/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ cc_library(
8585
"//koladata/expr:non_determinism",
8686
"//koladata/functor",
8787
"//koladata/functor:auto_variables",
88+
"//koladata/functor:call",
8889
"//koladata/functor:signature_utils",
8990
"//koladata/internal:data_item",
9091
"//koladata/internal:dtype",

koladata/functor/parallel/transform.cc

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,16 @@
6161
#include "koladata/expr/expr_operators.h"
6262
#include "koladata/expr/non_determinism.h"
6363
#include "koladata/functor/auto_variables.h"
64+
#include "koladata/functor/call.h"
6465
#include "koladata/functor/functor.h"
65-
#include "koladata/functor/parallel/transform_config.pb.h"
6666
#include "koladata/functor/parallel/transform_config.h"
67+
#include "koladata/functor/parallel/transform_config.pb.h"
6768
#include "koladata/functor/signature_utils.h"
6869
#include "koladata/functor_storage.h"
6970
#include "koladata/internal/data_item.h"
7071
#include "koladata/internal/dtype.h"
7172
#include "koladata/object_factories.h"
73+
#include "koladata/operators/core.h"
7274
#include "koladata/operators/slices.h"
7375
#include "koladata/signature.h"
7476
#include "koladata/signature_storage.h"
@@ -278,24 +280,15 @@ class InnerTransformManager {
278280
}
279281
used_var_names_.insert(result.var_name);
280282
ASSIGN_OR_RETURN(DataSlice var, functor_.GetAttr(var_name));
281-
if (!var.item().holds_value<arolla::expr::ExprQuote>()) {
282-
ASSIGN_OR_RETURN(
283-
auto transformed_var,
284-
TransformEager(arolla::TypedValue::FromValue(std::move(var))));
285-
ASSIGN_OR_RETURN(auto transformed_var_slice,
286-
std::move(transformed_var).As<DataSlice>());
287-
new_vars_.emplace_back(result.var_name, std::move(transformed_var_slice));
288-
results_.emplace(var_name, result);
289-
return result;
283+
std::optional<arolla::TypedValue> var_qvalue;
284+
if (var.item().holds_value<arolla::expr::ExprQuote>()) {
285+
ASSIGN_OR_RETURN(var_qvalue, TryComputeLiteral(var));
286+
} else {
287+
var_qvalue = arolla::TypedValue::FromValue(std::move(var));
290288
}
291-
ASSIGN_OR_RETURN(auto var_expr,
292-
var.item().value<arolla::expr::ExprQuote>().expr());
293-
if (expr::IsLiteral(var_expr)) {
294-
if (!var_expr->qvalue()) {
295-
return absl::InternalError("Literal does not have a value");
296-
}
289+
if (var_qvalue.has_value()) {
297290
ASSIGN_OR_RETURN(auto transformed_var,
298-
TransformEager(*var_expr->qvalue()));
291+
TransformEager(*std::move(var_qvalue)));
299292
auto new_var_expr = arolla::expr::Literal(std::move(transformed_var));
300293
auto new_var = DataSlice::CreatePrimitive(
301294
arolla::expr::ExprQuote{std::move(new_var_expr)});
@@ -305,18 +298,17 @@ class InnerTransformManager {
305298
}
306299
if (!config_->allow_runtime_transforms()) {
307300
return absl::InvalidArgumentError(absl::StrCat(
308-
"The parallel transformation requires that all sub-functors being"
309-
" called are specified as literals, instead of being computed"
310-
" dynamically, so that we can transform them recursively. In case you"
311-
" are calling a sub-functor that is computed dynamically, but do not"
312-
" need to recursively transform it (evaluating this sub-functor"
313-
" single-threaded is fine), you can use"
314-
" kd.functor.call_fn_normally_when_parallel and its variants to call"
315-
" the sub-functor. In case you do need to evaluate a dynamically"
316-
" computed sub-functor in a parallel fashion, you can pass"
317-
" allow_runtime_transforms=True to kd.parallel.transform, but this"
318-
" will be slower and should not be used in production."
319-
"\nThe offending sub-functor ",
301+
"The parallel transformation requires that all sub-functors being "
302+
"called are fully defined at transformation time, so that we can "
303+
"transform them recursively. In case you are calling a sub-functor "
304+
"that is computed dynamically, but do not need to recursively "
305+
"transform it (evaluating this sub-functor single-threaded is fine), "
306+
"you can use kd.functor.call_fn_normally_when_parallel and its "
307+
"variants to call the sub-functor. In case you do need to evaluate a "
308+
"dynamically computed sub-functor in a parallel fashion, you can "
309+
"pass allow_runtime_transforms=True to kd.parallel.transform, but "
310+
"this will be slower and should not be used in production.\nThe "
311+
"offending sub-functor ",
320312
var_name, ": ", arolla::Repr(var),
321313
" . The whole functor: ", arolla::Repr(functor_)));
322314
}
@@ -355,6 +347,38 @@ class InnerTransformManager {
355347
std::move(sub_functor)});
356348
}
357349

350+
// If expr can be literal folded (with respect of its dependencies in
351+
// functor_), returns its value. Otherwise returns std::nullopt.
352+
absl::StatusOr<std::optional<arolla::TypedValue>> TryComputeLiteral(
353+
const DataSlice& expr) const {
354+
// TODO: b/477578091 - we are changing the functor in order to reuse
355+
// CallFunctorWithCompilationCache. It will be more efficient if we just
356+
// evaluate all the variables needed for `expr` explicitly.
357+
ASSIGN_OR_RETURN(
358+
DataBagPtr returns_db,
359+
ops::Attr(functor_,
360+
DataSlice::CreatePrimitive(arolla::Text(kReturnsAttrName)),
361+
expr,
362+
/*override_schema=*/DataSlice::CreatePrimitive(false)));
363+
ASSIGN_OR_RETURN(
364+
DataBagPtr signature_db,
365+
ops::Attr(functor_,
366+
DataSlice::CreatePrimitive(arolla::Text(kSignatureAttrName)),
367+
KodaEmptySignature(),
368+
/*override_schema=*/DataSlice::CreatePrimitive(false)));
369+
ASSIGN_OR_RETURN(
370+
auto db, DataBag::ImmutableEmptyWithFallbacks({std::move(returns_db),
371+
std::move(signature_db),
372+
functor_.GetBag()}));
373+
auto new_functor = functor_.WithBag(std::move(db));
374+
ASSIGN_OR_RETURN(
375+
auto res, functor::CallFunctorWithCompilationCache(new_functor, {}, {}),
376+
// We consider all functor evaluation errors as caused by being not
377+
// actually a literal.
378+
std::nullopt);
379+
return res;
380+
}
381+
358382
const ParallelTransformConfigPtr absl_nonnull& config_;
359383
const DataSlice& functor_;
360384
const expr::InputContainer& variable_container_;

koladata/functor/parallel/transform_config.proto

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ message ParallelTransformConfigProto {
8888

8989
repeated OperatorReplacement operator_replacements = 1;
9090

91-
// If true, the sub-functors that are not static but computed will
92-
// be transformed to the parallel form at runtime. If false, we will
93-
// raise an error in that case.
91+
// If true, the sub-functors that are not fully defined at transform time
92+
// (i.e. that depend on functor inputs) will be transformed to the parallel
93+
// form at runtime. If false, we will raise an error in that case.
9494
bool allow_runtime_transforms = 2;
9595
}

koladata/functor/signature_utils.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ const Signature& CppArgsKwargsSignature() {
280280
.value());
281281
return *val;
282282
}
283+
284+
const Signature& CppEmptySignature() {
285+
static absl::NoDestructor<Signature> val(Signature::Create({}).value());
286+
return *val;
287+
}
288+
283289
} // namespace
284290

285291
const DataSlice& KodaArgsKwargsSignature() {
@@ -288,4 +294,10 @@ const DataSlice& KodaArgsKwargsSignature() {
288294
return *val;
289295
}
290296

297+
const DataSlice& KodaEmptySignature() {
298+
static absl::NoDestructor<DataSlice> val{
299+
CppSignatureToKodaSignature(CppEmptySignature()).value()};
300+
return *val;
301+
}
302+
291303
} // namespace koladata::functor

koladata/functor/signature_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ const DataSlice& NoDefaultValueMarker();
5757
// Returns functor signature for the *args and **kwargs parameters.
5858
const DataSlice& KodaArgsKwargsSignature();
5959

60+
// Returns a functor signature with no parameters.
61+
const DataSlice& KodaEmptySignature();
6062

6163
// Converts a C++ Signature object to a Koda DataItem storing the signature.
6264
// The returned DataItem will have a new DataBag created to store the triples.

koladata/functor/signature_utils_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "koladata/internal/dtype.h"
3232
#include "koladata/internal/object_id.h"
3333
#include "koladata/signature.h"
34+
#include "koladata/signature_storage.h"
3435
#include "koladata/test_utils.h"
3536
#include "koladata/testing/matchers.h"
3637

@@ -42,6 +43,7 @@ using ::absl_testing::IsOkAndHolds;
4243
using ::absl_testing::StatusIs;
4344
using ::koladata::testing::IsEquivalentTo;
4445
using ::testing::HasSubstr;
46+
using ::testing::IsEmpty;
4547

4648
TEST(BindArgumentsTest, Basic) {
4749
Signature::Parameter p1 = {
@@ -380,6 +382,12 @@ TEST(CppSignatureToKodaSignatureTest, Basic) {
380382
EXPECT_THAT(arolla::Repr(koda_signature), HasSubstr("no_default_value"));
381383
}
382384

385+
TEST(SignatureUtilsTest, KodaEmptySignature) {
386+
ASSERT_OK_AND_ASSIGN(auto signature,
387+
KodaSignatureToCppSignature(KodaEmptySignature()));
388+
EXPECT_THAT(signature.parameters(), IsEmpty());
389+
}
390+
383391
} // namespace
384392

385393
} // namespace koladata::functor

py/koladata/functions/parallel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,14 @@ def transform(
177177
178178
Args:
179179
fn: The functor to transform.
180-
allow_runtime_transforms: Whether to allow sub-functors to be not literals,
181-
but computed expressions, which will therefore have to be transformed at
182-
runtime. This can be slow.
180+
allow_runtime_transforms: Whether to allow sub-functors to be not fully
181+
defined at transform time (i.e. to depend on functor inputs), which will
182+
therefore have to be transformed at runtime. This can be slow.
183183
184184
Returns:
185185
The transformed functor.
186186
"""
187+
187188
fn = py_boxing.as_qvalue(fn)
188189
config = koda_internal_parallel.create_transform_config(
189190
koda_internal_parallel.get_default_transform_config_src().with_attrs(

py/koladata/operators/functor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,13 @@ def _switch_repr(
373373
# unlike str(k).
374374
k_repr = repr(k.to_py()) if k != SWITCH_DEFAULT else 'kd.SWITCH_DEFAULT'
375375
case_reprs[k_repr] = str(f)
376-
cases_repr = ',\n'.join(f'{k}: {v}' for k, v in case_reprs.items())
376+
cases_repr = (
377+
'{\n'
378+
+ textwrap.indent(
379+
'\n'.join(f'{k}: {v},' for k, v in case_reprs.items()), ' '
380+
)
381+
+ '\n}'
382+
)
377383
else:
378384
cases_repr = f'kd.dict({tokens[case_keys].text}, {tokens[case_fns].text})'
379385

@@ -382,9 +388,11 @@ def _switch_repr(
382388
res = op_repr.ReprToken()
383389
res.text = (
384390
f'{node.op.display_name}(\n'
385-
f'{textwrap.indent(key_repr, " ")},\n'
386-
f'{textwrap.indent(cases_repr, " ")},\n'
387-
f'{textwrap.indent(args_kwargs, " ")})'
391+
+ textwrap.indent(
392+
',\n'.join(filter(None, (key_repr, cases_repr, args_kwargs))),
393+
' ',
394+
)
395+
+ ')'
388396
)
389397
return res
390398

py/koladata/operators/koda_internal_parallel.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,69 @@ def _parallel_if(
22462246
cond,
22472247
transformed_yes_fn,
22482248
transformed_no_fn,
2249+
# async_eval waits until all the future args are ready, but here we
2250+
# hide the futures inside a tuple in order to propagate to the
2251+
# downstream operators.
2252+
(args, return_type_as, kwargs),
2253+
optools.unified_non_deterministic_arg(),
2254+
)
2255+
)
2256+
2257+
2258+
@optools.as_lambda_operator(
2259+
'koda_internal.parallel._parallel_switch_impl',
2260+
)
2261+
def _parallel_switch_impl(
2262+
executor, key, case_keys, transformed_case_fns, parallel_args
2263+
):
2264+
"""Implementation helper for _parallel_switch."""
2265+
args = parallel_args[0]
2266+
return_type_as = parallel_args[1]
2267+
kwargs = parallel_args[2]
2268+
return arolla.abc.bind_op( # pytype: disable=wrong-arg-types
2269+
functor.switch,
2270+
key,
2271+
case_keys,
2272+
transformed_case_fns,
2273+
return_type_as=return_type_as,
2274+
args=M.core.concat_tuples(M.core.make_tuple(executor), args),
2275+
kwargs=kwargs,
2276+
non_deterministic=optools.unified_non_deterministic_arg(),
2277+
)
2278+
2279+
2280+
@optools.add_to_registry(via_cc_operator_package=True)
2281+
@optools.as_lambda_operator(
2282+
'koda_internal.parallel._parallel_switch',
2283+
qtype_constraints=[
2284+
qtype_utils.expect_executor(P.executor),
2285+
qtype_utils.expect_future(P.key),
2286+
qtype_utils.expect_data_slice(P.transformed_case_fns),
2287+
qtype_utils.expect_tuple(P.args),
2288+
qtype_utils.expect_namedtuple(P.kwargs),
2289+
],
2290+
)
2291+
def _parallel_switch(
2292+
executor,
2293+
key,
2294+
case_keys,
2295+
transformed_case_fns,
2296+
return_type_as,
2297+
args,
2298+
kwargs,
2299+
):
2300+
"""The parallel version of kd.switch."""
2301+
return unwrap_future_to_parallel(
2302+
async_eval(
2303+
executor,
2304+
_parallel_switch_impl,
2305+
executor,
2306+
key,
2307+
case_keys,
2308+
transformed_case_fns,
2309+
# async_eval waits until all the future args are ready, but here we
2310+
# hide the futures inside a tuple in order to propagate to the
2311+
# downstream operators.
22492312
(args, return_type_as, kwargs),
22502313
optools.unified_non_deterministic_arg(),
22512314
)
@@ -3323,6 +3386,15 @@ def _parallel_with_assertion(executor, x, condition, message_or_fn, args):
33233386
functor_argument_indices: 2
33243387
}
33253388
}
3389+
operator_replacements {
3390+
from_op: "kd.functor.switch"
3391+
to_op: "koda_internal.parallel._parallel_switch"
3392+
argument_transformation {
3393+
arguments: EXECUTOR
3394+
arguments: ORIGINAL_ARGUMENTS
3395+
functor_argument_indices: 2
3396+
}
3397+
}
33263398
operator_replacements {
33273399
from_op: "kd.iterables.make"
33283400
to_op: "koda_internal.parallel._parallel_stream_make"

0 commit comments

Comments
 (0)