Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion axiom/logical_plan/PlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "axiom/logical_plan/PlanBuilder.h"
#include <folly/container/Foreach.h>
#include <velox/common/base/Exceptions.h>
#include <vector>
#include "axiom/connectors/ConnectorMetadata.h"
Expand Down Expand Up @@ -813,13 +814,90 @@ PlanBuilder& PlanBuilder::setOperation(
SetOperation op,
const std::vector<PlanBuilder>& inputs) {
VELOX_USER_CHECK_NULL(node_, "setOperation must be a leaf");
VELOX_USER_CHECK_GE(
inputs.size(), 2, "Set operation requires at least 2 inputs");

outputMapping_ = inputs.front().outputMapping_;

std::vector<LogicalPlanNodePtr> nodes;
nodes.reserve(inputs.size());
for (auto& builder : inputs) {
for (const auto& builder : inputs) {
VELOX_CHECK_NOT_NULL(builder.node_);
nodes.push_back(builder.node_);
}

// Apply type coercion: find common supertype for each column.
const auto firstRowType = nodes[0]->outputType();
auto targetTypes = firstRowType->children();

for (size_t i = 1; i < nodes.size(); ++i) {
const auto& rowType = nodes[i]->outputType();

VELOX_USER_CHECK_EQ(
firstRowType->size(),
rowType->size(),
"Output schemas of all inputs to a Set operation must have same number of columns");

for (uint32_t j = 0; j < firstRowType->size(); ++j) {
const auto& currentType = targetTypes[j];
const auto& nextType = rowType->childAt(j);

if (currentType->equivalent(*nextType)) {
continue;
}

auto commonType =
velox::TypeCoercer::leastCommonSuperType(currentType, nextType);
VELOX_USER_CHECK_NOT_NULL(
commonType,
"Output schemas of all inputs to a Set operation must match: {} vs. {} at {}.{}",
currentType->toSummaryString(),
nextType->toSummaryString(),
j,
firstRowType->nameOf(j));

targetTypes[j] = commonType;
}
}

auto targetRowType =
velox::ROW(folly::copy(firstRowType->names()), std::move(targetTypes));

// Add cast projections where needed.
for (auto& node : nodes) {
const auto& inputType = node->outputType();
std::vector<uint32_t> indicesToCast;
for (uint32_t i = 0; i < inputType->size(); ++i) {
if (*inputType->childAt(i) != *targetRowType->childAt(i)) {
indicesToCast.push_back(i);
}
}

if (!indicesToCast.empty()) {
std::vector<ExprPtr> exprs;
exprs.reserve(inputType->size());

size_t castIdx = 0;
for (uint32_t i = 0; i < inputType->size(); ++i) {
const auto& inputColType = inputType->childAt(i);
const auto& name = inputType->nameOf(i);

auto inputRef =
std::make_shared<InputReferenceExpr>(inputColType, name);

if (castIdx < indicesToCast.size() && indicesToCast[castIdx] == i) {
exprs.push_back(applyCoercion(inputRef, targetRowType->childAt(i)));
++castIdx;
} else {
exprs.push_back(inputRef);
}
}

node = std::make_shared<ProjectNode>(
nextId(), std::move(node), inputType->names(), std::move(exprs));
}
}

node_ = std::make_shared<SetNode>(nextId(), std::move(nodes), op);
return *this;
}
Expand Down
63 changes: 63 additions & 0 deletions axiom/logical_plan/tests/PlanBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "axiom/logical_plan/PlanBuilder.h"
#include <gtest/gtest.h>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

using namespace facebook::velox;
Expand All @@ -32,6 +33,19 @@ class PlanBuilderTest : public testing::Test {
void SetUp() override {
functions::prestosql::registerAllScalarFunctions();
}

protected:
PlanBuilder makeEmptyValues(
PlanBuilder::Context& context,
const std::vector<TypePtr>& types) {
std::vector<std::string> names;
names.reserve(types.size());
for (size_t i = 0; i < types.size(); ++i) {
names.push_back(fmt::format("c{}", i));
}
return PlanBuilder(context).values(
ROW(std::move(names), types), ValuesNode::Variants{});
}
};

TEST_F(PlanBuilderTest, outputNames) {
Expand All @@ -56,5 +70,54 @@ TEST_F(PlanBuilderTest, outputNames) {
EXPECT_EQ("expr_0", outputNames[2]);
}

TEST_F(PlanBuilderTest, setOperationTypeCoercion) {
// (INTEGER, REAL) + (BIGINT, DOUBLE) -> (BIGINT, DOUBLE)
// Verify that a project node is added for the first input (needs coercion),
// while the second input remains unchanged (types already match).
{
PlanBuilder::Context context;
auto plan = PlanBuilder(context)
.setOperation(
SetOperation::kUnionAll,
{
makeEmptyValues(context, {INTEGER(), REAL()}),
makeEmptyValues(context, {BIGINT(), DOUBLE()}),
})
.build();

EXPECT_EQ(*plan->outputType(), *ROW({"c0", "c1"}, {BIGINT(), DOUBLE()}));
}

// Same types stay the same. No project nodes needed.
{
PlanBuilder::Context context;
auto plan = PlanBuilder(context)
.setOperation(
SetOperation::kUnionAll,
{
makeEmptyValues(context, {BIGINT()}),
makeEmptyValues(context, {BIGINT()}),
})
.build();

EXPECT_EQ(*plan->outputType(), *ROW({"c0"}, {BIGINT()}));
}

// Incompatible types fail.
{
PlanBuilder::Context context;
VELOX_ASSERT_THROW(
PlanBuilder(context)
.setOperation(
SetOperation::kUnionAll,
{
makeEmptyValues(context, {VARCHAR()}),
makeEmptyValues(context, {INTEGER()}),
})
.build(),
"Output schemas of all inputs to a Set operation must match");
}
}

} // namespace
} // namespace facebook::axiom::logical_plan
Loading