| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Facebook, Inc. and its affiliates.  | 
 | 3 | + *  | 
 | 4 | + * Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 5 | + * you may not use this file except in compliance with the License.  | 
 | 6 | + * You may obtain a copy of the License at  | 
 | 7 | + *  | 
 | 8 | + *     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 9 | + *  | 
 | 10 | + * Unless required by applicable law or agreed to in writing, software  | 
 | 11 | + * distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 13 | + * See the License for the specific language governing permissions and  | 
 | 14 | + * limitations under the License.  | 
 | 15 | + */  | 
 | 16 | + | 
 | 17 | +#include "velox/functions/sparksql/specialforms/GetArrayStructFields.h"  | 
 | 18 | +#include "velox/expression/ConstantExpr.h"  | 
 | 19 | +#include "velox/vector/ComplexVector.h"  | 
 | 20 | + | 
 | 21 | +namespace facebook::velox::functions::sparksql {  | 
 | 22 | +namespace {  | 
 | 23 | + | 
 | 24 | +class GetArrayStructFieldsFunction : public exec::VectorFunction {  | 
 | 25 | + public:  | 
 | 26 | +  explicit GetArrayStructFieldsFunction(int32_t ordinal) : ordinal_(ordinal) {}  | 
 | 27 | + | 
 | 28 | +  void apply(  | 
 | 29 | +      const SelectivityVector& rows,  | 
 | 30 | +      std::vector<VectorPtr>& args,  | 
 | 31 | +      const TypePtr& /*resultType*/,  | 
 | 32 | +      exec::EvalCtx& context,  | 
 | 33 | +      VectorPtr& result) const override {  | 
 | 34 | +    // Decode input array vector.  | 
 | 35 | +    exec::LocalDecodedVector decoded(context, *args[0], rows);  | 
 | 36 | +    auto arrayVec = decoded->base()->as<ArrayVector>();  | 
 | 37 | + | 
 | 38 | +    DecodedVector decodedElement(*arrayVec->elements());  | 
 | 39 | +    auto elements = decodedElement.base()->as<RowVector>();  | 
 | 40 | +    auto fieldVector = elements->childAt(ordinal_);  | 
 | 41 | + | 
 | 42 | +    VectorPtr fieldResult = fieldVector;  | 
 | 43 | +    if (elements->mayHaveNulls()) {  | 
 | 44 | +      auto size = elements->size();  | 
 | 45 | +      auto indices =  | 
 | 46 | +          AlignedBuffer::allocate<vector_size_t>(size, context.pool());  | 
 | 47 | +      std::iota(  | 
 | 48 | +          indices->asMutable<vector_size_t>(),  | 
 | 49 | +          indices->asMutable<vector_size_t>() + size,  | 
 | 50 | +          0);  | 
 | 51 | + | 
 | 52 | +      // Wrap field with nulls from elements.  | 
 | 53 | +      fieldResult = BaseVector::wrapInDictionary(  | 
 | 54 | +          elements->nulls(), indices, size, fieldResult);  | 
 | 55 | +    }  | 
 | 56 | + | 
 | 57 | +    // Apply element decoding if needed.  | 
 | 58 | +    if (!decodedElement.isIdentityMapping()) {  | 
 | 59 | +      fieldResult = decodedElement.wrap(  | 
 | 60 | +          fieldResult, *arrayVec->elements(), decodedElement.size());  | 
 | 61 | +    }  | 
 | 62 | + | 
 | 63 | +    auto arrayResult = std::make_shared<ArrayVector>(  | 
 | 64 | +        context.pool(),  | 
 | 65 | +        ARRAY(fieldVector->type()),  | 
 | 66 | +        arrayVec->nulls(),  | 
 | 67 | +        arrayVec->size(),  | 
 | 68 | +        arrayVec->offsets(),  | 
 | 69 | +        arrayVec->sizes(),  | 
 | 70 | +        fieldResult);  | 
 | 71 | + | 
 | 72 | +    if (decoded->isIdentityMapping()) {  | 
 | 73 | +      result = arrayResult;  | 
 | 74 | +    } else {  | 
 | 75 | +      result = decoded->wrap(arrayResult, *args[0], decoded->size());  | 
 | 76 | +    }  | 
 | 77 | +  }  | 
 | 78 | + | 
 | 79 | + private:  | 
 | 80 | +  // The position to select subfield from the struct.  | 
 | 81 | +  const int32_t ordinal_;  | 
 | 82 | +};  | 
 | 83 | + | 
 | 84 | +} // namespace  | 
 | 85 | + | 
 | 86 | +TypePtr GetArrayStructFieldsCallToSpecialForm::resolveType(  | 
 | 87 | +    const std::vector<TypePtr>& /*argTypes*/) {  | 
 | 88 | +  VELOX_FAIL("GetArrayStructFields function does not support type resolution.");  | 
 | 89 | +}  | 
 | 90 | + | 
 | 91 | +exec::ExprPtr GetArrayStructFieldsCallToSpecialForm::constructSpecialForm(  | 
 | 92 | +    const TypePtr& type,  | 
 | 93 | +    std::vector<exec::ExprPtr>&& args,  | 
 | 94 | +    bool trackCpuUsage,  | 
 | 95 | +    const core::QueryConfig& /*config*/) {  | 
 | 96 | +  VELOX_USER_CHECK_EQ(  | 
 | 97 | +      args.size(), 2, "get_array_struct_fields expects two arguments.");  | 
 | 98 | + | 
 | 99 | +  VELOX_USER_CHECK(  | 
 | 100 | +      args[0]->type()->kind() == TypeKind::ARRAY &&  | 
 | 101 | +          args[0]->type()->asArray().elementType()->kind() == TypeKind::ROW,  | 
 | 102 | +      "The first argument of get_array_struct_fields should be of array(row) type.");  | 
 | 103 | + | 
 | 104 | +  VELOX_USER_CHECK_EQ(  | 
 | 105 | +      args[1]->type()->kind(),  | 
 | 106 | +      TypeKind::INTEGER,  | 
 | 107 | +      "The second argument of get_array_struct_fields should be of integer type.");  | 
 | 108 | + | 
 | 109 | +  auto constantExpr = std::dynamic_pointer_cast<exec::ConstantExpr>(args[1]);  | 
 | 110 | +  VELOX_USER_CHECK_NOT_NULL(  | 
 | 111 | +      constantExpr,  | 
 | 112 | +      "The second argument of get_array_struct_fields should be constant expression.");  | 
 | 113 | +  VELOX_USER_CHECK(  | 
 | 114 | +      constantExpr->value()->isConstantEncoding(),  | 
 | 115 | +      "The second argument of get_array_struct_fields should be wrapped in constant vector.");  | 
 | 116 | +  auto constantVector =  | 
 | 117 | +      constantExpr->value()->asUnchecked<ConstantVector<int32_t>>();  | 
 | 118 | +  VELOX_USER_CHECK(  | 
 | 119 | +      !constantVector->isNullAt(0),  | 
 | 120 | +      "The second argument of get_array_struct_fields should be non-nullable.");  | 
 | 121 | + | 
 | 122 | +  auto ordinal = constantVector->valueAt(0);  | 
 | 123 | + | 
 | 124 | +  VELOX_USER_CHECK_GE(  | 
 | 125 | +      ordinal, 0, "Invalid ordinal. Should be greater than or equal to 0.");  | 
 | 126 | +  auto numFields = args[0]->type()->asArray().elementType()->asRow().size();  | 
 | 127 | +  VELOX_USER_CHECK_LT(  | 
 | 128 | +      ordinal,  | 
 | 129 | +      numFields,  | 
 | 130 | +      "Invalid ordinal {} for struct with {} fields.",  | 
 | 131 | +      ordinal,  | 
 | 132 | +      numFields);  | 
 | 133 | +  auto getArrayStructFieldsFunction =  | 
 | 134 | +      std::make_shared<GetArrayStructFieldsFunction>(ordinal);  | 
 | 135 | + | 
 | 136 | +  return std::make_shared<exec::Expr>(  | 
 | 137 | +      type,  | 
 | 138 | +      std::move(args),  | 
 | 139 | +      std::move(getArrayStructFieldsFunction),  | 
 | 140 | +      exec::VectorFunctionMetadata{},  | 
 | 141 | +      kGetArrayStructFields,  | 
 | 142 | +      trackCpuUsage);  | 
 | 143 | +}  | 
 | 144 | +} // namespace facebook::velox::functions::sparksql  | 
0 commit comments