Skip to content

Commit 84c2ad0

Browse files
oerlingmeta-codesync[bot]
authored andcommitted
Add merge join support
Summary: This diff implements merge join as a new physical join algorithm in the Axiom optimizer. Merge join is a highly efficient join method for pre-sorted inputs, avoiding the hash table construction overhead of hash joins. When both join inputs are sorted on the join keys (or can be cheaply sorted), merge join streams through both inputs in lockstep, matching rows with equal keys. The applicability is limited to tables which are already copartitioned and where the join keys contain all the bucketing keys. For Hive partitioned tables, both sides must have a single hive partition or must have all hive partitioning columns on both sides in the join keys. Main changes: 1. **Added joinByMerge() method** (Optimization.cpp, ~300 lines): - Core merge join candidate generation logic - Validates merge join preconditions: - Checks all join keys are columns (not expressions) - merge join requires direct column comparisons - Verifies left input has both partitioning and ordering (from distribution) - Confirms ordering is ascending (merge join requires monotonic order) - Validates all partition columns are in join keys (ensures copartitioning) - Identifies merge columns from left input's orderKeys that match join keys - Plans right side with matching partition and ordering distributions - Adds shuffle and sort operators to right side if needed - Computes join using merge method - Returns NextJoin candidate for cost comparison 2. **Merge join precondition checking**: - **Column-only keys**: Rejects if any join key is an expression (e.g., `CAST(orderkey AS BIGINT)`) - **Partitioning requirement**: Left input must be partitioned (not gathered) - **Ordering requirement**: Left input must have orderKeys specified - **Ascending order**: Only kAscNullsFirst and kAscNullsLast supported - **Partition subset**: All partition columns must appear in join keys - **Matching merge columns**: At least one orderKey must match a join key 3. **Right side preparation algorithm**: - Constructs Distribution for right side matching left's partition and ordering - For each left partition column, finds corresponding right join key - For each left merge column (from orderKeys), finds corresponding right join key - Creates `forRight` distribution with: - Same DistributionType as left (enables copartitioning) - rightPartition: right keys corresponding to left partition columns - rightOrderKeys: right keys corresponding to left merge columns - rightOrderTypes: ascending order (kAscNullsLast) to match left - Calls `makePlan()` with forRight distribution - If `needsShuffle` is true, adds Repartition operator - Checks if right input needs sorting (orderKeys don't match expected) - Adds OrderBy operator if needed, after shuffle or directly 4. **Merge join cost model** (RelationOp.cpp): - Added `setMergeJoinCost()` method to Join class - Cost formula: `3 * kKeyCompareCost * numKeys * min(1, fanout) + rightSideBytes + kHashExtractColumnCost * numRightSideColumns` - Rationale: - Key comparisons: Merge join compares keys 3 times on average per match (binary search in merge) - Scales with number of keys and fanout (more comparisons for multiple matches) - Data copying: Transfers right side bytes to output - Column extraction: Extracts columns from right side vectors - Significantly cheaper than hash join for large inputs (no hash table construction) - Cost difference grows with build side size (hash table cost is O(n log n), merge is O(n)) 5. **Integration with join planning** (Optimization.cpp): - Modified `makeJoins()` to call `joinByMerge()` after `joinByIndex()` - Added `testingUseMergeJoin` option for testing: - `std::nullopt` (default): Normal cost-based selection among all join types - `true`: Prefer merge join - return immediately if joinByMerge produces a candidate - `false`: Disable merge join - skip calling joinByMerge entirely - If testing mode is off, merge join competes with hash join based on cost - If testing mode is on and merge join produced a candidate, skip hash join consideration 6. **Schema changes for lookup keys** (Schema.h, Schema.cpp): - Added `lookupColumns` field to ColumnGroup - Distinguished from `orderKeys` in Distribution: - `lookupColumns`: Columns used for index lookups (prefix of sort order) - `orderKeys`: Full sort order (may include additional sorting columns) - The key point is that sortedness does not in and of itself make a table lookup-compatible. - Modified `addIndex()` to accept both `columns` and `lookupColumns` - Updated `indexLookupCardinality()` to use lookupColumns for cardinality estimation - Enables accurate modeling of sorted table access patterns - Example: Table sorted on (orderkey, linenum) can be efficiently joined on just orderkey 7. **Velox plan translation** (ToVelox.cpp, ToVelox.h): - Added `makeMergeJoin()` method to create MergeJoinNode - Checks `join.method == JoinMethod::kMerge` to dispatch to merge join creation - Creates `velox::core::MergeJoinNode` with: - Join type (INNER, LEFT, RIGHT, FULL, SEMI, ANTI) - Left and right keys as field references - Filter expression (for non-equi join conditions) - Left and right child plan nodes - Output type from join columns - Registers prediction and history for cost feedback - MergeJoinNode relies on Velox runtime's merge join operator 8. **Bucketed sorted table creation** (ParquetTpchTest.cpp): - Added `makeBucketedSortedTables()` utility method - Creates `orders_bs`, `lineitem_bs`, `partsupp_bs`, `part_bs` tables - Uses 32 buckets (more buckets than `_b` versions for finer parallelism) - Specifies `sorted_by` property in addition to `bucketed_by` - Example: `orders_bs` is bucketed on `o_orderkey` and sorted on `o_orderkey` within each bucket - Parquet files maintain sort order within partitions - Used for testing merge join on realistic data 9. **Plan matcher support** (PlanMatcher.cpp, PlanMatcherGenerator.cpp): - Added `mergeJoin()` method to PlanMatcherBuilder - Signature: `mergeJoin(matcher, joinType)` similar to `hashJoin()` - Enables test assertions like: ```cpp auto matcher = PlanMatcherBuilder() .tableScan("orders_bs") .mergeJoin(rightMatcher, JoinType::kInner) .build(); ``` - Added merge join code generation in PlanMatcherGenerator - Generates proper `.mergeJoin()` calls when plan contains MergeJoinNode 10. **Testing infrastructure** (OptimizerOptions.h): - Added `testingUseMergeJoin` optional flag - Three modes for comprehensive testing: - `nullopt`: Production mode - cost-based selection - `true`: Force merge join - tests merge join implementation in isolation - `false`: Disable merge join - tests that hash join fallback works - Enables differential testing: run same query with and without merge join The merge join selection algorithm in joinByMerge(): ``` 1. Check preconditions: - All join keys are columns - Left input partitioned and ordered - Order is ascending - Partition columns ⊆ join keys 2. Extract merge columns: - For each left orderKey that matches a join key - Build leftMergeColumns vector 3. Plan right side: - Construct matching Distribution (partition + order) - Call makePlan() to get right input plan - Check if shuffle/sort needed via needsShuffle flag 4. Add shuffle/sort if needed: - If needsShuffle: - Add Repartition on rightPartition - Add OrderBy on rightOrderKeys - Else if ordering doesn't match: - Add OrderBy on rightOrderKeys 5. Create Join operator: - method = JoinMethod::kMerge - Compute cost using setMergeJoinCost() - Return as NextJoin candidate ``` Differential Revision: D89875337
1 parent d1a58be commit 84c2ad0

18 files changed

+1351
-92
lines changed

axiom/optimizer/Optimization.cpp

Lines changed: 495 additions & 0 deletions
Large diffs are not rendered by default.

axiom/optimizer/Optimization.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ class Optimization {
313313
PlanState& state,
314314
std::vector<NextJoin>& toTry);
315315

316+
// Adds 'candidate' on top of 'plan' as a merge join. Checks if the left
317+
// input (plan) is partitioned and ordered, and if join keys match the
318+
// ordering. Prepares the right side with appropriate partitioning and
319+
// ordering, adding shuffle and sort operators as needed.
320+
void joinByMerge(
321+
const RelationOpPtr& plan,
322+
const JoinCandidate& candidate,
323+
PlanState& state,
324+
std::vector<NextJoin>& toTry);
325+
316326
void crossJoin(
317327
const RelationOpPtr& plan,
318328
const JoinCandidate& candidate,

axiom/optimizer/OptimizerOptions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ struct OptimizerOptions {
7575
/// partial + final or not.
7676
bool alwaysPlanPartialAggregation = false;
7777

78+
/// For testing: control merge join behavior.
79+
/// - std::nullopt (default): normal cost-based selection among all join types
80+
/// - true: prefer merge joins - return immediately if joinByMerge produces a
81+
/// candidate
82+
/// - false: disable merge joins - skip calling joinByMerge
83+
std::optional<bool> testingUseMergeJoin{std::nullopt};
84+
7885
bool isMapAsStruct(std::string_view table, std::string_view column) const {
7986
if (allMapsAsStruct) {
8087
return true;

axiom/optimizer/RelationOp.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,16 @@ Join::Join(
403403
float fanout,
404404
float innerFanout,
405405
ColumnVector columns,
406-
PlanState& state)
406+
PlanState& state,
407+
ExprVector discreteJoinFilterExprs)
407408
: RelationOp{RelType::kJoin, std::move(lhs), std::move(columns)},
408409
method{method},
409410
joinType{joinType},
410411
right{std::move(rhs)},
411412
leftKeys{std::move(lhsKeys)},
412413
rightKeys{std::move(rhsKeys)},
413-
filter{std::move(filterExprs)} {
414+
filter{std::move(filterExprs)},
415+
discreteJoinFilter{std::move(discreteJoinFilterExprs)} {
414416
cost_.inputCardinality = inputCardinality();
415417

416418
// Determine optionality for each side
@@ -471,18 +473,24 @@ Join::Join(
471473
}
472474
}
473475

474-
const float buildSize = right->resultCardinality();
475-
const auto numKeys = leftKeys.size();
476-
const auto probeCost = Costs::hashTableCost(buildSize) +
477-
// Multiply by min(fanout, 1) because most misses will not compare and if
478-
// fanout > 1, there is still only one compare.
479-
(Costs::kKeyCompareCost * numKeys * std::min<float>(1, cost_.fanout)) +
480-
numKeys * Costs::kHashColumnCost;
476+
// Compute join cost based on method
477+
if (method == JoinMethod::kMerge) {
478+
setMergeJoinCost();
479+
} else {
480+
// Hash join costing
481+
const float buildSize = right->resultCardinality();
482+
const auto numKeys = leftKeys.size();
483+
const auto probeCost = Costs::hashTableCost(buildSize) +
484+
// Multiply by min(fanout, 1) because most misses will not compare and
485+
// if fanout > 1, there is still only one compare.
486+
(Costs::kKeyCompareCost * numKeys * std::min<float>(1, cost_.fanout)) +
487+
numKeys * Costs::kHashColumnCost;
481488

482-
const auto rowBytes = byteSize(right->columns());
483-
const auto rowCost = Costs::hashRowCost(buildSize, rowBytes);
489+
const auto rowBytes = byteSize(right->columns());
490+
const auto rowCost = Costs::hashRowCost(buildSize, rowBytes);
484491

485-
cost_.unitCost = probeCost + cost_.fanout * rowCost;
492+
cost_.unitCost = probeCost + cost_.fanout * rowCost;
493+
}
486494

487495
// Add constraints for non-key columns from the optional side of an outer join
488496
if (leftOptional || rightOptional) {
@@ -532,6 +540,23 @@ Join::Join(
532540
}
533541
}
534542

543+
void Join::setMergeJoinCost() {
544+
const auto numKeys = leftKeys.size();
545+
546+
// Get right side columns for byte size calculation
547+
const auto rightSideColumns = right->columns();
548+
const auto rightSideBytes = byteSize(rightSideColumns);
549+
const auto numRightSideColumns = rightSideColumns.size();
550+
551+
// Merge join cost formula:
552+
// 3 * key compare cost * number of keys * min(1, fanout) +
553+
// byteSize(rightSideColumns) +
554+
// kHashExtractColumnCost * numRightSideColumns
555+
cost_.unitCost =
556+
3 * Costs::kKeyCompareCost * numKeys * std::min<float>(1, cost_.fanout) +
557+
rightSideBytes + Costs::kHashExtractColumnCost * numRightSideColumns;
558+
}
559+
535560
namespace {
536561
std::pair<std::string, std::string> joinKeysString(
537562
const ExprVector& left,

axiom/optimizer/RelationOp.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,8 @@ struct Join : public RelationOp {
491491
float fanout,
492492
float innerFanout, // The fanout if this were an inner join
493493
ColumnVector columns,
494-
PlanState& state);
494+
PlanState& state,
495+
ExprVector discreteJoinFilter = {});
495496

496497
static Join* makeCrossJoin(
497498
RelationOpPtr input,
@@ -505,6 +506,7 @@ struct Join : public RelationOp {
505506
const ExprVector leftKeys;
506507
const ExprVector rightKeys;
507508
const ExprVector filter;
509+
const ExprVector discreteJoinFilter;
508510

509511
const QGString& historyKey() const override;
510512

@@ -513,6 +515,9 @@ struct Join : public RelationOp {
513515
void accept(
514516
const RelationOpVisitor& visitor,
515517
RelationOpVisitorContext& context) const override;
518+
519+
private:
520+
void setMergeJoinCost();
516521
};
517522

518523
using JoinCP = const Join*;

axiom/optimizer/Schema.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,15 @@ std::string Value::toString() const {
101101
ColumnGroupCP SchemaTable::addIndex(
102102
const connector::TableLayout& layout,
103103
Distribution distribution,
104-
ColumnVector columns) {
104+
ColumnVector columns,
105+
ColumnVector lookupColumns) {
105106
return columnGroups.emplace_back(
106107
make<ColumnGroup>(
107-
*this, layout, std::move(distribution), std::move(columns)));
108+
*this,
109+
layout,
110+
std::move(distribution),
111+
std::move(columns),
112+
std::move(lookupColumns)));
108113
}
109114

110115
ColumnCP SchemaTable::findColumn(Name name) const {
@@ -187,7 +192,15 @@ SchemaTableCP Schema::findTable(
187192

188193
ColumnVector columns;
189194
appendColumns(layout->columns(), columns);
190-
schemaTable->addIndex(*layout, std::move(distribution), std::move(columns));
195+
196+
ColumnVector lookupColumns;
197+
appendColumns(layout->lookupKeys(), lookupColumns);
198+
199+
schemaTable->addIndex(
200+
*layout,
201+
std::move(distribution),
202+
std::move(columns),
203+
std::move(lookupColumns));
191204
}
192205
table = {std::move(connectorTable), schemaTable};
193206
return schemaTable;
@@ -281,26 +294,27 @@ IndexInfo SchemaTable::indexInfo(
281294

282295
const auto& distribution = index->distribution;
283296

284-
const auto numSorting = distribution.orderTypes.size();
297+
const auto numLookupKeys = index->lookupColumns.size();
285298
const auto numUnique = distribution.numKeysUnique;
286299

287300
PlanObjectSet covered;
288-
for (auto i = 0; i < numSorting || i < numUnique; ++i) {
289-
auto orderKey = distribution.orderKeys[i];
290-
auto part = findColumnByName(columnsSpan, orderKey->as<Column>()->name());
301+
for (auto i = 0; i < numLookupKeys || i < numUnique; ++i) {
302+
ExprCP lookupKey =
303+
i < numLookupKeys ? index->lookupColumns[i] : distribution.orderKeys[i];
304+
auto part = findColumnByName(columnsSpan, lookupKey->as<Column>()->name());
291305
if (!part) {
292306
break;
293307
}
294308

295309
covered.add(part);
296-
if (i < numSorting) {
310+
if (i < numLookupKeys) {
297311
info.scanCardinality =
298-
combine(info.scanCardinality, i, orderKey->value().cardinality);
312+
combine(info.scanCardinality, i, lookupKey->value().cardinality);
299313
info.lookupKeys.push_back(part);
300314
info.joinCardinality = info.scanCardinality;
301315
} else {
302316
info.joinCardinality =
303-
combine(info.joinCardinality, i, orderKey->value().cardinality);
317+
combine(info.joinCardinality, i, lookupKey->value().cardinality);
304318
}
305319
if (i == numUnique - 1) {
306320
info.unique = true;

axiom/optimizer/Schema.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,19 @@ struct ColumnGroup {
231231
const SchemaTable& table,
232232
const connector::TableLayout& layout,
233233
Distribution distribution,
234-
ColumnVector columns)
234+
ColumnVector columns,
235+
ColumnVector lookupColumns)
235236
: table{&table},
236237
layout{&layout},
237238
distribution{std::move(distribution)},
238-
columns{std::move(columns)} {}
239+
columns{std::move(columns)},
240+
lookupColumns{std::move(lookupColumns)} {}
239241

240242
SchemaTableCP table;
241243
const connector::TableLayout* layout;
242244
const Distribution distribution;
243245
const ColumnVector columns;
246+
const ColumnVector lookupColumns;
244247

245248
/// Returns cost of next lookup when the hit is within 'range' rows
246249
/// of the previous hit. If lookups are not batched or not ordered,
@@ -302,7 +305,8 @@ struct SchemaTable {
302305
ColumnGroupCP addIndex(
303306
const connector::TableLayout& layout,
304307
Distribution distribution,
305-
ColumnVector columns);
308+
ColumnVector columns,
309+
ColumnVector lookupColumns);
306310

307311
ColumnCP findColumn(Name name) const;
308312

axiom/optimizer/ToVelox.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,10 @@ velox::core::PlanNodePtr ToVelox::makeJoin(
11121112
nextId(), toAnd(join.filter), joinNode);
11131113
}
11141114

1115+
if (join.method == JoinMethod::kMerge) {
1116+
return makeMergeJoin(join, fragment, stages, left, right);
1117+
}
1118+
11151119
auto leftKeys = toFieldRefs(join.leftKeys);
11161120
auto rightKeys = toFieldRefs(join.rightKeys);
11171121

@@ -1130,6 +1134,29 @@ velox::core::PlanNodePtr ToVelox::makeJoin(
11301134
return joinNode;
11311135
}
11321136

1137+
velox::core::PlanNodePtr ToVelox::makeMergeJoin(
1138+
const Join& join,
1139+
runner::ExecutableFragment& fragment,
1140+
std::vector<runner::ExecutableFragment>& stages,
1141+
velox::core::PlanNodePtr left,
1142+
velox::core::PlanNodePtr right) {
1143+
auto leftKeys = toFieldRefs(join.leftKeys);
1144+
auto rightKeys = toFieldRefs(join.rightKeys);
1145+
1146+
auto joinNode = std::make_shared<velox::core::MergeJoinNode>(
1147+
nextId(),
1148+
join.joinType,
1149+
leftKeys,
1150+
rightKeys,
1151+
toAnd(join.filter),
1152+
left,
1153+
right,
1154+
makeOutputType(join.columns()));
1155+
1156+
makePredictionAndHistory(joinNode->id(), &join);
1157+
return joinNode;
1158+
}
1159+
11331160
velox::core::PlanNodePtr ToVelox::makeUnnest(
11341161
const Unnest& op,
11351162
runner::ExecutableFragment& fragment,

axiom/optimizer/ToVelox.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ class ToVelox {
179179
runner::ExecutableFragment& fragment,
180180
std::vector<runner::ExecutableFragment>& stages);
181181

182+
velox::core::PlanNodePtr makeMergeJoin(
183+
const Join& join,
184+
runner::ExecutableFragment& fragment,
185+
std::vector<runner::ExecutableFragment>& stages,
186+
velox::core::PlanNodePtr left,
187+
velox::core::PlanNodePtr right);
188+
182189
velox::core::PlanNodePtr makeRepartition(
183190
const Repartition& repartition,
184191
runner::ExecutableFragment& fragment,

axiom/optimizer/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ add_executable(
120120
HiveLimitQueriesTest.cpp
121121
HiveQueriesTest.cpp
122122
JoinTest.cpp
123+
OrderedOpsTest.cpp
123124
ParquetTpchTest.cpp
124125
PrecomputeProjectionTest.cpp
125126
PlanTest.cpp

0 commit comments

Comments
 (0)