From 838513cfb6e8c72f777ab3f75b0d997df5975b2e Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 15 Jan 2026 04:29:35 -0800 Subject: [PATCH 01/22] feat: add strict option in StreamingRepartition Signed-off-by: machichima --- .../data/_internal/logical/operators/map_operator.py | 2 ++ python/ray/data/_internal/planner/plan_udf_map_op.py | 10 +++++++++- python/ray/data/dataset.py | 2 ++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 0994f1d0d4f9..40ef94c81248 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -426,6 +426,7 @@ def __init__( self, input_op: LogicalOperator, target_num_rows_per_block: int, + strict: bool = False, ): super().__init__( f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block}]", @@ -433,6 +434,7 @@ def __init__( can_modify_num_rows=False, ) self._target_num_rows_per_block = target_num_rows_per_block + self._strict = strict @property def target_num_rows_per_block(self) -> int: diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index c95018f7a180..45bc5fa4abb3 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -195,6 +195,13 @@ def plan_streaming_repartition_op( ) map_transformer = MapTransformer([transform_fn]) + if op._strict: + ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block) + supports_fusion = False + else: + ref_bundler = None + supports_fusion = True + # Disable fusion for streaming repartition with the downstream op. operator = MapOperator.create( map_transformer, @@ -202,7 +209,8 @@ def plan_streaming_repartition_op( data_context, name=op.name, compute_strategy=compute, - ref_bundler=StreamingRepartitionRefBundler(op.target_num_rows_per_block), + ref_bundler=ref_bundler, + supports_fusion=supports_fusion, ray_remote_args=op._ray_remote_args, ray_remote_args_fn=op._ray_remote_args_fn, ) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8cc4f6698352..de814df566c5 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1629,6 +1629,7 @@ def repartition( num_blocks: Optional[int] = None, target_num_rows_per_block: Optional[int] = None, *, + strict: bool = False, shuffle: bool = False, keys: Optional[List[str]] = None, sort: bool = False, @@ -1735,6 +1736,7 @@ def repartition( op = StreamingRepartition( self._logical_plan.dag, target_num_rows_per_block=target_num_rows_per_block, + strict=strict, ) else: op = Repartition( From 08008d13e8fbb35dea35ebd7f7ed9838e0984ce6 Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 15 Jan 2026 05:47:04 -0800 Subject: [PATCH 02/22] feat: enable fusion for non-strict mode Signed-off-by: machichima --- .../logical/rules/operator_fusion.py | 57 +++++++++++++------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 96f94d135dca..aca573728860 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -274,20 +274,24 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # only allow fusion of MapBatches -> StreamingRepartition if isinstance(down_logical_op, StreamingRepartition): - return ( + if not ( isinstance(up_logical_op, MapBatches) and up_logical_op._batch_size is not None and down_logical_op.target_num_rows_per_block is not None and down_logical_op.target_num_rows_per_block > 0 - # When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks. - # See `_fuse_streaming_repartition_operators_in_dag` docstring for details. - # TODO: when the StreamingRepartition supports none_strict_mode, we can fuse - # `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are. - # https://anyscale1.atlassian.net/browse/DATA-1731 - and up_logical_op._batch_size - % down_logical_op.target_num_rows_per_block - == 0 - ) + ): + return False + + # Non-strict mode: can always fuse, no matter what batch_size is. + # This allows fusion without cross-task buffering by using default bundler. + if not down_logical_op.strict: + return True + + # Strict mode: only fuse when batch_size is a multiple of target_num_rows_per_block. + # When batch_size % target == 0, each batch can be perfectly sliced into chunks + # without cross-task buffering. See `_fuse_streaming_repartition_operators_in_dag` + # docstring for details. + return up_logical_op._batch_size % down_logical_op.target_num_rows_per_block == 0 # Other operators cannot fuse with StreamingRepartition. if isinstance(up_logical_op, StreamingRepartition): return False @@ -309,11 +313,32 @@ def _get_fused_streaming_repartition_operator( up_logical_op = self._op_map.pop(up_op) assert isinstance(up_logical_op, MapBatches) assert isinstance(down_logical_op, StreamingRepartition) - assert ( - up_logical_op._batch_size % down_logical_op.target_num_rows_per_block == 0 - ) + batch_size = up_logical_op._batch_size + # Choose ref_bundler and fusion behavior based on strict mode + if down_logical_op.strict: + # Strict mode: use StreamingRepartitionRefBundler for stitching. + # Only works when batch_size % target == 0 (verified in _can_fuse). + assert ( + batch_size % down_logical_op.target_num_rows_per_block == 0 + ), ( + f"Strict mode fusion requires batch_size ({batch_size}) to be " + f"a multiple of target_num_rows_per_block " + f"({down_logical_op.target_num_rows_per_block})" + ) + ref_bundler = StreamingRepartitionRefBundler(batch_size) + # No further fusion because StreamingRepartitionRefBundler is stateful + # and maintains internal buffering state across bundles. + supports_fusion = False + else: + # Non-strict mode: use default pass-through bundler. + # Works with any batch_size without cross-task buffering. + ref_bundler = None + # Can fuse further because the default bundler is stateless + # and processes each bundle independently. + supports_fusion = True + compute = self._fuse_compute_strategy( up_logical_op._compute, down_logical_op._compute ) @@ -336,13 +361,11 @@ def _get_fused_streaming_repartition_operator( up_op.data_context, name=name, compute_strategy=compute, - ref_bundler=StreamingRepartitionRefBundler(batch_size), + ref_bundler=ref_bundler, map_task_kwargs=map_task_kwargs, ray_remote_args=ray_remote_args, ray_remote_args_fn=ray_remote_args_fn, - # For now, we don't want to over-fuse StreamingRepartition with other map operators, - # so the result operator does not support further fusion. - supports_fusion=False, + supports_fusion=supports_fusion, ) op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators) for map_task_kwargs_fn in itertools.chain( From 591b00fcccbcaacf9431efd986ba966a75990e4e Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 19:27:46 +0800 Subject: [PATCH 03/22] fix: .strict to ._strict Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/operator_fusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index aca573728860..3062ca439fc1 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -284,7 +284,7 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # Non-strict mode: can always fuse, no matter what batch_size is. # This allows fusion without cross-task buffering by using default bundler. - if not down_logical_op.strict: + if not down_logical_op._strict: return True # Strict mode: only fuse when batch_size is a multiple of target_num_rows_per_block. @@ -317,7 +317,7 @@ def _get_fused_streaming_repartition_operator( batch_size = up_logical_op._batch_size # Choose ref_bundler and fusion behavior based on strict mode - if down_logical_op.strict: + if down_logical_op._strict: # Strict mode: use StreamingRepartitionRefBundler for stitching. # Only works when batch_size % target == 0 (verified in _can_fuse). assert ( @@ -361,6 +361,7 @@ def _get_fused_streaming_repartition_operator( up_op.data_context, name=name, compute_strategy=compute, + min_rows_per_bundle=batch_size, ref_bundler=ref_bundler, map_task_kwargs=map_task_kwargs, ray_remote_args=ray_remote_args, From ae95e039a50cb181aefe08e1aac6824995434e06 Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 19:34:19 +0800 Subject: [PATCH 04/22] test: add non strict tests Signed-off-by: machichima --- python/ray/data/tests/test_repartition_e2e.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/python/ray/data/tests/test_repartition_e2e.py b/python/ray/data/tests/test_repartition_e2e.py index 337623c4ca53..b7aecde645ec 100644 --- a/python/ray/data/tests/test_repartition_e2e.py +++ b/python/ray/data/tests/test_repartition_e2e.py @@ -452,6 +452,95 @@ def test_streaming_repartition_with_partial_last_block( ), f"Expected all blocks except last to have 20 rows, got {block_row_counts}" +def test_streaming_repartition_non_strict_mode( + ray_start_regular_shared_2_cpus, + disable_fallback_to_object_extension, +): + """Test non-strict mode streaming repartition behavior. + + This test verifies: + 1. Non-strict mode produces at most 1 block < target per input block + 2. No stitching across input blocks + """ + num_rows = 100 + target = 20 + + # Create dataset with varying block sizes + ds = ray.data.range(num_rows, override_num_blocks=10) # 10 blocks of 10 rows each + + # Non-strict mode: should split each input block independently + ds_non_strict = ds.repartition(target_num_rows_per_block=target, strict=False) + ds_non_strict = ds_non_strict.materialize() + + # Collect block row counts + block_row_counts = [ + metadata.num_rows + for bundle in ds_non_strict.iter_internal_ref_bundles() + for metadata in bundle.metadata + ] + + # Verify non-strict mode behavior: no stitching across input blocks + # For non-strict mode with input blocks of 10 rows and target of 20: + # Each input block (10 rows) should produce exactly 1 block of 10 rows + # (since 10 < 20, no splitting needed, and no stitching with other blocks) + assert sum(block_row_counts) == num_rows, f"Expected {num_rows} total rows" + assert len(block_row_counts) == 10, f"Expected 10 blocks, got {len(block_row_counts)}" + assert all(count == 10 for count in block_row_counts), ( + f"Expected all blocks to have 10 rows (no stitching), got {block_row_counts}" + ) + + +@pytest.mark.parametrize("batch_size", [30, 35, 45]) +def test_streaming_repartition_fusion_non_strict( + ray_start_regular_shared_2_cpus, + disable_fallback_to_object_extension, + batch_size, +): + """Test that non-strict mode can fuse with any batch_size. + + This test verifies: + 1. MapBatches -> StreamingRepartition(strict=False) can fuse regardless of batch_size + """ + num_rows = 100 + target = 20 + + def fn(batch): + # Just pass through, but verify we got data + assert len(batch["id"]) > 0, "Batch should not be empty" + return batch + + # Create dataset with 10 blocks (10 rows each) to ensure varied input block sizes + ds = ray.data.range(num_rows, override_num_blocks=10) + + # Non-strict mode should fuse even when batch_size % target != 0 + ds = ds.map_batches(fn, batch_size=batch_size) + ds = ds.repartition(target_num_rows_per_block=target, strict=False) + + # Verify fusion happened + planner = create_planner() + physical_plan = planner.plan(ds._logical_plan) + physical_plan = PhysicalOptimizer().optimize(physical_plan) + physical_op = physical_plan.dag + + assert ( + f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target},strict=False]" + in physical_op.name + ), ( + f"Expected fusion for batch_size={batch_size}, target={target}, " + f"but got operator name: {physical_op.name}" + ) + + # Verify correctness: count total rows and verify output block sizes + assert ds.count() == num_rows, f"Expected {num_rows} rows" + + # In non-strict mode, blocks are NOT guaranteed to be exactly target size + # because no stitching happens across input blocks from map_batches. + # Just verify that data is preserved correctly. + result = sorted([row["id"] for row in ds.take_all()]) + expected = list(range(num_rows)) + assert result == expected, "Data should be preserved correctly after fusion" + + @pytest.mark.timeout(60) def test_streaming_repartition_empty_dataset( ray_start_regular_shared_2_cpus, From 8f7282a1c01e42bb51a948e0f6a9a91333162c99 Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 19:34:51 +0800 Subject: [PATCH 05/22] test: set stirct=True for existing tests Signed-off-by: machichima --- python/ray/data/tests/test_repartition_e2e.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/ray/data/tests/test_repartition_e2e.py b/python/ray/data/tests/test_repartition_e2e.py index b7aecde645ec..810c4c094b52 100644 --- a/python/ray/data/tests/test_repartition_e2e.py +++ b/python/ray/data/tests/test_repartition_e2e.py @@ -144,6 +144,7 @@ def test_repartition_target_num_rows_per_block( # Each block is 8 ints ds = ray.data.range(total_rows, override_num_blocks=num_blocks).repartition( target_num_rows_per_block=target_num_rows_per_block, + strict=True, ) num_blocks = 0 @@ -270,16 +271,16 @@ def fn(batch): ds = ds.repartition(num_blocks=2, keys=[partition_col]) # mess up with the block size - ds = ds.repartition(target_num_rows_per_block=30) + ds = ds.repartition(target_num_rows_per_block=30, strict=True) # Verify fusion of StreamingRepartition and MapBatches operators b_s = target_num_rows * n_target_num_rows if streaming_repartition_first: - ds = ds.repartition(target_num_rows_per_block=target_num_rows) + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) ds = ds.map_batches(fn, batch_size=b_s) else: ds = ds.map_batches(fn, batch_size=b_s) - ds = ds.repartition(target_num_rows_per_block=target_num_rows) + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) planner = create_planner() physical_plan = planner.plan(ds._logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) @@ -290,7 +291,7 @@ def fn(batch): else: assert ( physical_op.name - == f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows}]" + == f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]" ) # Write output to local Parquet files partitioned by key @@ -341,18 +342,18 @@ def fn(batch): ds = ds.repartition(num_blocks=2, keys=[partition_col]) # mess up with the block size - ds = ds.repartition(target_num_rows_per_block=30) + ds = ds.repartition(target_num_rows_per_block=30, strict=True) # Verify fusion of StreamingRepartition and MapBatches operators ds = ds.map_batches(fn, batch_size=20) - ds = ds.repartition(target_num_rows_per_block=20) + ds = ds.repartition(target_num_rows_per_block=20, strict=True) planner = create_planner() physical_plan = planner.plan(ds._logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) physical_op = physical_plan.dag assert ( physical_op.name - == "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20]" + == "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20,strict=True]" ) for block in ds.iter_batches(batch_size=None): @@ -382,6 +383,7 @@ def test_repartition_guarantee_row_num_to_be_exact( ds = ray.data.range(num_rows, override_num_blocks=override_num_blocks) ds = ds.repartition( target_num_rows_per_block=target_num_rows_per_block, + strict=True, ) ds = ds.materialize() @@ -431,7 +433,7 @@ def test_streaming_repartition_with_partial_last_block( table = [{"id": n} for n in range(num_rows)] ds = ray.data.from_items(table) - ds = ds.repartition(target_num_rows_per_block=20) + ds = ds.repartition(target_num_rows_per_block=20, strict=True) ds = ds.materialize() From cdf8f9ff5e4e73bb5ba7d3b034863f001c3d748a Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 19:44:34 +0800 Subject: [PATCH 06/22] fix: include strict=... in operator name Signed-off-by: machichima --- python/ray/data/_internal/logical/operators/map_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 40ef94c81248..2acddb5c5d0c 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -429,7 +429,7 @@ def __init__( strict: bool = False, ): super().__init__( - f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block}]", + f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block},strict={strict}]", input_op, can_modify_num_rows=False, ) From def13b2bb74662a49dc9dfe99013b12cb0f83684 Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 19:50:16 +0800 Subject: [PATCH 07/22] fix: min_row_per_bundle and support fusion issue Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/operator_fusion.py | 7 ++++++- python/ray/data/_internal/planner/plan_udf_map_op.py | 4 ---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 3062ca439fc1..8204a26dfda2 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -355,13 +355,18 @@ def _get_fused_streaming_repartition_operator( input_op = input_deps[0] assert up_op.data_context is down_op.data_context + + # In non-strict mode, use min_rows_per_bundle to ensure creating batches with batch_size. + # In strict mode, ref_bundler handles bundling, so do not set min_rows_per_bundle. + min_rows = None if down_logical_op._strict else batch_size + op = MapOperator.create( up_op.get_map_transformer().fuse(down_op.get_map_transformer()), input_op, up_op.data_context, name=name, compute_strategy=compute, - min_rows_per_bundle=batch_size, + min_rows_per_bundle=min_rows, ref_bundler=ref_bundler, map_task_kwargs=map_task_kwargs, ray_remote_args=ray_remote_args, diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 45bc5fa4abb3..b621dd82bf4e 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -197,12 +197,9 @@ def plan_streaming_repartition_op( if op._strict: ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block) - supports_fusion = False else: ref_bundler = None - supports_fusion = True - # Disable fusion for streaming repartition with the downstream op. operator = MapOperator.create( map_transformer, input_physical_dag, @@ -210,7 +207,6 @@ def plan_streaming_repartition_op( name=op.name, compute_strategy=compute, ref_bundler=ref_bundler, - supports_fusion=supports_fusion, ray_remote_args=op._ray_remote_args, ray_remote_args_fn=op._ray_remote_args_fn, ) From dc609e1a7d803eeaa9cb1edde770b8eb281fd3be Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 20:10:37 +0800 Subject: [PATCH 08/22] docs: update docstring Signed-off-by: machichima --- python/ray/data/dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index de814df566c5..65948020e16b 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1679,6 +1679,12 @@ def repartition( optimal execution, based on the `target_num_rows_per_block`. This is the current behavior because of the implementation and may change in the future. + strict: If ``True``, ``repartition`` guarantees that all output blocks, + except for the last one, will have exactly ``target_num_rows_per_block`` rows. + If ``False``, ``repartition`` is more relaxed and may produce blocks smaller + than ``target_num_rows_per_block`` without stitching them together. + This parameter is only used with ``target_num_rows_per_block``. + Defaults to ``False``. shuffle: Whether to perform a distributed shuffle during the repartition. When shuffle is enabled, each output block contains a subset of data rows from each input block, which From 2c87758d9bb841ca50c800afea14dc2562159e52 Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 19 Jan 2026 20:16:06 +0800 Subject: [PATCH 09/22] feat: validate strict with target_num_rows_per_block Signed-off-by: machichima --- python/ray/data/dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 65948020e16b..c9131d79f1d5 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1721,6 +1721,13 @@ def repartition( warnings.warn( "`shuffle` is ignored when `target_num_rows_per_block` is set." ) + else: + if strict: + # strict is used in row-based repartition only + warnings.warn( + "`strict` is ignored when `target_num_rows_per_block` is not set. " + "Use `target_num_rows_per_block` instead of `num_blocks` to enable `strict` mode." + ) if (num_blocks is None) and (target_num_rows_per_block is None): raise ValueError( From d9d4295e26a66a04bb5606c5338501aa0d4d6ac4 Mon Sep 17 00:00:00 2001 From: machichima Date: Tue, 20 Jan 2026 04:49:45 +0800 Subject: [PATCH 10/22] refactor: precommit Signed-off-by: machichima --- ci/lint/pydoclint-baseline.txt | 1 - .../data/_internal/logical/rules/operator_fusion.py | 9 +++++---- python/ray/data/tests/test_repartition_e2e.py | 10 ++++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/ci/lint/pydoclint-baseline.txt b/ci/lint/pydoclint-baseline.txt index 177f0cff0f13..dc6b0f534367 100644 --- a/ci/lint/pydoclint-baseline.txt +++ b/ci/lint/pydoclint-baseline.txt @@ -1145,7 +1145,6 @@ python/ray/data/_internal/logical/operators/join_operator.py -------------------- python/ray/data/_internal/logical/operators/map_operator.py DOC101: Method `StreamingRepartition.__init__`: Docstring contains fewer arguments than in function signature. - DOC103: Method `StreamingRepartition.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [input_op: LogicalOperator]. -------------------- python/ray/data/_internal/logical/operators/n_ary_operator.py DOC001: Method `__init__` Potential formatting errors in docstring. Error message: No specification for "Args": "" diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 8204a26dfda2..687b1ccae4c5 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -291,7 +291,10 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # When batch_size % target == 0, each batch can be perfectly sliced into chunks # without cross-task buffering. See `_fuse_streaming_repartition_operators_in_dag` # docstring for details. - return up_logical_op._batch_size % down_logical_op.target_num_rows_per_block == 0 + return ( + up_logical_op._batch_size % down_logical_op.target_num_rows_per_block + == 0 + ) # Other operators cannot fuse with StreamingRepartition. if isinstance(up_logical_op, StreamingRepartition): return False @@ -320,9 +323,7 @@ def _get_fused_streaming_repartition_operator( if down_logical_op._strict: # Strict mode: use StreamingRepartitionRefBundler for stitching. # Only works when batch_size % target == 0 (verified in _can_fuse). - assert ( - batch_size % down_logical_op.target_num_rows_per_block == 0 - ), ( + assert batch_size % down_logical_op.target_num_rows_per_block == 0, ( f"Strict mode fusion requires batch_size ({batch_size}) to be " f"a multiple of target_num_rows_per_block " f"({down_logical_op.target_num_rows_per_block})" diff --git a/python/ray/data/tests/test_repartition_e2e.py b/python/ray/data/tests/test_repartition_e2e.py index 810c4c094b52..17c8d29c060b 100644 --- a/python/ray/data/tests/test_repartition_e2e.py +++ b/python/ray/data/tests/test_repartition_e2e.py @@ -486,10 +486,12 @@ def test_streaming_repartition_non_strict_mode( # Each input block (10 rows) should produce exactly 1 block of 10 rows # (since 10 < 20, no splitting needed, and no stitching with other blocks) assert sum(block_row_counts) == num_rows, f"Expected {num_rows} total rows" - assert len(block_row_counts) == 10, f"Expected 10 blocks, got {len(block_row_counts)}" - assert all(count == 10 for count in block_row_counts), ( - f"Expected all blocks to have 10 rows (no stitching), got {block_row_counts}" - ) + assert ( + len(block_row_counts) == 10 + ), f"Expected 10 blocks, got {len(block_row_counts)}" + assert all( + count == 10 for count in block_row_counts + ), f"Expected all blocks to have 10 rows (no stitching), got {block_row_counts}" @pytest.mark.parametrize("batch_size", [30, 35, 45]) From 04964bc4ca7992c08fc463769284d5daaa65e40c Mon Sep 17 00:00:00 2001 From: machichima Date: Tue, 20 Jan 2026 05:24:56 +0800 Subject: [PATCH 11/22] docs: update docstring Signed-off-by: machichima --- ci/lint/pydoclint-baseline.txt | 3 --- .../ray/data/_internal/logical/operators/map_operator.py | 8 +++++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ci/lint/pydoclint-baseline.txt b/ci/lint/pydoclint-baseline.txt index dc6b0f534367..a28abe851d43 100644 --- a/ci/lint/pydoclint-baseline.txt +++ b/ci/lint/pydoclint-baseline.txt @@ -1143,9 +1143,6 @@ python/ray/data/_internal/logical/operators/join_operator.py DOC101: Method `Join.__init__`: Docstring contains fewer arguments than in function signature. DOC103: Method `Join.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [aggregator_ray_remote_args: Optional[Dict[str, Any]], join_type: str, left_columns_suffix: Optional[str], left_input_op: LogicalOperator, left_key_columns: Tuple[str], num_partitions: int, partition_size_hint: Optional[int], right_columns_suffix: Optional[str], right_input_op: LogicalOperator, right_key_columns: Tuple[str]]. -------------------- -python/ray/data/_internal/logical/operators/map_operator.py - DOC101: Method `StreamingRepartition.__init__`: Docstring contains fewer arguments than in function signature. --------------------- python/ray/data/_internal/logical/operators/n_ary_operator.py DOC001: Method `__init__` Potential formatting errors in docstring. Error message: No specification for "Args": "" DOC001: Function/method `__init__`: Potential formatting errors in docstring. Error message: No specification for "Args": "" (Note: DOC001 could trigger other unrelated violations under this function/method too. Please fix the docstring formatting first.) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 2acddb5c5d0c..31c4a9e7d299 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -417,9 +417,15 @@ def __init__( class StreamingRepartition(AbstractMap): """Logical operator for streaming repartition operation. + Args: + input_op: The operator preceding this operator in the plan DAG. target_num_rows_per_block: The target number of rows per block granularity for - streaming repartition. + streaming repartition. + strict: If True, guarantees that all output blocks, except for the last one, + will have exactly target_num_rows_per_block rows. If False, is more relaxed + and may produce blocks smaller than target_num_rows_per_block without + stitching them together. Defaults to False. """ def __init__( From a9fbce02c18a8169b7e59ce45c5558f93b75c120 Mon Sep 17 00:00:00 2001 From: machichima Date: Wed, 21 Jan 2026 20:40:55 +0800 Subject: [PATCH 12/22] fix: pass strict param in CombineRepartitions Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/combine_repartitions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/data/_internal/logical/rules/combine_repartitions.py b/python/ray/data/_internal/logical/rules/combine_repartitions.py index a483676436cc..3e97faa5121e 100644 --- a/python/ray/data/_internal/logical/rules/combine_repartitions.py +++ b/python/ray/data/_internal/logical/rules/combine_repartitions.py @@ -35,9 +35,11 @@ def _combine_repartitions(op: LogicalOperator) -> LogicalOperator: elif isinstance(op, StreamingRepartition) and isinstance( input_op, StreamingRepartition ): + strict = input_op._strict or op._strict return StreamingRepartition( input_op.input_dependencies[0], target_num_rows_per_block=op.target_num_rows_per_block, + strict=strict, ) return op From 7b825e5604799e138b19aafb69100326934461d3 Mon Sep 17 00:00:00 2001 From: machichima Date: Wed, 21 Jan 2026 20:50:27 +0800 Subject: [PATCH 13/22] fix: verify target_num_rows_per_block in StreamingRepartition Signed-off-by: machichima --- python/ray/data/_internal/logical/operators/map_operator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 31c4a9e7d299..de24ec55ea6f 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -434,6 +434,11 @@ def __init__( target_num_rows_per_block: int, strict: bool = False, ): + if target_num_rows_per_block <= 0: + raise ValueError( + "target_num_rows_per_block must be positive for streaming repartition, " + f"got {target_num_rows_per_block}" + ) super().__init__( f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block},strict={strict}]", input_op, From accb54a80160503af669e57967ad399b42da2bb3 Mon Sep 17 00:00:00 2001 From: machichima Date: Wed, 21 Jan 2026 20:59:18 +0800 Subject: [PATCH 14/22] fix: pass strict param in CombineShuffles Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/combine_shuffles.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/data/_internal/logical/rules/combine_shuffles.py b/python/ray/data/_internal/logical/rules/combine_shuffles.py index afed879b370b..0494adf68c1d 100644 --- a/python/ray/data/_internal/logical/rules/combine_shuffles.py +++ b/python/ray/data/_internal/logical/rules/combine_shuffles.py @@ -52,9 +52,11 @@ def _combine(self, op: LogicalOperator) -> LogicalOperator: elif isinstance(input_op, StreamingRepartition) and isinstance( op, StreamingRepartition ): + strict = input_op._strict or op._strict return StreamingRepartition( input_op.input_dependencies[0], target_num_rows_per_block=op.target_num_rows_per_block, + strict=strict, ) elif isinstance(input_op, Repartition) and isinstance(op, Aggregate): return Aggregate( From 89965d04765212b1a1da8de6a97bc3ef69dae14d Mon Sep 17 00:00:00 2001 From: machichima Date: Fri, 23 Jan 2026 19:14:36 +0800 Subject: [PATCH 15/22] fix: pass min_rows_per_bundle in non-strict mode Signed-off-by: machichima --- python/ray/data/_internal/planner/plan_udf_map_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index b621dd82bf4e..fc033c43214a 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -197,8 +197,10 @@ def plan_streaming_repartition_op( if op._strict: ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block) + min_rows_per_bundle = None else: ref_bundler = None + min_rows_per_bundle = op.target_num_rows_per_block operator = MapOperator.create( map_transformer, @@ -206,6 +208,7 @@ def plan_streaming_repartition_op( data_context, name=op.name, compute_strategy=compute, + min_rows_per_bundle=min_rows_per_bundle, ref_bundler=ref_bundler, ray_remote_args=op._ray_remote_args, ray_remote_args_fn=op._ray_remote_args_fn, From f748b7985e2307cc06ec9ac7b526680f9a79e4d8 Mon Sep 17 00:00:00 2001 From: machichima Date: Fri, 23 Jan 2026 19:26:28 +0800 Subject: [PATCH 16/22] docs: update docstring Signed-off-by: machichima --- .../ray/data/_internal/logical/operators/map_operator.py | 7 ++++--- python/ray/data/dataset.py | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index de24ec55ea6f..a398f741e03f 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -423,9 +423,10 @@ class StreamingRepartition(AbstractMap): target_num_rows_per_block: The target number of rows per block granularity for streaming repartition. strict: If True, guarantees that all output blocks, except for the last one, - will have exactly target_num_rows_per_block rows. If False, is more relaxed - and may produce blocks smaller than target_num_rows_per_block without - stitching them together. Defaults to False. + will have exactly target_num_rows_per_block rows. If False, uses best-effort + bundling and may produce at most one block smaller than target_num_rows_per_block + per input block without forcing exact sizes through block splitting. + Defaults to False. """ def __init__( diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4e1693ce5adc..cbb6ad0e5a6e 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1692,8 +1692,9 @@ def repartition( the future. strict: If ``True``, ``repartition`` guarantees that all output blocks, except for the last one, will have exactly ``target_num_rows_per_block`` rows. - If ``False``, ``repartition`` is more relaxed and may produce blocks smaller - than ``target_num_rows_per_block`` without stitching them together. + If ``False``, ``repartition`` uses best-effort bundling and may produce at most + one block smaller than ``target_num_rows_per_block`` per input block without + forcing exact sizes through block splitting. This parameter is only used with ``target_num_rows_per_block``. Defaults to ``False``. shuffle: Whether to perform a distributed shuffle during the From 49cc5fcdbcf5a7e63608814810a19a8647dc007e Mon Sep 17 00:00:00 2001 From: machichima Date: Fri, 23 Jan 2026 19:41:37 +0800 Subject: [PATCH 17/22] test: set strict=True Signed-off-by: machichima --- python/ray/data/tests/test_operator_fusion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/data/tests/test_operator_fusion.py b/python/ray/data/tests/test_operator_fusion.py index 9d71689cbf97..125fad1637b6 100644 --- a/python/ray/data/tests/test_operator_fusion.py +++ b/python/ray/data/tests/test_operator_fusion.py @@ -768,12 +768,12 @@ def test_streaming_repartition_map_batches_fusion_order_and_params( if order == "map_then_sr": ds = ds.map_batches(lambda x: x, batch_size=batch_size) - ds = ds.repartition(target_num_rows_per_block=target_num_rows) - expected_fused_name = f"MapBatches()->StreamingRepartition[num_rows_per_block={target_num_rows}]" + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) + expected_fused_name = f"MapBatches()->StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]" else: # sr_then_map - ds = ds.repartition(target_num_rows_per_block=target_num_rows) + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) ds = ds.map_batches(lambda x: x, batch_size=batch_size) - expected_fused_name = f"StreamingRepartition[num_rows_per_block={target_num_rows}]->MapBatches()" + expected_fused_name = f"StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]->MapBatches()" assert len(ds.take_all()) == n @@ -813,7 +813,7 @@ def test_streaming_repartition_no_further_fuse( stats1 = ds1.stats() assert ( - f"MapBatches()->StreamingRepartition[num_rows_per_block={target_rows}]" + f"MapBatches()->StreamingRepartition[num_rows_per_block={target_rows},strict=False]" in stats1 ), stats1 assert "MapBatches()->MapBatches()" in stats1 From 68d01c4c48a59c7768ec9c2359a1859966c446b6 Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 26 Jan 2026 21:11:45 +0800 Subject: [PATCH 18/22] fix: set min_rows_per_bundle to None Signed-off-by: machichima --- python/ray/data/_internal/planner/plan_udf_map_op.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index fc033c43214a..b621dd82bf4e 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -197,10 +197,8 @@ def plan_streaming_repartition_op( if op._strict: ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block) - min_rows_per_bundle = None else: ref_bundler = None - min_rows_per_bundle = op.target_num_rows_per_block operator = MapOperator.create( map_transformer, @@ -208,7 +206,6 @@ def plan_streaming_repartition_op( data_context, name=op.name, compute_strategy=compute, - min_rows_per_bundle=min_rows_per_bundle, ref_bundler=ref_bundler, ray_remote_args=op._ray_remote_args, ray_remote_args_fn=op._ray_remote_args_fn, From c77787cde618f8f49756d9f9159c472c0a6520fb Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 5 Feb 2026 18:48:10 +0800 Subject: [PATCH 19/22] fix: update _can_fuse logic for batch size Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/operator_fusion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 22f8638d0ed9..805339654218 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -280,7 +280,6 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: if isinstance(down_logical_op, StreamingRepartition): if not ( isinstance(up_logical_op, MapBatches) - and up_logical_op.batch_size is not None and down_logical_op.target_num_rows_per_block is not None and down_logical_op.target_num_rows_per_block > 0 ): @@ -296,8 +295,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # without cross-task buffering. See `_fuse_streaming_repartition_operators_in_dag` # docstring for details. return ( - up_logical_op.batch_size % down_logical_op.target_num_rows_per_block - == 0 + up_logical_op.batch_size is not None + and up_logical_op.batch_size % down_logical_op.target_num_rows_per_block == 0 ) # Other operators cannot fuse with StreamingRepartition. if isinstance(up_logical_op, StreamingRepartition): From 6a2fec81cd09bb034f5e5f040031b3d33b60d128 Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 5 Feb 2026 19:07:03 +0800 Subject: [PATCH 20/22] refactor: precommit Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/operator_fusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 805339654218..e6796e2b16cb 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -296,7 +296,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # docstring for details. return ( up_logical_op.batch_size is not None - and up_logical_op.batch_size % down_logical_op.target_num_rows_per_block == 0 + and up_logical_op.batch_size % down_logical_op.target_num_rows_per_block + == 0 ) # Other operators cannot fuse with StreamingRepartition. if isinstance(up_logical_op, StreamingRepartition): From 111c05420d79e3fada83ff9f487fb5fca5258d2c Mon Sep 17 00:00:00 2001 From: machichima Date: Sat, 7 Feb 2026 14:07:44 +0800 Subject: [PATCH 21/22] fix: enable fuse with other operations in non-strict mode Signed-off-by: machichima --- python/ray/data/_internal/logical/rules/operator_fusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index e6796e2b16cb..15e05e43d4d2 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -299,9 +299,10 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: and up_logical_op.batch_size % down_logical_op.target_num_rows_per_block == 0 ) - # Other operators cannot fuse with StreamingRepartition. + # StreamingRepartition can only fuse in non-strict mode. + # In strict mode, it does not support further fusion. if isinstance(up_logical_op, StreamingRepartition): - return False + return not up_logical_op._strict # Otherwise, ops are compatible for fusion. return True From 83c5ddbfc84d72d1100349daab9a01924f1fa2df Mon Sep 17 00:00:00 2001 From: machichima Date: Sat, 7 Feb 2026 14:41:42 +0800 Subject: [PATCH 22/22] test: add map>map>sr and map>sr>map case Signed-off-by: machichima --- python/ray/data/tests/test_operator_fusion.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/python/ray/data/tests/test_operator_fusion.py b/python/ray/data/tests/test_operator_fusion.py index c7aabb40d783..5dbd4afa27e3 100644 --- a/python/ray/data/tests/test_operator_fusion.py +++ b/python/ray/data/tests/test_operator_fusion.py @@ -844,6 +844,48 @@ def test_filter_operator_no_upstream_fusion(ray_start_regular_shared_2_cpus, cap assert "TaskPoolMapOperator[MapBatches()->Filter()]" in captured +def test_streaming_repartition_multiple_fusion_non_strict( + ray_start_regular_shared_2_cpus, +): + """Test that non-strict mode allows multiple operators to fuse with StreamingRepartition. + + Case 1: Map > Map > SR (non-strict) + Case 2: Map > SR (non-strict) > Map + """ + n = 100 + target_rows = 20 + + # Case 1: Map > Map > SR (non-strict) + ds1 = ray.data.range(n, override_num_blocks=2) + ds1 = ds1.map_batches(lambda x: x, batch_size=None) + ds1 = ds1.map_batches(lambda x: x, batch_size=None) + ds1 = ds1.repartition(target_num_rows_per_block=target_rows, strict=False) + + assert len(ds1.take_all()) == n + stats1 = ds1.stats() + + # Verify all three operators are fused together + assert ( + f"MapBatches()->MapBatches()->StreamingRepartition[num_rows_per_block={target_rows},strict=False]" + in stats1 + ), f"Expected full fusion in stats: {stats1}" + + # Case 2: Map > SR (non-strict) > Map + ds2 = ray.data.range(n, override_num_blocks=2) + ds2 = ds2.map_batches(lambda x: x, batch_size=None) + ds2 = ds2.repartition(target_num_rows_per_block=target_rows, strict=False) + ds2 = ds2.map_batches(lambda x: x, batch_size=None) + + assert len(ds2.take_all()) == n + stats2 = ds2.stats() + + # Verify all three operators are fused together + assert ( + f"MapBatches()->StreamingRepartition[num_rows_per_block={target_rows},strict=False]->MapBatches()" + in stats2 + ), f"Expected full fusion in stats: {stats2}" + + def test_combine_repartition_aggregate( ray_start_regular_shared_2_cpus, configure_shuffle_method, capsys ):