Skip to content

Commit e971b98

Browse files
zhli1142015facebook-github-bot
authored andcommitted
feat: Add Spark get_array_struct_fields function (facebookincubator#14292)
Summary: Extracts the ``ordinal``-th fields of all array elements from array(struct), and returns them as a new array. Spark source code: https://github.com/apache/spark/blob/800d3729c9c2c1b1bf2d4c326d1ade610a7f2ada/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala#L203 Pull Request resolved: facebookincubator#14292 Reviewed By: kKPulla Differential Revision: D80121658 Pulled By: pedroerp fbshipit-source-id: fcb79cab1eea13bc0b4841c6f6837820321bb5d0
1 parent 654a936 commit e971b98

File tree

7 files changed

+363
-0
lines changed

7 files changed

+363
-0
lines changed

velox/docs/functions/spark/misc.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ Miscellaneous Functions
1515
SELECT at_least_n_non_nulls(2, 0, 1.0, NULL); -- true
1616
SELECT at_least_n_non_nulls(2, 0, array(NULL), NULL); -- true
1717

18+
.. spark:function:: get_array_struct_fields(array, ordinal) -> array(T)
19+
20+
Extracts the ``ordinal``-th fields of all array elements, and returns them as a new array.
21+
The first input must be of array(strcut) type and nested complex type is allowed.
22+
The ``ordinal`` is 0-based, and if ``ordinal`` is negative or no less than
23+
the children size of strcut, exception is thrown. ::
24+
25+
SELECT items.col1 FROM VALUES (array(struct(100,'foo'), struct(200,'bar'))) AS t(items); -- array(100, 200)
26+
SELECT items.col2 FROM VALUES (array(struct(100,'foo'), struct(200,'bar'))) AS t(items); -- array('foo', 'bar')
27+
1828
.. spark:function:: get_struct_field(struct, ordinal) -> T
1929
2030
Returns the value of nested subfield at position ``ordinal`` in the input ``struct``.

velox/functions/sparksql/registration/RegisterSpecialForm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h"
2020
#include "velox/functions/sparksql/specialforms/DecimalRound.h"
2121
#include "velox/functions/sparksql/specialforms/FromJson.h"
22+
#include "velox/functions/sparksql/specialforms/GetArrayStructFields.h"
2223
#include "velox/functions/sparksql/specialforms/GetStructField.h"
2324
#include "velox/functions/sparksql/specialforms/MakeDecimal.h"
2425
#include "velox/functions/sparksql/specialforms/SparkCastExpr.h"
@@ -51,6 +52,9 @@ void registerSpecialFormGeneralFunctions(const std::string& prefix) {
5152
std::make_unique<FromJsonCallToSpecialForm>());
5253
registerFunctionCallToSpecialForm(
5354
"get_struct_field", std::make_unique<GetStructFieldCallToSpecialForm>());
55+
registerFunctionCallToSpecialForm(
56+
GetArrayStructFieldsCallToSpecialForm::kGetArrayStructFields,
57+
std::make_unique<GetArrayStructFieldsCallToSpecialForm>());
5458
}
5559
} // namespace sparksql
5660
} // namespace facebook::velox::functions

