Skip to content

Commit e12b4ff

Browse files
Krishna Paifacebook-github-bot
authored andcommitted
Add support for multiarg lambdas in array_sort (facebookincubator#13387)
Summary: Pull Request resolved: facebookincubator#13387 This change aims to add support for array_sort with lambda that can take multiple arguments. Currently Velox tries to convert any multiple argument lambda to a single argument one using SimpleMatcher. This works for a large majority of the cases and is pretty fast, however there are still significant number of cases where we still require support for true multiple arguments in the lambda. This preliminary PR supports that use case ; We do this by evaluating the lambda for all the argument vectors in an iteration. In every iteration we keep one of the arguments constant for an array and compare it to the rest of its array. This gives us the position of the constant element wrt to the other elements in the array. We can do this N times to sort the array. To illustrate, say lambda takes two arguments, say x, y. We will map y to the elements of the array A and assign x as A[i] for iteration i. Thus in the ith iteration we will compare x[i] against every y (i.e every element in A). At the end of each iteration, based on the number of -1, 0, 1's we will know the position of the element at index i. NOTE: 1. This implementation requires that the lambda always return -1, 0, or 1. This is enforced by Presto Java too. 2. This implementation has a further restriction that for the lambda f, f(a,a) => 0 when a = a. 3. This implementation aims to rewrite the lambda to the single lambda arg case failing which it uses the multi arg lambda. 4. This implementation also only supports 1 lambda entry and not multiple lambda entries. Differential Revision: D74852355
1 parent 4a55e54 commit e12b4ff

File tree

6 files changed

+415
-46
lines changed

6 files changed

+415
-46
lines changed

velox/core/QueryConfig.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,13 @@ class QueryConfig {
598598
static constexpr const char* kFieldNamesInJsonCastEnabled =
599599
"field_names_in_json_cast_enabled";
600600

601+
/// Represents the maximum size of the array that can be sorted by the
602+
/// array_sort() function, if the lambda takes more than one argument
603+
/// and cannot be reduced by the SimpleMatcher. If the value is 0, then
604+
/// this is disabled.
605+
static constexpr const char* kArraySortMaxIterations =
606+
"array_sort_max_iterations";
607+
601608
bool selectiveNimbleReaderEnabled() const {
602609
return get<bool>(kSelectiveNimbleReaderEnabled, false);
603610
}
@@ -1089,6 +1096,10 @@ class QueryConfig {
10891096
return get<bool>(kFieldNamesInJsonCastEnabled, false);
10901097
}
10911098

1099+
uint32_t getArraySortMaxIterations() const {
1100+
return get<uint32_t>(kArraySortMaxIterations, 1000);
1101+
}
1102+
10921103
template <typename T>
10931104
T get(const std::string& key, const T& defaultValue) const {
10941105
return config_->get<T>(key, defaultValue);

velox/expression/LambdaExpr.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class ExprCallable : public Callable {
7676
lambdaCtx.swapErrors(elementErrors);
7777
}
7878

79+
const RowTypePtr getFunctionSignatures() override {
80+
return signature_;
81+
}
82+
7983
private:
8084
void resetSharedExprs() {
8185
for (auto& expr : sharedExprsToReset_) {

velox/expression/tests/ExprCompilerTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ TEST_F(ExprCompilerTest, rewrites) {
309309
exprSet->expr(0)->toString(),
310310
"plus(1:BIGINT, array_sum_propagate_element_null(transform(c0, (x) -> multiply(x, 2:BIGINT))))");
311311

312+
exprSet = compile(makeTypedExpr(
313+
"ARRAY_SORT(c0, (a, b) -> IF(POW(a, 3) < POW(b, 3), -1, IF(a = b, 0, 1)))",
314+
ROW({"c0", "c1"}, {ARRAY(BIGINT()), BIGINT()})));
315+
ASSERT_EQ(exprSet->size(), 1);
316+
312317
exprSet = compile(makeTypedExpr(
313318
"reduce(c0, 1, (s, x) -> (s + 2) - x, s -> s)",
314319
ROW({"c0"}, {ARRAY(BIGINT())})));

0 commit comments

Comments
 (0)