Skip to content

Commit 20c64b9

Browse files
oerlingmeta-codesync[bot]
authored andcommitted
feature: Add merge join support
Differential Revision: D89875337
1 parent 7827933 commit 20c64b9

15 files changed

+796
-25
lines changed

axiom/optimizer/Optimization.cpp

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,302 @@ void Optimization::joinByHashRight(
16741674
state.addNextJoin(&candidate, join, toTry);
16751675
}
16761676

1677+
void Optimization::joinByMerge(
1678+
const RelationOpPtr& plan,
1679+
const JoinCandidate& candidate,
1680+
PlanState& state,
1681+
std::vector<NextJoin>& toTry) {
1682+
checkTables(candidate);
1683+
1684+
// Merge join requires: [right, left] assignment
1685+
auto [right, left] = candidate.joinSides();
1686+
1687+
// Check if all join keys are columns (not expressions)
1688+
for (auto* key : left.keys) {
1689+
if (!key->isColumn()) {
1690+
return; // Cannot do merge join with non-column keys
1691+
}
1692+
}
1693+
for (auto* key : right.keys) {
1694+
if (!key->isColumn()) {
1695+
return; // Cannot do merge join with non-column keys
1696+
}
1697+
}
1698+
1699+
// Check if left side (plan) has partitioning and ordering
1700+
const auto& distribution = plan->distribution();
1701+
if (distribution.partition.empty() || distribution.orderKeys.empty()) {
1702+
return; // Need both partitioning and ordering for merge join
1703+
}
1704+
1705+
// Check that all ordering has ascending order
1706+
for (const auto& orderType : distribution.orderTypes) {
1707+
if (orderType != OrderType::kAscNullsFirst &&
1708+
orderType != OrderType::kAscNullsLast) {
1709+
return; // Merge join requires ascending order
1710+
}
1711+
}
1712+
1713+
// Check if all left partition columns are in left.keys
1714+
ExprVector leftKeyExprs;
1715+
for (auto* col : left.keys) {
1716+
leftKeyExprs.push_back(col);
1717+
}
1718+
1719+
PlanObjectSet leftKeySet;
1720+
leftKeySet.unionObjects(leftKeyExprs);
1721+
1722+
PlanObjectSet partitionSet;
1723+
partitionSet.unionObjects(distribution.partition);
1724+
1725+
if (!partitionSet.isSubset(leftKeySet)) {
1726+
return; // Not all partition columns are in join keys
1727+
}
1728+
1729+
// Find the left keys that correspond to ordering columns
1730+
ColumnVector leftMergeColumns;
1731+
for (const auto& orderKey : distribution.orderKeys) {
1732+
// Find matching column in left.keys
1733+
ColumnCP matchingKey = nullptr;
1734+
for (size_t i = 0; i < left.keys.size(); ++i) {
1735+
if (left.keys[i] == orderKey) {
1736+
matchingKey = dynamic_cast<ColumnCP>(left.keys[i]);
1737+
break;
1738+
}
1739+
}
1740+
1741+
if (!matchingKey) {
1742+
break; // Stop looking if we don't find a match
1743+
}
1744+
1745+
leftMergeColumns.push_back(matchingKey);
1746+
}
1747+
1748+
if (leftMergeColumns.empty()) {
1749+
return; // No ordering columns match join keys
1750+
}
1751+
1752+
// At this point, left input (plan) is a candidate for merge join
1753+
PlanStateSaver save(state, candidate);
1754+
1755+
// Make the right side plan similar to hash join build side
1756+
PlanObjectSet rightFilterColumns;
1757+
rightFilterColumns.unionColumns(candidate.join->filter());
1758+
rightFilterColumns.intersect(availableColumns(candidate.tables[0]));
1759+
1760+
PlanObjectSet rightTables;
1761+
PlanObjectSet rightColumns;
1762+
for (auto rightTable : candidate.tables) {
1763+
rightColumns.unionSet(availableColumns(rightTable));
1764+
rightTables.add(rightTable);
1765+
}
1766+
1767+
auto joinColumnMapping = makeJoinColumnMapping(candidate.join);
1768+
1769+
state.placed.unionSet(rightTables);
1770+
rightColumns.intersect(
1771+
translateToJoinInput(state.downstreamColumns(), joinColumnMapping));
1772+
1773+
rightColumns.unionColumns(right.keys);
1774+
rightColumns.unionSet(rightFilterColumns);
1775+
state.columns.unionSet(rightColumns);
1776+
1777+
MemoKey memoKey{
1778+
candidate.tables[0], rightColumns, rightTables, candidate.existences};
1779+
1780+
// Distribution for right side matching left partition and ordering
1781+
ExprVector rightPartition;
1782+
for (const auto& leftPartCol : distribution.partition) {
1783+
// Find corresponding right key
1784+
for (size_t i = 0; i < left.keys.size(); ++i) {
1785+
if (left.keys[i] == leftPartCol) {
1786+
rightPartition.push_back(right.keys[i]);
1787+
break;
1788+
}
1789+
}
1790+
}
1791+
1792+
ExprVector rightOrderKeys;
1793+
OrderTypeVector rightOrderTypes;
1794+
for (const auto* leftMergeCol : leftMergeColumns) {
1795+
// Find corresponding right key
1796+
for (size_t i = 0; i < left.keys.size(); ++i) {
1797+
if (left.keys[i] == leftMergeCol) {
1798+
rightOrderKeys.push_back(right.keys[i]);
1799+
// Use ascending order to match left side (default to kAscNullsLast)
1800+
rightOrderTypes.push_back(OrderType::kAscNullsLast);
1801+
break;
1802+
}
1803+
}
1804+
}
1805+
1806+
Distribution forRight{
1807+
distribution.distributionType,
1808+
rightPartition,
1809+
rightOrderKeys,
1810+
rightOrderTypes};
1811+
1812+
PlanObjectSet empty;
1813+
bool needsShuffle = false;
1814+
auto rightPlan = makePlan(
1815+
*state.dt,
1816+
memoKey,
1817+
forRight,
1818+
empty,
1819+
candidate.existsFanout,
1820+
needsShuffle);
1821+
1822+
PlanState rightState(state.optimization, state.dt, rightPlan);
1823+
RelationOpPtr rightInput = rightPlan->op;
1824+
RelationOpPtr probeInput = plan;
1825+
1826+
// Handle shuffle and sort for right side
1827+
if (!isSingleWorker_) {
1828+
if (needsShuffle) {
1829+
// Add shuffle
1830+
Distribution distribution{
1831+
plan->distribution().distributionType, rightPartition};
1832+
auto* repartition = make<Repartition>(
1833+
rightInput, std::move(distribution), rightInput->columns());
1834+
rightState.addCost(*repartition);
1835+
rightInput = repartition;
1836+
1837+
// Add sort after shuffle
1838+
PrecomputeProjection precomputeSort(rightInput, state.dt);
1839+
auto orderKeys = precomputeSort.toColumns(rightOrderKeys);
1840+
rightInput = std::move(precomputeSort).maybeProject(rightState);
1841+
1842+
auto* orderBy = make<OrderBy>(
1843+
rightInput,
1844+
std::move(orderKeys),
1845+
rightOrderTypes,
1846+
/*limit=*/-1,
1847+
/*offset=*/0);
1848+
rightState.addCost(*orderBy);
1849+
rightInput = orderBy;
1850+
} else {
1851+
// Check if ordering columns match expected
1852+
const auto& rightDist = rightInput->distribution();
1853+
bool needsSort = rightDist.orderKeys.size() != rightOrderKeys.size();
1854+
if (!needsSort) {
1855+
for (size_t i = 0; i < rightOrderKeys.size(); ++i) {
1856+
if (rightDist.orderKeys[i] != rightOrderKeys[i]) {
1857+
needsSort = true;
1858+
break;
1859+
}
1860+
}
1861+
}
1862+
1863+
if (needsSort) {
1864+
// Add sort without shuffle
1865+
PrecomputeProjection precomputeSort(rightInput, state.dt);
1866+
auto orderKeys = precomputeSort.toColumns(rightOrderKeys);
1867+
rightInput = std::move(precomputeSort).maybeProject(rightState);
1868+
1869+
auto* orderBy = make<OrderBy>(
1870+
rightInput,
1871+
std::move(orderKeys),
1872+
rightOrderTypes,
1873+
/*limit=*/-1,
1874+
/*offset=*/0);
1875+
rightState.addCost(*orderBy);
1876+
rightInput = orderBy;
1877+
}
1878+
}
1879+
}
1880+
1881+
PrecomputeProjection precomputeRight(rightInput, state.dt);
1882+
auto rightKeys = precomputeRight.toColumns(right.keys);
1883+
rightInput = std::move(precomputeRight).maybeProject(rightState);
1884+
1885+
auto joinType = right.leftJoinType();
1886+
const bool probeOnly = joinType == velox::core::JoinType::kLeftSemiFilter ||
1887+
joinType == velox::core::JoinType::kLeftSemiProject ||
1888+
joinType == velox::core::JoinType::kAnti;
1889+
1890+
PlanObjectSet probeColumns;
1891+
probeColumns.unionObjects(plan->columns());
1892+
1893+
ColumnCP mark = nullptr;
1894+
1895+
auto* joinEdge = candidate.join;
1896+
1897+
PlanObjectSet joinColumns;
1898+
joinColumns.unionObjects(joinEdge->leftColumns());
1899+
joinColumns.unionObjects(joinEdge->rightColumns());
1900+
1901+
ProjectionBuilder projectionBuilder;
1902+
bool needsProjection = false;
1903+
1904+
state.downstreamColumns().forEach<Column>([&](auto column) {
1905+
if (column == right.markColumn) {
1906+
mark = column;
1907+
return;
1908+
}
1909+
1910+
if (joinColumns.contains(column)) {
1911+
projectionBuilder.add(column, joinColumnMapping.at(column));
1912+
needsProjection = true;
1913+
return;
1914+
}
1915+
1916+
if ((probeOnly || !rightColumns.contains(column)) &&
1917+
!probeColumns.contains(column)) {
1918+
return;
1919+
}
1920+
1921+
projectionBuilder.add(column, column);
1922+
});
1923+
1924+
if (mark) {
1925+
setMarkTrueFraction(
1926+
mark, joinType, candidate.fanout, candidate.join->rlFanout());
1927+
}
1928+
1929+
tryOptimizeSemiProject(joinType, mark, state, negation_);
1930+
1931+
if (mark) {
1932+
projectionBuilder.add(mark, mark);
1933+
}
1934+
1935+
state.columns = projectionBuilder.outputColumns();
1936+
1937+
const auto fanout = fanoutJoinTypeLimit(
1938+
joinType,
1939+
candidate.fanout,
1940+
candidate.join->rlFanout(),
1941+
rightState.cost.cardinality / state.cost.cardinality);
1942+
1943+
PrecomputeProjection precomputeProbe(probeInput, state.dt);
1944+
auto probeKeys = precomputeProbe.toColumns(left.keys);
1945+
probeInput = std::move(precomputeProbe).maybeProject(state);
1946+
1947+
const auto innerFanout = candidate.fanout;
1948+
1949+
// Create merge join
1950+
RelationOp* join = make<Join>(
1951+
JoinMethod::kMerge,
1952+
joinType,
1953+
probeInput,
1954+
rightInput,
1955+
std::move(probeKeys),
1956+
std::move(rightKeys),
1957+
candidate.join->filter(),
1958+
fanout,
1959+
innerFanout,
1960+
projectionBuilder.inputColumns(),
1961+
state);
1962+
1963+
state.addCost(*join);
1964+
state.cost.cost += rightState.cost.cost;
1965+
1966+
if (needsProjection) {
1967+
join = projectionBuilder.build(join, state);
1968+
}
1969+
1970+
state.addNextJoin(&candidate, join, toTry);
1971+
}
1972+
16771973
void Optimization::crossJoin(
16781974
const RelationOpPtr& plan,
16791975
const JoinCandidate& candidate,
@@ -1746,6 +2042,21 @@ void Optimization::addJoin(
17462042
std::vector<NextJoin> toTry;
17472043
joinByIndex(plan, candidate, state, toTry);
17482044

2045+
// For testing: if testingUseMergeJoin is false, skip merge join entirely.
2046+
if (!options_.testingUseMergeJoin.has_value() ||
2047+
options_.testingUseMergeJoin.value()) {
2048+
const auto sizeBeforeMerge = toTry.size();
2049+
joinByMerge(plan, candidate, state, toTry);
2050+
2051+
// For testing: if testingUseMergeJoin is true and joinByMerge produced a
2052+
// result, return immediately without trying other join methods.
2053+
if (options_.testingUseMergeJoin.has_value() &&
2054+
options_.testingUseMergeJoin.value() && toTry.size() > sizeBeforeMerge) {
2055+
result.insert(result.end(), toTry.begin(), toTry.end());
2056+
return;
2057+
}
2058+
}
2059+
17492060
const auto sizeAfterIndex = toTry.size();
17502061
joinByHash(plan, candidate, state, toTry);
17512062

axiom/optimizer/Optimization.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,16 @@ class Optimization {
314314
PlanState& state,
315315
std::vector<NextJoin>& toTry);
316316

317+
// Adds 'candidate' on top of 'plan' as a merge join. Checks if the left
318+
// input (plan) is partitioned and ordered, and if join keys match the
319+
// ordering. Prepares the right side with appropriate partitioning and
320+
// ordering, adding shuffle and sort operators as needed.
321+
void joinByMerge(
322+
const RelationOpPtr& plan,
323+
const JoinCandidate& candidate,
324+
PlanState& state,
325+
std::vector<NextJoin>& toTry);
326+
317327
void crossJoin(
318328
const RelationOpPtr& plan,
319329
const JoinCandidate& candidate,

axiom/optimizer/OptimizerOptions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ struct OptimizerOptions {
7575
/// partial + final or not.
7676
bool alwaysPlanPartialAggregation = false;
7777

78+
/// For testing: control merge join behavior.
79+
/// - std::nullopt (default): normal cost-based selection among all join types
80+
/// - true: prefer merge joins - return immediately if joinByMerge produces a
81+
/// candidate
82+
/// - false: disable merge joins - skip calling joinByMerge
83+
std::optional<bool> testingUseMergeJoin{std::nullopt};
84+
7885
bool isMapAsStruct(std::string_view table, std::string_view column) const {
7986
if (allMapsAsStruct) {
8087
return true;

0 commit comments

Comments
 (0)