1919#include < iostream>
2020#include < utility>
2121#include " axiom/optimizer/DerivedTablePrinter.h"
22+ #include " axiom/optimizer/Filters.h"
2223#include " axiom/optimizer/Plan.h"
2324#include " axiom/optimizer/PrecomputeProjection.h"
2425#include " axiom/optimizer/VeloxHistory.h"
@@ -820,7 +821,7 @@ void Optimization::addPostprocess(
820821 VELOX_DCHECK (!dt->hasLimit ());
821822 PrecomputeProjection precompute{plan, dt, /* projectAllInputs=*/ false };
822823 auto writeColumns = precompute.toColumns (dt->write ->columnExprs ());
823- plan = std::move (precompute).maybeProject ();
824+ plan = std::move (precompute).maybeProject (state );
824825 state.addCost (*plan);
825826
826827 plan = repartitionForWrite (plan, state);
@@ -868,7 +869,8 @@ void Optimization::addPostprocess(
868869 maybeDropProject (plan),
869870 usedExprs,
870871 usedColumns,
871- Project::isRedundant (plan, usedExprs, usedColumns));
872+ Project::isRedundant (plan, usedExprs, usedColumns),
873+ state);
872874 }
873875
874876 if (!dt->hasOrderBy () && dt->limit > kMaxLimitBeforeProject ) {
@@ -915,7 +917,8 @@ AggregateVector flattenAggregates(
915917// static
916918RelationOpPtr Optimization::planSingleAggregation (
917919 DerivedTableCP dt,
918- RelationOpPtr& input) {
920+ RelationOpPtr& input,
921+ PlanState& state) {
919922 const auto * aggPlan = dt->aggregation ;
920923
921924 PrecomputeProjection precompute (input, dt, /* projectAllInputs=*/ false );
@@ -924,11 +927,12 @@ RelationOpPtr Optimization::planSingleAggregation(
924927 auto aggregates = flattenAggregates (aggPlan->aggregates (), precompute);
925928
926929 return make<Aggregation>(
927- std::move (precompute).maybeProject (),
930+ std::move (precompute).maybeProject (state ),
928931 std::move (groupingKeys),
929932 std::move (aggregates),
930933 velox::core::AggregationNode::Step::kSingle ,
931- aggPlan->columns ());
934+ aggPlan->columns (),
935+ state);
932936}
933937
934938void Optimization::addAggregation (
@@ -942,7 +946,7 @@ void Optimization::addAggregation(
942946 precompute.toColumns (aggPlan->groupingKeys (), &aggPlan->columns ());
943947 auto aggregates = flattenAggregates (aggPlan->aggregates (), precompute);
944948
945- plan = std::move (precompute).maybeProject ();
949+ plan = std::move (precompute).maybeProject (state );
946950 state.placed .add (aggPlan);
947951
948952 if (isSingleWorker_ && isSingleDriver_) {
@@ -951,7 +955,8 @@ void Optimization::addAggregation(
951955 std::move (groupingKeys),
952956 std::move (aggregates),
953957 velox::core::AggregationNode::Step::kSingle ,
954- aggPlan->columns ());
958+ aggPlan->columns (),
959+ state);
955960
956961 state.addCost (*singleAgg);
957962 plan = singleAgg;
@@ -972,7 +977,8 @@ void Optimization::addAggregation(
972977 groupingKeys,
973978 aggregates,
974979 velox::core::AggregationNode::Step::kPartial ,
975- aggPlan->intermediateColumns ());
980+ aggPlan->intermediateColumns (),
981+ state);
976982
977983 PlanCost splitAggCost;
978984 splitAggCost.add (*partialAgg);
@@ -991,7 +997,8 @@ void Optimization::addAggregation(
991997 std::move (finalGroupingKeys),
992998 aggregates,
993999 velox::core::AggregationNode::Step::kFinal ,
994- aggPlan->columns ());
1000+ aggPlan->columns (),
1001+ state);
9951002 splitAggCost.add (*splitAggPlan);
9961003
9971004 if (numKeys == 0 || options_.alwaysPlanPartialAggregation ) {
@@ -1009,7 +1016,8 @@ void Optimization::addAggregation(
10091016 groupingKeys,
10101017 aggregates,
10111018 velox::core::AggregationNode::Step::kSingle ,
1012- aggPlan->columns ());
1019+ aggPlan->columns (),
1020+ state);
10131021 singleAggCost.add (*singleAgg);
10141022
10151023 if (singleAggCost.cost < splitAggCost.cost ) {
@@ -1043,7 +1051,7 @@ void Optimization::addOrderBy(
10431051 }
10441052
10451053 auto * orderBy = make<OrderBy>(
1046- std::move (precompute).maybeProject (),
1054+ std::move (precompute).maybeProject (state ),
10471055 std::move (orderKeys),
10481056 dt->orderTypes ,
10491057 dt->limit ,
@@ -1136,9 +1144,9 @@ struct ProjectionBuilder {
11361144 exprs.emplace_back (expr);
11371145 }
11381146
1139- RelationOp* build (RelationOp* input) {
1147+ RelationOp* build (RelationOp* input, PlanState& state ) {
11401148 return make<Project>(
1141- input, exprs, columns, Project::isRedundant (input, exprs, columns));
1149+ input, exprs, columns, Project::isRedundant (input, exprs, columns), state );
11421150 }
11431151
11441152 ColumnVector inputColumns () const {
@@ -1229,6 +1237,42 @@ void tryOptimizeSemiProject(
12291237}
12301238} // namespace
12311239
1240+ void Optimization::addJoinConstraint (
1241+ ExprCP left,
1242+ ExprCP right,
1243+ bool leftOptional,
1244+ bool rightOptional,
1245+ PlanState& state) {
1246+ // Get the values for left and right expressions
1247+ Value leftValue = value (state, left);
1248+ Value rightValue = value (state, right);
1249+
1250+ // Set nullability based on optionality flags
1251+ if (leftOptional) {
1252+ leftValue.nullFraction = std::max (leftValue.nullFraction , 0 .01f );
1253+ }
1254+ if (rightOptional) {
1255+ rightValue.nullFraction = std::max (rightValue.nullFraction , 0 .01f );
1256+ }
1257+
1258+ // Call columnComparisonSelectivity with updateConstraints=true
1259+ columnComparisonSelectivity (
1260+ left, right, leftValue, rightValue, toName (" eq" ), true , state.constraints );
1261+ }
1262+
1263+ void Optimization::addJoinConstraints (
1264+ const ExprVector& left,
1265+ const ExprVector& right,
1266+ bool leftOptional,
1267+ bool rightOptional,
1268+ PlanState& state) {
1269+ VELOX_CHECK_EQ (left.size (), right.size (), " Join key vectors must have same size" );
1270+
1271+ for (size_t i = 0 ; i < left.size (); ++i) {
1272+ addJoinConstraint (left[i], right[i], leftOptional, rightOptional, state);
1273+ }
1274+ }
1275+
12321276void Optimization::joinByHash (
12331277 const RelationOpPtr& plan,
12341278 const JoinCandidate& candidate,
@@ -1330,7 +1374,7 @@ void Optimization::joinByHash(
13301374
13311375 PrecomputeProjection precomputeBuild (buildInput, state.dt );
13321376 auto buildKeys = precomputeBuild.toColumns (build.keys );
1333- buildInput = std::move (precomputeBuild).maybeProject ();
1377+ buildInput = std::move (precomputeBuild).maybeProject (buildState );
13341378
13351379 auto * buildOp = make<HashBuild>(buildInput, build.keys , buildPlan);
13361380 buildState.addCost (*buildOp);
@@ -1396,7 +1440,13 @@ void Optimization::joinByHash(
13961440
13971441 PrecomputeProjection precomputeProbe (probeInput, state.dt );
13981442 auto probeKeys = precomputeProbe.toColumns (probe.keys );
1399- probeInput = std::move (precomputeProbe).maybeProject ();
1443+ probeInput = std::move (precomputeProbe).maybeProject (state);
1444+
1445+ // Add join constraints for equi-join keys (except for anti joins)
1446+ if (joinType != velox::core::JoinType::kAnti ) {
1447+ bool rightOptional = (joinType == velox::core::JoinType::kLeft );
1448+ addJoinConstraints (probe.keys , build.keys , false , rightOptional, state);
1449+ }
14001450
14011451 RelationOp* join = make<Join>(
14021452 JoinMethod::kHash ,
@@ -1413,7 +1463,7 @@ void Optimization::joinByHash(
14131463 state.cost .cost += buildState.cost .cost ;
14141464
14151465 if (needsProjection) {
1416- join = projectionBuilder.build (join);
1466+ join = projectionBuilder.build (join, state );
14171467 }
14181468
14191469 state.addNextJoin (&candidate, join, toTry);
@@ -1475,7 +1525,7 @@ void Optimization::joinByHashRight(
14751525
14761526 PrecomputeProjection precomputeBuild (buildInput, state.dt );
14771527 auto buildKeys = precomputeBuild.toColumns (build.keys );
1478- buildInput = std::move (precomputeBuild).maybeProject ();
1528+ buildInput = std::move (precomputeBuild).maybeProject (state );
14791529
14801530 auto * buildOp = make<HashBuild>(buildInput, build.keys , nullptr );
14811531 state.addCost (*buildOp);
@@ -1554,7 +1604,12 @@ void Optimization::joinByHashRight(
15541604
15551605 PrecomputeProjection precomputeProbe (probeInput, state.dt );
15561606 auto probeKeys = precomputeProbe.toColumns (probe.keys );
1557- probeInput = std::move (precomputeProbe).maybeProject ();
1607+ probeInput = std::move (precomputeProbe).maybeProject (state);
1608+
1609+ // Add join constraints for equi-join keys
1610+ // In joinByHashRight, if join type is right outer, leftOptional is true
1611+ bool leftOptional = (rightJoinType == velox::core::JoinType::kRight );
1612+ addJoinConstraints (probe.keys , build.keys , leftOptional, false , state);
15581613
15591614 RelationOp* join = make<Join>(
15601615 JoinMethod::kHash ,
@@ -1569,7 +1624,7 @@ void Optimization::joinByHashRight(
15691624 state.addCost (*join);
15701625
15711626 if (needsProjection) {
1572- join = projectionBuilder.build (join);
1627+ join = projectionBuilder.build (join, state );
15731628 }
15741629
15751630 state.addNextJoin (&candidate, join, toTry);
@@ -1613,7 +1668,7 @@ void Optimization::crossJoinUnnest(
16131668 // because we can have multiple unnest joins in single JoinCandidate.
16141669
16151670 auto unnestColumns = precompute.toColumns (unnestExprs);
1616- plan = std::move (precompute).maybeProject ();
1671+ plan = std::move (precompute).maybeProject (state );
16171672
16181673 plan = make<Unnest>(
16191674 std::move (plan),
@@ -1853,7 +1908,9 @@ ColumnVector indexColumns(
18531908 return result;
18541909}
18551910
1856- RelationOpPtr makeDistinct (const RelationOpPtr& input) {
1911+ RelationOpPtr makeDistinct (
1912+ const RelationOpPtr& input,
1913+ const PlanState& state) {
18571914 ExprVector groupingKeys;
18581915 for (const auto & column : input->columns ()) {
18591916 groupingKeys.push_back (column);
@@ -1864,7 +1921,8 @@ RelationOpPtr makeDistinct(const RelationOpPtr& input) {
18641921 groupingKeys,
18651922 AggregateVector{},
18661923 velox::core::AggregationNode::Step::kSingle ,
1867- input->columns ());
1924+ input->columns (),
1925+ state);
18681926}
18691927
18701928Distribution somePartition (const RelationOpPtrVector& inputs) {
@@ -2104,7 +2162,7 @@ PlanP Optimization::makeUnionPlan(
21042162 RelationOpPtr result = make<UnionAll>(inputs);
21052163 Aggregation* distinct = nullptr ;
21062164 if (isDistinct) {
2107- result = makeDistinct (result);
2165+ result = makeDistinct (result, inputStates[ 0 ] );
21082166 distinct = result->as <Aggregation>();
21092167 }
21102168 return unionPlan (inputStates, result, distinct);
@@ -2135,7 +2193,7 @@ PlanP Optimization::makeUnionPlan(
21352193 RelationOpPtr result = make<UnionAll>(inputs);
21362194 Aggregation* distinct = nullptr ;
21372195 if (isDistinct) {
2138- result = makeDistinct (result);
2196+ result = makeDistinct (result, inputStates[ 0 ] );
21392197 distinct = result->as <Aggregation>();
21402198 }
21412199 return unionPlan (inputStates, result, distinct);
0 commit comments