velox/functions/sparksql/specialforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ velox_add_library(
1717
AtLeastNNonNulls.cpp
1818
DecimalRound.cpp
1919
FromJson.cpp
20+
GetArrayStructFields.cpp
2021
GetStructField.cpp
2122
MakeDecimal.cpp
2223
SparkCastExpr.cpp
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "velox/functions/sparksql/specialforms/GetArrayStructFields.h"
18+
#include "velox/expression/ConstantExpr.h"
19+
#include "velox/vector/ComplexVector.h"
20+
21+
namespace facebook::velox::functions::sparksql {
22+
namespace {
23+
24+
class GetArrayStructFieldsFunction : public exec::VectorFunction {
25+
public:
26+
explicit GetArrayStructFieldsFunction(int32_t ordinal) : ordinal_(ordinal) {}
27+
28+
void apply(
29+
const SelectivityVector& rows,
30+
std::vector<VectorPtr>& args,
31+
const TypePtr& /*resultType*/,
32+
exec::EvalCtx& context,
33+
VectorPtr& result) const override {
34+
// Decode input array vector.
35+
exec::LocalDecodedVector decoded(context, *args[0], rows);
36+
auto arrayVec = decoded->base()->as<ArrayVector>();
37+
38+
DecodedVector decodedElement(*arrayVec->elements());
39+
auto elements = decodedElement.base()->as<RowVector>();
40+
auto fieldVector = elements->childAt(ordinal_);
41+
42+
VectorPtr fieldResult = fieldVector;
43+
if (elements->mayHaveNulls()) {
44+
auto size = elements->size();
45+
auto indices =
46+
AlignedBuffer::allocate<vector_size_t>(size, context.pool());
47+
std::iota(
48+
indices->asMutable<vector_size_t>(),
49+
indices->asMutable<vector_size_t>() + size,
50+
0);
51+
52+
// Wrap field with nulls from elements.
53+
fieldResult = BaseVector::wrapInDictionary(
54+
elements->nulls(), indices, size, fieldResult);
55+
}
56+
57+
// Apply element decoding if needed.
58+
if (!decodedElement.isIdentityMapping()) {
59+
fieldResult = decodedElement.wrap(
60+
fieldResult, *arrayVec->elements(), decodedElement.size());
61+
}
62+
63+
auto arrayResult = std::make_shared<ArrayVector>(
64+
context.pool(),
65+
ARRAY(fieldVector->type()),
66+
arrayVec->nulls(),
67+
arrayVec->size(),
68+
arrayVec->offsets(),
69+
arrayVec->sizes(),
70+
fieldResult);
71+
72+
if (decoded->isIdentityMapping()) {
73+
result = arrayResult;
74+
} else {
75+
result = decoded->wrap(arrayResult, *args[0], decoded->size());
76+
}
77+
}
78+
79+
private:
80+
// The position to select subfield from the struct.
81+
const int32_t ordinal_;
82+
};
83+
84+
} // namespace
85+
86+
TypePtr GetArrayStructFieldsCallToSpecialForm::resolveType(
87+
const std::vector<TypePtr>& /*argTypes*/) {
88+
VELOX_FAIL("GetArrayStructFields function does not support type resolution.");
89+
}
90+
91+
exec::ExprPtr GetArrayStructFieldsCallToSpecialForm::constructSpecialForm(
92+
const TypePtr& type,
93+
std::vector<exec::ExprPtr>&& args,
94+
bool trackCpuUsage,
95+
const core::QueryConfig& /*config*/) {
96+
VELOX_USER_CHECK_EQ(
97+
args.size(), 2, "get_array_struct_fields expects two arguments.");
98+
99+
VELOX_USER_CHECK(
100+
args[0]->type()->kind() == TypeKind::ARRAY &&
101+
args[0]->type()->asArray().elementType()->kind() == TypeKind::ROW,
102+
"The first argument of get_array_struct_fields should be of array(row) type.");
103+
104+
VELOX_USER_CHECK_EQ(
105+
args[1]->type()->kind(),
106+
TypeKind::INTEGER,
107+
"The second argument of get_array_struct_fields should be of integer type.");
108+
109+
auto constantExpr = std::dynamic_pointer_cast<exec::ConstantExpr>(args[1]);
110+
VELOX_USER_CHECK_NOT_NULL(
111+
constantExpr,
112+
"The second argument of get_array_struct_fields should be constant expression.");
113+
VELOX_USER_CHECK(
114+
constantExpr->value()->isConstantEncoding(),
115+
"The second argument of get_array_struct_fields should be wrapped in constant vector.");
116+
auto constantVector =
117+
constantExpr->value()->asUnchecked<ConstantVector<int32_t>>();
118+
VELOX_USER_CHECK(
119+
!constantVector->isNullAt(0),
120+
"The second argument of get_array_struct_fields should be non-nullable.");
121+
122+
auto ordinal = constantVector->valueAt(0);
123+
124+
VELOX_USER_CHECK_GE(
125+
ordinal, 0, "Invalid ordinal. Should be greater than or equal to 0.");
126+
auto numFields = args[0]->type()->asArray().elementType()->asRow().size();
127+
VELOX_USER_CHECK_LT(
128+
ordinal,
129+
numFields,
130+
"Invalid ordinal {} for struct with {} fields.",
131+
ordinal,
132+
numFields);
133+
auto getArrayStructFieldsFunction =
134+
std::make_shared<GetArrayStructFieldsFunction>(ordinal);
135+
136+
return std::make_shared<exec::Expr>(
137+
type,
138+
std::move(args),
139+
std::move(getArrayStructFieldsFunction),
140+
exec::VectorFunctionMetadata{},
141+
kGetArrayStructFields,
142+
trackCpuUsage);
143+
}
144+
} // namespace facebook::velox::functions::sparksql
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/expression/FunctionCallToSpecialForm.h"
19+
20+
namespace facebook::velox::functions::sparksql {
21+
22+
class GetArrayStructFieldsCallToSpecialForm
23+
: public exec::FunctionCallToSpecialForm {
24+
public:
25+
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override;
26+
27+
/// Returns an expression for get_array_struct_fields special form. The
28+
/// expression is a regular expression based on a custom VectorFunction
29+
/// implementation.
30+
exec::ExprPtr constructSpecialForm(
31+
const TypePtr& type,
32+
std::vector<exec::ExprPtr>&& args,
33+
bool trackCpuUsage,
34+
const core::QueryConfig& config) override;
35+
36+
static constexpr const char* kGetArrayStructFields =
37+
"get_array_struct_fields";
38+
};
39+
40+
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ add_executable(
4242
FactorialTest.cpp
4343
FromJsonTest.cpp
4444
FromToJsonRoundTripTest.cpp
45+
GetArrayStructFieldsTest.cpp
4546
GetJsonObjectTest.cpp
4647
GetStructFieldTest.cpp
4748
HashTest.cpp

0 commit comments

Comments
 (0)