Skip to content

Commit d03c49c

Browse files
oerlingmeta-codesync[bot]
authored andcommitted
Track Value constraints through plan construction (facebookincubator#708)
Summary: Pull Request resolved: facebookincubator#708 Differential Revision: D89130357
1 parent 7f2a96a commit d03c49c

File tree

11 files changed

+272
-50
lines changed

11 files changed

+272
-50
lines changed

axiom/optimizer/Filters.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ float combineSelectivities(
127127
}
128128

129129
const Value& value(const PlanState& state, ExprCP expr) {
130-
return expr->value();
130+
auto it = state.constraints.find(expr->id());
131+
if (it != state.constraints.end()) {
132+
return it->second;
133+
}
134+
return expr->value();
131135
}
132136

133137
Selectivity comparisonSelectivity(

axiom/optimizer/FunctionRegistry.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
namespace facebook::axiom::optimizer {
2222

23+
struct PlanState;
24+
struct Value;
25+
2326
/// A bit set that qualifies an Expr. Represents which functions/kinds
2427
/// of functions are found inside the children of an Expr.
2528
class FunctionSet {
@@ -182,6 +185,10 @@ struct FunctionMetadata {
182185
const logical_plan::CallExpr* call,
183186
std::vector<PathCP>& paths)>
184187
explode;
188+
189+
/// Function to compute derived constraints for function calls.
190+
std::function<std::optional<Value>(ExprCP, PlanState& state)>
191+
functionConstraint;
185192
};
186193

187194
using FunctionMetadataCP = const FunctionMetadata*;

axiom/optimizer/Optimization.cpp

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <iostream>
2020
#include <utility>
2121
#include "axiom/optimizer/DerivedTablePrinter.h"
22+
#include "axiom/optimizer/Filters.h"
2223
#include "axiom/optimizer/Plan.h"
2324
#include "axiom/optimizer/PrecomputeProjection.h"
2425
#include "axiom/optimizer/VeloxHistory.h"
@@ -820,7 +821,7 @@ void Optimization::addPostprocess(
820821
VELOX_DCHECK(!dt->hasLimit());
821822
PrecomputeProjection precompute{plan, dt, /*projectAllInputs=*/false};
822823
auto writeColumns = precompute.toColumns(dt->write->columnExprs());
823-
plan = std::move(precompute).maybeProject();
824+
plan = std::move(precompute).maybeProject(state);
824825
state.addCost(*plan);
825826

826827
plan = repartitionForWrite(plan, state);
@@ -868,7 +869,8 @@ void Optimization::addPostprocess(
868869
maybeDropProject(plan),
869870
usedExprs,
870871
usedColumns,
871-
Project::isRedundant(plan, usedExprs, usedColumns));
872+
Project::isRedundant(plan, usedExprs, usedColumns),
873+
state);
872874
}
873875

874876
if (!dt->hasOrderBy() && dt->limit > kMaxLimitBeforeProject) {
@@ -915,7 +917,8 @@ AggregateVector flattenAggregates(
915917
// static
916918
RelationOpPtr Optimization::planSingleAggregation(
917919
DerivedTableCP dt,
918-
RelationOpPtr& input) {
920+
RelationOpPtr& input,
921+
PlanState& state) {
919922
const auto* aggPlan = dt->aggregation;
920923

921924
PrecomputeProjection precompute(input, dt, /*projectAllInputs=*/false);
@@ -924,11 +927,12 @@ RelationOpPtr Optimization::planSingleAggregation(
924927
auto aggregates = flattenAggregates(aggPlan->aggregates(), precompute);
925928

926929
return make<Aggregation>(
927-
std::move(precompute).maybeProject(),
930+
std::move(precompute).maybeProject(state),
928931
std::move(groupingKeys),
929932
std::move(aggregates),
930933
velox::core::AggregationNode::Step::kSingle,
931-
aggPlan->columns());
934+
aggPlan->columns(),
935+
state);
932936
}
933937

934938
void Optimization::addAggregation(
@@ -942,7 +946,7 @@ void Optimization::addAggregation(
942946
precompute.toColumns(aggPlan->groupingKeys(), &aggPlan->columns());
943947
auto aggregates = flattenAggregates(aggPlan->aggregates(), precompute);
944948

945-
plan = std::move(precompute).maybeProject();
949+
plan = std::move(precompute).maybeProject(state);
946950
state.placed.add(aggPlan);
947951

948952
if (isSingleWorker_ && isSingleDriver_) {
@@ -951,7 +955,8 @@ void Optimization::addAggregation(
951955
std::move(groupingKeys),
952956
std::move(aggregates),
953957
velox::core::AggregationNode::Step::kSingle,
954-
aggPlan->columns());
958+
aggPlan->columns(),
959+
state);
955960

956961
state.addCost(*singleAgg);
957962
plan = singleAgg;
@@ -972,7 +977,8 @@ void Optimization::addAggregation(
972977
groupingKeys,
973978
aggregates,
974979
velox::core::AggregationNode::Step::kPartial,
975-
aggPlan->intermediateColumns());
980+
aggPlan->intermediateColumns(),
981+
state);
976982

977983
PlanCost splitAggCost;
978984
splitAggCost.add(*partialAgg);
@@ -991,7 +997,8 @@ void Optimization::addAggregation(
991997
std::move(finalGroupingKeys),
992998
aggregates,
993999
velox::core::AggregationNode::Step::kFinal,
994-
aggPlan->columns());
1000+
aggPlan->columns(),
1001+
state);
9951002
splitAggCost.add(*splitAggPlan);
9961003

9971004
if (numKeys == 0 || options_.alwaysPlanPartialAggregation) {
@@ -1009,7 +1016,8 @@ void Optimization::addAggregation(
10091016
groupingKeys,
10101017
aggregates,
10111018
velox::core::AggregationNode::Step::kSingle,
1012-
aggPlan->columns());
1019+
aggPlan->columns(),
1020+
state);
10131021
singleAggCost.add(*singleAgg);
10141022

10151023
if (singleAggCost.cost < splitAggCost.cost) {
@@ -1043,7 +1051,7 @@ void Optimization::addOrderBy(
10431051
}
10441052

10451053
auto* orderBy = make<OrderBy>(
1046-
std::move(precompute).maybeProject(),
1054+
std::move(precompute).maybeProject(state),
10471055
std::move(orderKeys),
10481056
dt->orderTypes,
10491057
dt->limit,
@@ -1136,9 +1144,9 @@ struct ProjectionBuilder {
11361144
exprs.emplace_back(expr);
11371145
}
11381146

1139-
RelationOp* build(RelationOp* input) {
1147+
RelationOp* build(RelationOp* input, PlanState& state) {
11401148
return make<Project>(
1141-
input, exprs, columns, Project::isRedundant(input, exprs, columns));
1149+
input, exprs, columns, Project::isRedundant(input, exprs, columns), state);
11421150
}
11431151

11441152
ColumnVector inputColumns() const {
@@ -1229,6 +1237,42 @@ void tryOptimizeSemiProject(
12291237
}
12301238
} // namespace
12311239

1240+
void Optimization::addJoinConstraint(
1241+
ExprCP left,
1242+
ExprCP right,
1243+
bool leftOptional,
1244+
bool rightOptional,
1245+
PlanState& state) {
1246+
// Get the values for left and right expressions
1247+
Value leftValue = value(state, left);
1248+
Value rightValue = value(state, right);
1249+
1250+
// Set nullability based on optionality flags
1251+
if (leftOptional) {
1252+
leftValue.nullFraction = std::max(leftValue.nullFraction, 0.01f);
1253+
}
1254+
if (rightOptional) {
1255+
rightValue.nullFraction = std::max(rightValue.nullFraction, 0.01f);
1256+
}
1257+
1258+
// Call columnComparisonSelectivity with updateConstraints=true
1259+
columnComparisonSelectivity(
1260+
left, right, leftValue, rightValue, toName("eq"), true, state.constraints);
1261+
}
1262+
1263+
void Optimization::addJoinConstraints(
1264+
const ExprVector& left,
1265+
const ExprVector& right,
1266+
bool leftOptional,
1267+
bool rightOptional,
1268+
PlanState& state) {
1269+
VELOX_CHECK_EQ(left.size(), right.size(), "Join key vectors must have same size");
1270+
1271+
for (size_t i = 0; i < left.size(); ++i) {
1272+
addJoinConstraint(left[i], right[i], leftOptional, rightOptional, state);
1273+
}
1274+
}
1275+
12321276
void Optimization::joinByHash(
12331277
const RelationOpPtr& plan,
12341278
const JoinCandidate& candidate,
@@ -1330,7 +1374,7 @@ void Optimization::joinByHash(
13301374

13311375
PrecomputeProjection precomputeBuild(buildInput, state.dt);
13321376
auto buildKeys = precomputeBuild.toColumns(build.keys);
1333-
buildInput = std::move(precomputeBuild).maybeProject();
1377+
buildInput = std::move(precomputeBuild).maybeProject(buildState);
13341378

13351379
auto* buildOp = make<HashBuild>(buildInput, build.keys, buildPlan);
13361380
buildState.addCost(*buildOp);
@@ -1396,7 +1440,13 @@ void Optimization::joinByHash(
13961440

13971441
PrecomputeProjection precomputeProbe(probeInput, state.dt);
13981442
auto probeKeys = precomputeProbe.toColumns(probe.keys);
1399-
probeInput = std::move(precomputeProbe).maybeProject();
1443+
probeInput = std::move(precomputeProbe).maybeProject(state);
1444+
1445+
// Add join constraints for equi-join keys (except for anti joins)
1446+
if (joinType != velox::core::JoinType::kAnti) {
1447+
bool rightOptional = (joinType == velox::core::JoinType::kLeft);
1448+
addJoinConstraints(probe.keys, build.keys, false, rightOptional, state);
1449+
}
14001450

14011451
RelationOp* join = make<Join>(
14021452
JoinMethod::kHash,
@@ -1413,7 +1463,7 @@ void Optimization::joinByHash(
14131463
state.cost.cost += buildState.cost.cost;
14141464

14151465
if (needsProjection) {
1416-
join = projectionBuilder.build(join);
1466+
join = projectionBuilder.build(join, state);
14171467
}
14181468

14191469
state.addNextJoin(&candidate, join, toTry);
@@ -1475,7 +1525,7 @@ void Optimization::joinByHashRight(
14751525

14761526
PrecomputeProjection precomputeBuild(buildInput, state.dt);
14771527
auto buildKeys = precomputeBuild.toColumns(build.keys);
1478-
buildInput = std::move(precomputeBuild).maybeProject();
1528+
buildInput = std::move(precomputeBuild).maybeProject(state);
14791529

14801530
auto* buildOp = make<HashBuild>(buildInput, build.keys, nullptr);
14811531
state.addCost(*buildOp);
@@ -1554,7 +1604,12 @@ void Optimization::joinByHashRight(
15541604

15551605
PrecomputeProjection precomputeProbe(probeInput, state.dt);
15561606
auto probeKeys = precomputeProbe.toColumns(probe.keys);
1557-
probeInput = std::move(precomputeProbe).maybeProject();
1607+
probeInput = std::move(precomputeProbe).maybeProject(state);
1608+
1609+
// Add join constraints for equi-join keys
1610+
// In joinByHashRight, if join type is right outer, leftOptional is true
1611+
bool leftOptional = (rightJoinType == velox::core::JoinType::kRight);
1612+
addJoinConstraints(probe.keys, build.keys, leftOptional, false, state);
15581613

15591614
RelationOp* join = make<Join>(
15601615
JoinMethod::kHash,
@@ -1569,7 +1624,7 @@ void Optimization::joinByHashRight(
15691624
state.addCost(*join);
15701625

15711626
if (needsProjection) {
1572-
join = projectionBuilder.build(join);
1627+
join = projectionBuilder.build(join, state);
15731628
}
15741629

15751630
state.addNextJoin(&candidate, join, toTry);
@@ -1613,7 +1668,7 @@ void Optimization::crossJoinUnnest(
16131668
// because we can have multiple unnest joins in single JoinCandidate.
16141669

16151670
auto unnestColumns = precompute.toColumns(unnestExprs);
1616-
plan = std::move(precompute).maybeProject();
1671+
plan = std::move(precompute).maybeProject(state);
16171672

16181673
plan = make<Unnest>(
16191674
std::move(plan),
@@ -1853,7 +1908,9 @@ ColumnVector indexColumns(
18531908
return result;
18541909
}
18551910

1856-
RelationOpPtr makeDistinct(const RelationOpPtr& input) {
1911+
RelationOpPtr makeDistinct(
1912+
const RelationOpPtr& input,
1913+
const PlanState& state) {
18571914
ExprVector groupingKeys;
18581915
for (const auto& column : input->columns()) {
18591916
groupingKeys.push_back(column);
@@ -1864,7 +1921,8 @@ RelationOpPtr makeDistinct(const RelationOpPtr& input) {
18641921
groupingKeys,
18651922
AggregateVector{},
18661923
velox::core::AggregationNode::Step::kSingle,
1867-
input->columns());
1924+
input->columns(),
1925+
state);
18681926
}
18691927

18701928
Distribution somePartition(const RelationOpPtrVector& inputs) {
@@ -2104,7 +2162,7 @@ PlanP Optimization::makeUnionPlan(
21042162
RelationOpPtr result = make<UnionAll>(inputs);
21052163
Aggregation* distinct = nullptr;
21062164
if (isDistinct) {
2107-
result = makeDistinct(result);
2165+
result = makeDistinct(result, inputStates[0]);
21082166
distinct = result->as<Aggregation>();
21092167
}
21102168
return unionPlan(inputStates, result, distinct);
@@ -2135,7 +2193,7 @@ PlanP Optimization::makeUnionPlan(
21352193
RelationOpPtr result = make<UnionAll>(inputs);
21362194
Aggregation* distinct = nullptr;
21372195
if (isDistinct) {
2138-
result = makeDistinct(result);
2196+
result = makeDistinct(result, inputStates[0]);
21392197
distinct = result->as<Aggregation>();
21402198
}
21412199
return unionPlan(inputStates, result, distinct);

axiom/optimizer/Optimization.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ class Optimization {
121121
/// @param dt Derived table with an aggregation.
122122
static RelationOpPtr planSingleAggregation(
123123
DerivedTableCP dt,
124-
RelationOpPtr& input);
124+
RelationOpPtr& input,
125+
PlanState& state);
125126

126127
const std::shared_ptr<velox::core::QueryCtx>& veloxQueryCtx() const {
127128
return veloxQueryCtx_;
@@ -302,6 +303,22 @@ class Optimization {
302303
PlanState& state,
303304
std::vector<NextJoin>& toTry);
304305

306+
// Adds join constraints for a single pair of join keys.
307+
void addJoinConstraint(
308+
ExprCP left,
309+
ExprCP right,
310+
bool leftOptional,
311+
bool rightOptional,
312+
PlanState& state);
313+
314+
// Adds join constraints for vectors of join keys.
315+
void addJoinConstraints(
316+
const ExprVector& left,
317+
const ExprVector& right,
318+
bool leftOptional,
319+
bool rightOptional,
320+
PlanState& state);
321+
305322
void crossJoin(
306323
const RelationOpPtr& plan,
307324
const JoinCandidate& candidate,

0 commit comments

Comments
 (0)