@@ -1674,6 +1674,302 @@ void Optimization::joinByHashRight(
16741674 state.addNextJoin (&candidate, join, toTry);
16751675}
16761676
1677+ void Optimization::joinByMerge (
1678+ const RelationOpPtr& plan,
1679+ const JoinCandidate& candidate,
1680+ PlanState& state,
1681+ std::vector<NextJoin>& toTry) {
1682+ checkTables (candidate);
1683+
1684+ // Merge join requires: [right, left] assignment
1685+ auto [right, left] = candidate.joinSides ();
1686+
1687+ // Check if all join keys are columns (not expressions)
1688+ for (auto * key : left.keys ) {
1689+ if (!key->isColumn ()) {
1690+ return ; // Cannot do merge join with non-column keys
1691+ }
1692+ }
1693+ for (auto * key : right.keys ) {
1694+ if (!key->isColumn ()) {
1695+ return ; // Cannot do merge join with non-column keys
1696+ }
1697+ }
1698+
1699+ // Check if left side (plan) has partitioning and ordering
1700+ const auto & distribution = plan->distribution ();
1701+ if (distribution.partition .empty () || distribution.orderKeys .empty ()) {
1702+ return ; // Need both partitioning and ordering for merge join
1703+ }
1704+
1705+ // Check that all ordering has ascending order
1706+ for (const auto & orderType : distribution.orderTypes ) {
1707+ if (orderType != OrderType::kAscNullsFirst &&
1708+ orderType != OrderType::kAscNullsLast ) {
1709+ return ; // Merge join requires ascending order
1710+ }
1711+ }
1712+
1713+ // Check if all left partition columns are in left.keys
1714+ ExprVector leftKeyExprs;
1715+ for (auto * col : left.keys ) {
1716+ leftKeyExprs.push_back (col);
1717+ }
1718+
1719+ PlanObjectSet leftKeySet;
1720+ leftKeySet.unionObjects (leftKeyExprs);
1721+
1722+ PlanObjectSet partitionSet;
1723+ partitionSet.unionObjects (distribution.partition );
1724+
1725+ if (!partitionSet.isSubset (leftKeySet)) {
1726+ return ; // Not all partition columns are in join keys
1727+ }
1728+
1729+ // Find the left keys that correspond to ordering columns
1730+ ColumnVector leftMergeColumns;
1731+ for (const auto & orderKey : distribution.orderKeys ) {
1732+ // Find matching column in left.keys
1733+ ColumnCP matchingKey = nullptr ;
1734+ for (size_t i = 0 ; i < left.keys .size (); ++i) {
1735+ if (left.keys [i] == orderKey) {
1736+ matchingKey = dynamic_cast <ColumnCP>(left.keys [i]);
1737+ break ;
1738+ }
1739+ }
1740+
1741+ if (!matchingKey) {
1742+ break ; // Stop looking if we don't find a match
1743+ }
1744+
1745+ leftMergeColumns.push_back (matchingKey);
1746+ }
1747+
1748+ if (leftMergeColumns.empty ()) {
1749+ return ; // No ordering columns match join keys
1750+ }
1751+
1752+ // At this point, left input (plan) is a candidate for merge join
1753+ PlanStateSaver save (state, candidate);
1754+
1755+ // Make the right side plan similar to hash join build side
1756+ PlanObjectSet rightFilterColumns;
1757+ rightFilterColumns.unionColumns (candidate.join ->filter ());
1758+ rightFilterColumns.intersect (availableColumns (candidate.tables [0 ]));
1759+
1760+ PlanObjectSet rightTables;
1761+ PlanObjectSet rightColumns;
1762+ for (auto rightTable : candidate.tables ) {
1763+ rightColumns.unionSet (availableColumns (rightTable));
1764+ rightTables.add (rightTable);
1765+ }
1766+
1767+ auto joinColumnMapping = makeJoinColumnMapping (candidate.join );
1768+
1769+ state.placed .unionSet (rightTables);
1770+ rightColumns.intersect (
1771+ translateToJoinInput (state.downstreamColumns (), joinColumnMapping));
1772+
1773+ rightColumns.unionColumns (right.keys );
1774+ rightColumns.unionSet (rightFilterColumns);
1775+ state.columns .unionSet (rightColumns);
1776+
1777+ MemoKey memoKey{
1778+ candidate.tables [0 ], rightColumns, rightTables, candidate.existences };
1779+
1780+ // Distribution for right side matching left partition and ordering
1781+ ExprVector rightPartition;
1782+ for (const auto & leftPartCol : distribution.partition ) {
1783+ // Find corresponding right key
1784+ for (size_t i = 0 ; i < left.keys .size (); ++i) {
1785+ if (left.keys [i] == leftPartCol) {
1786+ rightPartition.push_back (right.keys [i]);
1787+ break ;
1788+ }
1789+ }
1790+ }
1791+
1792+ ExprVector rightOrderKeys;
1793+ OrderTypeVector rightOrderTypes;
1794+ for (const auto * leftMergeCol : leftMergeColumns) {
1795+ // Find corresponding right key
1796+ for (size_t i = 0 ; i < left.keys .size (); ++i) {
1797+ if (left.keys [i] == leftMergeCol) {
1798+ rightOrderKeys.push_back (right.keys [i]);
1799+ // Use ascending order to match left side (default to kAscNullsLast)
1800+ rightOrderTypes.push_back (OrderType::kAscNullsLast );
1801+ break ;
1802+ }
1803+ }
1804+ }
1805+
1806+ Distribution forRight{
1807+ distribution.distributionType ,
1808+ rightPartition,
1809+ rightOrderKeys,
1810+ rightOrderTypes};
1811+
1812+ PlanObjectSet empty;
1813+ bool needsShuffle = false ;
1814+ auto rightPlan = makePlan (
1815+ *state.dt ,
1816+ memoKey,
1817+ forRight,
1818+ empty,
1819+ candidate.existsFanout ,
1820+ needsShuffle);
1821+
1822+ PlanState rightState (state.optimization , state.dt , rightPlan);
1823+ RelationOpPtr rightInput = rightPlan->op ;
1824+ RelationOpPtr probeInput = plan;
1825+
1826+ // Handle shuffle and sort for right side
1827+ if (!isSingleWorker_) {
1828+ if (needsShuffle) {
1829+ // Add shuffle
1830+ Distribution distribution{
1831+ plan->distribution ().distributionType , rightPartition};
1832+ auto * repartition = make<Repartition>(
1833+ rightInput, std::move (distribution), rightInput->columns ());
1834+ rightState.addCost (*repartition);
1835+ rightInput = repartition;
1836+
1837+ // Add sort after shuffle
1838+ PrecomputeProjection precomputeSort (rightInput, state.dt );
1839+ auto orderKeys = precomputeSort.toColumns (rightOrderKeys);
1840+ rightInput = std::move (precomputeSort).maybeProject (rightState);
1841+
1842+ auto * orderBy = make<OrderBy>(
1843+ rightInput,
1844+ std::move (orderKeys),
1845+ rightOrderTypes,
1846+ /* limit=*/ -1 ,
1847+ /* offset=*/ 0 );
1848+ rightState.addCost (*orderBy);
1849+ rightInput = orderBy;
1850+ } else {
1851+ // Check if ordering columns match expected
1852+ const auto & rightDist = rightInput->distribution ();
1853+ bool needsSort = rightDist.orderKeys .size () != rightOrderKeys.size ();
1854+ if (!needsSort) {
1855+ for (size_t i = 0 ; i < rightOrderKeys.size (); ++i) {
1856+ if (rightDist.orderKeys [i] != rightOrderKeys[i]) {
1857+ needsSort = true ;
1858+ break ;
1859+ }
1860+ }
1861+ }
1862+
1863+ if (needsSort) {
1864+ // Add sort without shuffle
1865+ PrecomputeProjection precomputeSort (rightInput, state.dt );
1866+ auto orderKeys = precomputeSort.toColumns (rightOrderKeys);
1867+ rightInput = std::move (precomputeSort).maybeProject (rightState);
1868+
1869+ auto * orderBy = make<OrderBy>(
1870+ rightInput,
1871+ std::move (orderKeys),
1872+ rightOrderTypes,
1873+ /* limit=*/ -1 ,
1874+ /* offset=*/ 0 );
1875+ rightState.addCost (*orderBy);
1876+ rightInput = orderBy;
1877+ }
1878+ }
1879+ }
1880+
1881+ PrecomputeProjection precomputeRight (rightInput, state.dt );
1882+ auto rightKeys = precomputeRight.toColumns (right.keys );
1883+ rightInput = std::move (precomputeRight).maybeProject (rightState);
1884+
1885+ auto joinType = right.leftJoinType ();
1886+ const bool probeOnly = joinType == velox::core::JoinType::kLeftSemiFilter ||
1887+ joinType == velox::core::JoinType::kLeftSemiProject ||
1888+ joinType == velox::core::JoinType::kAnti ;
1889+
1890+ PlanObjectSet probeColumns;
1891+ probeColumns.unionObjects (plan->columns ());
1892+
1893+ ColumnCP mark = nullptr ;
1894+
1895+ auto * joinEdge = candidate.join ;
1896+
1897+ PlanObjectSet joinColumns;
1898+ joinColumns.unionObjects (joinEdge->leftColumns ());
1899+ joinColumns.unionObjects (joinEdge->rightColumns ());
1900+
1901+ ProjectionBuilder projectionBuilder;
1902+ bool needsProjection = false ;
1903+
1904+ state.downstreamColumns ().forEach <Column>([&](auto column) {
1905+ if (column == right.markColumn ) {
1906+ mark = column;
1907+ return ;
1908+ }
1909+
1910+ if (joinColumns.contains (column)) {
1911+ projectionBuilder.add (column, joinColumnMapping.at (column));
1912+ needsProjection = true ;
1913+ return ;
1914+ }
1915+
1916+ if ((probeOnly || !rightColumns.contains (column)) &&
1917+ !probeColumns.contains (column)) {
1918+ return ;
1919+ }
1920+
1921+ projectionBuilder.add (column, column);
1922+ });
1923+
1924+ if (mark) {
1925+ setMarkTrueFraction (
1926+ mark, joinType, candidate.fanout , candidate.join ->rlFanout ());
1927+ }
1928+
1929+ tryOptimizeSemiProject (joinType, mark, state, negation_);
1930+
1931+ if (mark) {
1932+ projectionBuilder.add (mark, mark);
1933+ }
1934+
1935+ state.columns = projectionBuilder.outputColumns ();
1936+
1937+ const auto fanout = fanoutJoinTypeLimit (
1938+ joinType,
1939+ candidate.fanout ,
1940+ candidate.join ->rlFanout (),
1941+ rightState.cost .cardinality / state.cost .cardinality );
1942+
1943+ PrecomputeProjection precomputeProbe (probeInput, state.dt );
1944+ auto probeKeys = precomputeProbe.toColumns (left.keys );
1945+ probeInput = std::move (precomputeProbe).maybeProject (state);
1946+
1947+ const auto innerFanout = candidate.fanout ;
1948+
1949+ // Create merge join
1950+ RelationOp* join = make<Join>(
1951+ JoinMethod::kMerge ,
1952+ joinType,
1953+ probeInput,
1954+ rightInput,
1955+ std::move (probeKeys),
1956+ std::move (rightKeys),
1957+ candidate.join ->filter (),
1958+ fanout,
1959+ innerFanout,
1960+ projectionBuilder.inputColumns (),
1961+ state);
1962+
1963+ state.addCost (*join);
1964+ state.cost .cost += rightState.cost .cost ;
1965+
1966+ if (needsProjection) {
1967+ join = projectionBuilder.build (join, state);
1968+ }
1969+
1970+ state.addNextJoin (&candidate, join, toTry);
1971+ }
1972+
16771973void Optimization::crossJoin (
16781974 const RelationOpPtr& plan,
16791975 const JoinCandidate& candidate,
@@ -1746,6 +2042,21 @@ void Optimization::addJoin(
17462042 std::vector<NextJoin> toTry;
17472043 joinByIndex (plan, candidate, state, toTry);
17482044
2045+ // For testing: if testingUseMergeJoin is false, skip merge join entirely.
2046+ if (!options_.testingUseMergeJoin .has_value () ||
2047+ options_.testingUseMergeJoin .value ()) {
2048+ const auto sizeBeforeMerge = toTry.size ();
2049+ joinByMerge (plan, candidate, state, toTry);
2050+
2051+ // For testing: if testingUseMergeJoin is true and joinByMerge produced a
2052+ // result, return immediately without trying other join methods.
2053+ if (options_.testingUseMergeJoin .has_value () &&
2054+ options_.testingUseMergeJoin .value () && toTry.size () > sizeBeforeMerge) {
2055+ result.insert (result.end (), toTry.begin (), toTry.end ());
2056+ return ;
2057+ }
2058+ }
2059+
17492060 const auto sizeAfterIndex = toTry.size ();
17502061 joinByHash (plan, candidate, state, toTry);
17512062
0 commit comments