diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 4091bbc53..eec03df73 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -23,12 +23,14 @@ Filter, Index, Isin, + MaybeAlignPartitions, PartitionsFiltered, Projection, Unaryop, determine_column_projection, is_filter_pushdown_available, ) +from dask_expr._reductions import Reduction from dask_expr._repartition import Repartition from dask_expr._shuffle import ( RearrangeByColumn, @@ -36,6 +38,7 @@ _select_columns_or_index, ) from dask_expr._util import _convert_to_list, _tokenize_deterministic, is_scalar +from dask_expr.io import IO _HASH_COLUMN_NAME = "__hash_partition" _PARTITION_COLUMN = "_partitions" @@ -135,11 +138,59 @@ def _meta(self): kwargs["how"] = "left" return make_meta(left.merge(right, **kwargs)) + def _find_partition_changer(self, expr): + # Look for an operation that reorganizes the number of partitions + # We ignore Blockwise stuff and reducers + stack = [expr] + seen = set() + result_nodes = [] + while stack: + node = stack.pop() + if node._name in seen: + continue + seen.add(node._name) + + if isinstance(node, Reduction): + continue + elif node.ndim == 0 or node.npartitions == 1: + continue + elif isinstance(node, IO): + return node + elif isinstance(node, (Blockwise, MaybeAlignPartitions)): + stack.extend(node.dependencies()) + continue + + result_nodes.append(node) + if len(result_nodes): + # The node with the maximum number of partitions will most likely have + # dominated the resulting partition count + return list(sorted(result_nodes, key=lambda x: x.npartitions))[-1] + return expr + @functools.cached_property def _npartitions(self): if self.operand("_npartitions") is not None: return self.operand("_npartitions") - return max(self.left.npartitions, self.right.npartitions) + if min(self.left.npartitions, self.right.npartitions) == 1: + return max(self.left.npartitions, self.right.npartitions) + if self.left.npartitions <= self.right.npartitions: + df_lower = self.left + df_higher = self.right + merge_base_columns = self._find_partition_changer(self.right).columns + else: + df_lower = self.right + df_higher = self.left + merge_base_columns = self._find_partition_changer(self.left).columns + npartitions = df_higher.npartitions + common_merge_columns = [] + if self.left_on is not None and self.right_on is not None: + common_merge_columns = set(_convert_to_list(self.left_on)) & set( + _convert_to_list(self.right_on) + ) + factor = ( + len(df_lower.columns) + len(df_higher.columns) - len(common_merge_columns) + ) / len(merge_base_columns) + return int(math.floor(npartitions * factor)) @property def _bcast_left(self): @@ -796,6 +847,12 @@ class BlockwiseMerge(Merge, Blockwise): is_broadcast_join = False + @functools.cached_property + def _npartitions(self): + if self.operand("_npartitions") is not None: + return self.operand("_npartitions") + return max(self.left.npartitions, self.right.npartitions) + def _divisions(self): if self.left.npartitions == self.right.npartitions: return super()._divisions() diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index 4c1813eac..281cb92e4 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -51,7 +51,7 @@ async def test_merge_p2p_shuffle(c, s, a, b, npartitions_left): right = from_pandas(df_right, npartitions=5) out = left.merge(right, shuffle_method="p2p") - assert out.npartitions == npartitions_left + assert out.npartitions == 8 x = c.compute(out) x = await x pd.testing.assert_frame_equal(x.reset_index(drop=True), df_left.merge(df_right)) @@ -88,7 +88,7 @@ async def test_merge_index_precedence(c, s, a, b, shuffle, name): result = df.join(df2, shuffle_method=shuffle) x = await c.compute(result) - assert result.npartitions == 3 + assert result.npartitions == 6 pd.testing.assert_frame_equal(x.sort_index(ascending=False), pdf.join(pdf2)) @@ -222,7 +222,7 @@ async def test_index_merge_p2p_shuffle(c, s, a, b, npartitions_left): right = from_pandas(df_right, npartitions=5) out = left.merge(right, left_index=True, right_on="a", shuffle_method="p2p") - assert out.npartitions == npartitions_left + assert out.npartitions == 7 if npartitions_left == 5 else 18 x = c.compute(out) x = await x pd.testing.assert_frame_equal( @@ -239,7 +239,7 @@ async def test_merge_p2p_shuffle(c, s, a, b): right = from_pandas(df_right, npartitions=5) out = left.merge(right, shuffle_method="p2p")[["b", "c"]] - assert out.npartitions == 6 + assert out.npartitions == 8 x = c.compute(out) x = await x pd.testing.assert_frame_equal( diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 048ab3a15..56faa0fd9 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -231,8 +231,8 @@ def test_merge_combine_similar(npartitions_left, npartitions_right): query["new"] = query.b + query.c query = query.groupby(["a", "e", "x"]).new.sum() assert ( - len(query.optimize().__dask_graph__()) <= 25 - ) # 45 is the non-combined version + len(query.optimize().__dask_graph__()) <= 37 + ) # the non-combined version is higher expected = pdf.merge(pdf2) expected["new"] = expected.b + expected.c @@ -899,3 +899,43 @@ def test_merge_leftsemi(): df2 = from_pandas(pdf2, npartitions=2) with pytest.raises(NotImplementedError, match="on columns from the index"): df1.merge(df2, how="leftsemi", on="aa") + + +def test_merge_npartitions_adjustment(): + pdf1 = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 2, "b": 1, "c": 1, "d": 1} + ) + pdf2 = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 2, "b": 1, "x": 1, "y": 1} + ) + pdf3 = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 2, "b": 1, "m": 1, "n": 1} + ) + df1 = from_pandas(pdf1, npartitions=10) + df2 = from_pandas(pdf2, npartitions=10) + df3 = from_pandas(pdf3, npartitions=10) + result = df1.merge(df2, on="a") + assert result.optimize().npartitions == 17 + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = df1.merge(df2, left_on=["a", "c"], right_on=["b", "x"]) + assert result.optimize().npartitions == 20 + + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = result.dropna() # block projections + result = result[["a", "b"]].merge(df3) + assert result.optimize().npartitions == 10 + + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = result.dropna() # block projections + result = result[["a", "b", "x", "y"]].merge(df3) + assert result.optimize().npartitions == 15 + + result = df1.merge(df2) + assert result.optimize().npartitions == 15 + result = result.dropna() # block projections + result = result + result.a.sum() + result = result[["a", "b", "x", "y"]].merge(df3) + assert result.optimize().npartitions == 15