-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix: RESPECT NULLS for Spark collect_list function #16933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -762,10 +762,12 @@ class SimpleAggregateAdapter : public Aggregate { | |
| } | ||
| } | ||
|
|
||
| protected: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making Not a blocker — just a suggestion for encapsulation. |
||
| std::unique_ptr<FUNC> fn_; | ||
|
|
||
| private: | ||
| std::vector<DecodedVector> inputDecoded_; | ||
| DecodedVector intermediateDecoded_; | ||
|
|
||
| std::unique_ptr<FUNC> fn_; | ||
| }; | ||
|
|
||
| } // namespace facebook::velox::exec | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
|
|
||
| #include "velox/exec/SimpleAggregateAdapter.h" | ||
| #include "velox/functions/lib/aggregates/ValueList.h" | ||
| #include "velox/vector/ConstantVector.h" | ||
|
|
||
| using namespace facebook::velox::aggregate; | ||
| using namespace facebook::velox::exec; | ||
|
|
@@ -44,14 +45,6 @@ | |
| // aggregation uses the accumulator path, which correctly respects the config. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment references "config" twice, but // NOTE: toIntermediate() was intentionally removed because it is static and
// cannot access the runtime ignoreNulls_ flag. Without it, partial
// aggregation uses the accumulator path, which correctly respects the flag. |
||
| bool ignoreNulls_{true}; | ||
|
|
||
| void initialize( | ||
| core::AggregationNode::Step /*step*/, | ||
| const std::vector<TypePtr>& /*argTypes*/, | ||
| const TypePtr& /*resultType*/, | ||
| const core::QueryConfig& config) { | ||
| ignoreNulls_ = config.sparkCollectListIgnoreNulls(); | ||
| } | ||
|
|
||
| struct AccumulatorType { | ||
| ValueList elements_; | ||
|
|
||
|
|
@@ -114,16 +107,40 @@ | |
| }; | ||
| }; | ||
|
|
||
| // Adapter that overrides setConstantInputs to read the ignoreNulls flag. | ||
| class CollectListAdapter : public SimpleAggregateAdapter<CollectListAggregate> { | ||
| public: | ||
| using SimpleAggregateAdapter<CollectListAggregate>::SimpleAggregateAdapter; | ||
|
|
||
| void setConstantInputs( | ||
| const std::vector<VectorPtr>& constantInputs) override { | ||
|
Check warning on line 116 in velox/functions/sparksql/aggregates/CollectListAggregate.cpp
|
||
| if (constantInputs.size() >= 2 && constantInputs[1] != nullptr && | ||
| !constantInputs[1]->isNullAt(0)) { | ||
| fn_->ignoreNulls_ = | ||
| constantInputs[1]->as<ConstantVector<bool>>()->valueAt(0); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| AggregateRegistrationResult registerCollectList( | ||
|
Check warning on line 125 in velox/functions/sparksql/aggregates/CollectListAggregate.cpp
|
||
| const std::string& name, | ||
| bool withCompanionFunctions, | ||
| bool overwrite) { | ||
| std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{ | ||
| // collect_list(E) -> array(E): default ignoreNulls=true. | ||
| exec::AggregateFunctionSignatureBuilder() | ||
| .typeVariable("E") | ||
| .returnType("array(E)") | ||
| .intermediateType("array(E)") | ||
| .argumentType("E") | ||
| .build(), | ||
| // collect_list(E, ignoreNulls) -> array(E): explicit flag. | ||
| exec::AggregateFunctionSignatureBuilder() | ||
| .typeVariable("E") | ||
| .returnType("array(E)") | ||
| .intermediateType("array(E)") | ||
| .argumentType("E") | ||
| .constantArgumentType("boolean") | ||
| .build()}; | ||
| return exec::registerAggregateFunction( | ||
| name, | ||
|
|
@@ -133,9 +150,7 @@ | |
| const std::vector<TypePtr>& argTypes, | ||
| const TypePtr& resultType, | ||
| const core::QueryConfig& config) -> std::unique_ptr<exec::Aggregate> { | ||
| VELOX_CHECK_EQ( | ||
| argTypes.size(), 1, "{} takes at most one argument", name); | ||
| return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>( | ||
| return std::make_unique<CollectListAdapter>( | ||
| step, argTypes, resultType, &config); | ||
| }, | ||
| withCompanionFunctions, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -124,24 +124,28 @@ TEST_F(CollectListAggregateTest, allNullsInput) { | |
| {}); | ||
| } | ||
|
|
||
| std::unordered_map<std::string, std::string> makeConfig(bool ignoreNulls) { | ||
| return {{"spark.collect_list.ignore_nulls", ignoreNulls ? "true" : "false"}}; | ||
| TEST_F(CollectListAggregateTest, explicitIgnoreNullsTrue) { | ||
| // 2-arg form with ignoreNulls=true should behave same as 1-arg. | ||
| auto input = makeRowVector({makeNullableFlatVector<int32_t>( | ||
| {1, 2, std::nullopt, 4, std::nullopt, 6})}); | ||
| auto expected = | ||
| makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 6]"})}); | ||
| testAggregations( | ||
| {input}, | ||
| {}, | ||
| {"spark_collect_list(c0, true)"}, | ||
| {"array_sort(a0)"}, | ||
| {expected}); | ||
| } | ||
|
|
||
| TEST_F(CollectListAggregateTest, respectNulls) { | ||
| // When ignoreNulls is false (RESPECT NULLS), nulls should be included. | ||
| // 2-arg form with ignoreNulls=false (RESPECT NULLS). | ||
| auto input = makeRowVector({makeNullableFlatVector<int32_t>( | ||
| {1, 2, std::nullopt, 4, std::nullopt, 6})}); | ||
| auto expected = makeRowVector({makeNullableArrayVector<int32_t>( | ||
| std::vector<std::vector<std::optional<int32_t>>>{ | ||
| {1, 2, std::nullopt, 4, std::nullopt, 6}})}); | ||
| std::vector<RowVectorPtr> expectedResult{expected}; | ||
| testAggregations( | ||
| {input}, | ||
| {}, | ||
| {"spark_collect_list(c0)"}, | ||
| expectedResult, | ||
| makeConfig(false)); | ||
| testAggregations({input}, {}, {"spark_collect_list(c0, false)"}, {expected}); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a test that verifies the constant boolean |
||
| } | ||
|
|
||
| TEST_F(CollectListAggregateTest, respectNullsGroupBy) { | ||
|
|
@@ -153,30 +157,20 @@ TEST_F(CollectListAggregateTest, respectNullsGroupBy) { | |
| makeNullableArrayVector<int64_t>( | ||
| std::vector<std::vector<std::optional<int64_t>>>{ | ||
| {std::nullopt, 1}, {2, std::nullopt, 3}})}); | ||
| std::vector<RowVectorPtr> expectedResult{expected}; | ||
| testAggregations( | ||
| {data}, | ||
| {"c0"}, | ||
| {"spark_collect_list(c1)"}, | ||
| {"spark_collect_list(c1, false)"}, | ||
| {"c0", "a0"}, | ||
| expectedResult, | ||
| makeConfig(false)); | ||
| {expected}); | ||
| } | ||
|
|
||
| TEST_F(CollectListAggregateTest, respectNullsAllNulls) { | ||
| // When all inputs are null and ignoreNulls is false, output should be an | ||
| // array of nulls (not an empty array). | ||
| auto input = makeRowVector({makeAllNullFlatVector<int32_t>(3)}); | ||
| auto expected = makeRowVector({makeNullableArrayVector<int32_t>( | ||
| std::vector<std::vector<std::optional<int32_t>>>{ | ||
| {std::nullopt, std::nullopt, std::nullopt}})}); | ||
| std::vector<RowVectorPtr> expectedResult{expected}; | ||
| testAggregations( | ||
| {input}, | ||
| {}, | ||
| {"spark_collect_list(c0)"}, | ||
| expectedResult, | ||
| makeConfig(false)); | ||
| testAggregations({input}, {}, {"spark_collect_list(c0, false)"}, {expected}); | ||
| } | ||
| } // namespace | ||
| } // namespace facebook::velox::functions::aggregate::sparksql::test | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,6 +87,8 @@ int main(int argc, char** argv) { | |
| // Velox registers a 2-arg collect_set(T, boolean) signature that Spark | ||
| // doesn't support. The fuzzer may pick this signature and fail. | ||
| "collect_set", | ||
| // Same as collect_set — 2-arg signature not supported by Spark. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: "2-arg signature not supported by Spark" is slightly misleading — Spark 4.0+ does support // Fuzzer may pick the 2-arg (T, boolean) signature which requires
// a constant boolean that the fuzzer cannot generate.
"collect_list",Same applies to the |
||
| "collect_list", | ||
| "first_ignore_null", | ||
| "last_ignore_null", | ||
| "regr_replacement", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The member order should be protected and then private