diff --git a/ci/lint/pydoclint-baseline.txt b/ci/lint/pydoclint-baseline.txt index 4bcdde7631ea..231f3411d059 100644 --- a/ci/lint/pydoclint-baseline.txt +++ b/ci/lint/pydoclint-baseline.txt @@ -1138,10 +1138,6 @@ python/ray/data/_internal/logical/operators/join_operator.py DOC101: Method `Join.__init__`: Docstring contains fewer arguments than in function signature. DOC103: Method `Join.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [aggregator_ray_remote_args: Optional[Dict[str, Any]], join_type: str, left_columns_suffix: Optional[str], left_input_op: LogicalOperator, left_key_columns: Tuple[str], num_partitions: int, partition_size_hint: Optional[int], right_columns_suffix: Optional[str], right_input_op: LogicalOperator, right_key_columns: Tuple[str]]. -------------------- -python/ray/data/_internal/logical/operators/map_operator.py - DOC101: Method `StreamingRepartition.__init__`: Docstring contains fewer arguments than in function signature. - DOC103: Method `StreamingRepartition.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [input_op: LogicalOperator]. --------------------- python/ray/data/_internal/logical/operators/n_ary_operator.py DOC001: Method `__init__` Potential formatting errors in docstring. Error message: No specification for "Args": "" DOC001: Function/method `__init__`: Potential formatting errors in docstring. Error message: No specification for "Args": "" (Note: DOC001 could trigger other unrelated violations under this function/method too. Please fix the docstring formatting first.) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index ac029dabf9e3..0b0b73de228e 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -425,19 +425,33 @@ def __init__( class StreamingRepartition(AbstractMap): """Logical operator for streaming repartition operation. + Args: + input_op: The operator preceding this operator in the plan DAG. target_num_rows_per_block: The target number of rows per block granularity for - streaming repartition. + streaming repartition. + strict: If True, guarantees that all output blocks, except for the last one, + will have exactly target_num_rows_per_block rows. If False, uses best-effort + bundling and may produce at most one block smaller than target_num_rows_per_block + per input block without forcing exact sizes through block splitting. + Defaults to False. """ def __init__( self, input_op: LogicalOperator, target_num_rows_per_block: int, + strict: bool = False, ): + if target_num_rows_per_block <= 0: + raise ValueError( + "target_num_rows_per_block must be positive for streaming repartition, " + f"got {target_num_rows_per_block}" + ) super().__init__( - f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block}]", + f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block},strict={strict}]", input_op, can_modify_num_rows=False, ) self.target_num_rows_per_block = target_num_rows_per_block + self._strict = strict diff --git a/python/ray/data/_internal/logical/rules/combine_shuffles.py b/python/ray/data/_internal/logical/rules/combine_shuffles.py index d5de5c547af3..401861e3508c 100644 --- a/python/ray/data/_internal/logical/rules/combine_shuffles.py +++ b/python/ray/data/_internal/logical/rules/combine_shuffles.py @@ -56,9 +56,11 @@ def _combine(self, op: LogicalOperator) -> LogicalOperator: elif isinstance(input_op, StreamingRepartition) and isinstance( op, StreamingRepartition ): + strict = input_op._strict or op._strict return StreamingRepartition( input_op.input_dependencies[0], target_num_rows_per_block=op.target_num_rows_per_block, + strict=strict, ) elif isinstance(input_op, Repartition) and isinstance(op, Aggregate): return Aggregate( diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 9288d7db9f02..15e05e43d4d2 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -278,22 +278,31 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # only allow fusion of MapBatches -> StreamingRepartition if isinstance(down_logical_op, StreamingRepartition): - return ( + if not ( isinstance(up_logical_op, MapBatches) - and up_logical_op.batch_size is not None and down_logical_op.target_num_rows_per_block is not None and down_logical_op.target_num_rows_per_block > 0 - # When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks. - # See `_fuse_streaming_repartition_operators_in_dag` docstring for details. - # TODO: when the StreamingRepartition supports none_strict_mode, we can fuse - # `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are. - # https://anyscale1.atlassian.net/browse/DATA-1731 + ): + return False + + # Non-strict mode: can always fuse, no matter what batch_size is. + # This allows fusion without cross-task buffering by using default bundler. + if not down_logical_op._strict: + return True + + # Strict mode: only fuse when batch_size is a multiple of target_num_rows_per_block. + # When batch_size % target == 0, each batch can be perfectly sliced into chunks + # without cross-task buffering. See `_fuse_streaming_repartition_operators_in_dag` + # docstring for details. + return ( + up_logical_op.batch_size is not None and up_logical_op.batch_size % down_logical_op.target_num_rows_per_block == 0 ) - # Other operators cannot fuse with StreamingRepartition. + # StreamingRepartition can only fuse in non-strict mode. + # In strict mode, it does not support further fusion. if isinstance(up_logical_op, StreamingRepartition): - return False + return not up_logical_op._strict # Otherwise, ops are compatible for fusion. return True @@ -312,9 +321,30 @@ def _get_fused_streaming_repartition_operator( up_logical_op = self._op_map.pop(up_op) assert isinstance(up_logical_op, MapBatches) assert isinstance(down_logical_op, StreamingRepartition) - assert up_logical_op.batch_size % down_logical_op.target_num_rows_per_block == 0 + batch_size = up_logical_op.batch_size + # Choose ref_bundler and fusion behavior based on strict mode + if down_logical_op._strict: + # Strict mode: use StreamingRepartitionRefBundler for stitching. + # Only works when batch_size % target == 0 (verified in _can_fuse). + assert batch_size % down_logical_op.target_num_rows_per_block == 0, ( + f"Strict mode fusion requires batch_size ({batch_size}) to be " + f"a multiple of target_num_rows_per_block " + f"({down_logical_op.target_num_rows_per_block})" + ) + ref_bundler = StreamingRepartitionRefBundler(batch_size) + # No further fusion because StreamingRepartitionRefBundler is stateful + # and maintains internal buffering state across bundles. + supports_fusion = False + else: + # Non-strict mode: use default pass-through bundler. + # Works with any batch_size without cross-task buffering. + ref_bundler = None + # Can fuse further because the default bundler is stateless + # and processes each bundle independently. + supports_fusion = True + compute = self._fuse_compute_strategy( up_logical_op.compute, down_logical_op.compute ) @@ -331,19 +361,23 @@ def _get_fused_streaming_repartition_operator( input_op = input_deps[0] assert up_op.data_context is down_op.data_context + + # In non-strict mode, use min_rows_per_bundle to ensure creating batches with batch_size. + # In strict mode, ref_bundler handles bundling, so do not set min_rows_per_bundle. + min_rows = None if down_logical_op._strict else batch_size + op = MapOperator.create( up_op.get_map_transformer().fuse(down_op.get_map_transformer()), input_op, up_op.data_context, name=name, compute_strategy=compute, - ref_bundler=StreamingRepartitionRefBundler(batch_size), + min_rows_per_bundle=min_rows, + ref_bundler=ref_bundler, map_task_kwargs=map_task_kwargs, ray_remote_args=ray_remote_args, ray_remote_args_fn=ray_remote_args_fn, - # For now, we don't want to over-fuse StreamingRepartition with other map operators, - # so the result operator does not support further fusion. - supports_fusion=False, + supports_fusion=supports_fusion, ) op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators) for map_task_kwargs_fn in itertools.chain( diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 18e8b136a181..ec51e85f36de 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -195,14 +195,18 @@ def plan_streaming_repartition_op( ) map_transformer = MapTransformer([transform_fn]) - # Disable fusion for streaming repartition with the downstream op. + if op._strict: + ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block) + else: + ref_bundler = None + operator = MapOperator.create( map_transformer, input_physical_dag, data_context, name=op.name, compute_strategy=compute, - ref_bundler=StreamingRepartitionRefBundler(op.target_num_rows_per_block), + ref_bundler=ref_bundler, ray_remote_args=op.ray_remote_args, ray_remote_args_fn=op.ray_remote_args_fn, ) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index c0867d7427b3..ee908587fcd3 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1632,6 +1632,7 @@ def repartition( num_blocks: Optional[int] = None, target_num_rows_per_block: Optional[int] = None, *, + strict: bool = False, shuffle: bool = False, keys: Optional[List[str]] = None, sort: bool = False, @@ -1681,6 +1682,13 @@ def repartition( optimal execution, based on the `target_num_rows_per_block`. This is the current behavior because of the implementation and may change in the future. + strict: If ``True``, ``repartition`` guarantees that all output blocks, + except for the last one, will have exactly ``target_num_rows_per_block`` rows. + If ``False``, ``repartition`` uses best-effort bundling and may produce at most + one block smaller than ``target_num_rows_per_block`` per input block without + forcing exact sizes through block splitting. + This parameter is only used with ``target_num_rows_per_block``. + Defaults to ``False``. shuffle: Whether to perform a distributed shuffle during the repartition. When shuffle is enabled, each output block contains a subset of data rows from each input block, which @@ -1717,6 +1725,13 @@ def repartition( warnings.warn( "`shuffle` is ignored when `target_num_rows_per_block` is set." ) + else: + if strict: + # strict is used in row-based repartition only + warnings.warn( + "`strict` is ignored when `target_num_rows_per_block` is not set. " + "Use `target_num_rows_per_block` instead of `num_blocks` to enable `strict` mode." + ) if (num_blocks is None) and (target_num_rows_per_block is None): raise ValueError( @@ -1738,6 +1753,7 @@ def repartition( op = StreamingRepartition( self._logical_plan.dag, target_num_rows_per_block=target_num_rows_per_block, + strict=strict, ) else: op = Repartition( diff --git a/python/ray/data/tests/test_operator_fusion.py b/python/ray/data/tests/test_operator_fusion.py index b5863f0cf039..5dbd4afa27e3 100644 --- a/python/ray/data/tests/test_operator_fusion.py +++ b/python/ray/data/tests/test_operator_fusion.py @@ -771,12 +771,12 @@ def test_streaming_repartition_map_batches_fusion_order_and_params( if order == "map_then_sr": ds = ds.map_batches(lambda x: x, batch_size=batch_size) - ds = ds.repartition(target_num_rows_per_block=target_num_rows) - expected_fused_name = f"MapBatches()->StreamingRepartition[num_rows_per_block={target_num_rows}]" + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) + expected_fused_name = f"MapBatches()->StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]" else: # sr_then_map - ds = ds.repartition(target_num_rows_per_block=target_num_rows) + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) ds = ds.map_batches(lambda x: x, batch_size=batch_size) - expected_fused_name = f"StreamingRepartition[num_rows_per_block={target_num_rows}]->MapBatches()" + expected_fused_name = f"StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]->MapBatches()" assert len(ds.take_all()) == n @@ -816,7 +816,7 @@ def test_streaming_repartition_no_further_fuse( stats1 = ds1.stats() assert ( - f"MapBatches()->StreamingRepartition[num_rows_per_block={target_rows}]" + f"MapBatches()->StreamingRepartition[num_rows_per_block={target_rows},strict=False]" in stats1 ), stats1 assert "MapBatches()->MapBatches()" in stats1 @@ -844,6 +844,48 @@ def test_filter_operator_no_upstream_fusion(ray_start_regular_shared_2_cpus, cap assert "TaskPoolMapOperator[MapBatches()->Filter()]" in captured +def test_streaming_repartition_multiple_fusion_non_strict( + ray_start_regular_shared_2_cpus, +): + """Test that non-strict mode allows multiple operators to fuse with StreamingRepartition. + + Case 1: Map > Map > SR (non-strict) + Case 2: Map > SR (non-strict) > Map + """ + n = 100 + target_rows = 20 + + # Case 1: Map > Map > SR (non-strict) + ds1 = ray.data.range(n, override_num_blocks=2) + ds1 = ds1.map_batches(lambda x: x, batch_size=None) + ds1 = ds1.map_batches(lambda x: x, batch_size=None) + ds1 = ds1.repartition(target_num_rows_per_block=target_rows, strict=False) + + assert len(ds1.take_all()) == n + stats1 = ds1.stats() + + # Verify all three operators are fused together + assert ( + f"MapBatches()->MapBatches()->StreamingRepartition[num_rows_per_block={target_rows},strict=False]" + in stats1 + ), f"Expected full fusion in stats: {stats1}" + + # Case 2: Map > SR (non-strict) > Map + ds2 = ray.data.range(n, override_num_blocks=2) + ds2 = ds2.map_batches(lambda x: x, batch_size=None) + ds2 = ds2.repartition(target_num_rows_per_block=target_rows, strict=False) + ds2 = ds2.map_batches(lambda x: x, batch_size=None) + + assert len(ds2.take_all()) == n + stats2 = ds2.stats() + + # Verify all three operators are fused together + assert ( + f"MapBatches()->StreamingRepartition[num_rows_per_block={target_rows},strict=False]->MapBatches()" + in stats2 + ), f"Expected full fusion in stats: {stats2}" + + def test_combine_repartition_aggregate( ray_start_regular_shared_2_cpus, configure_shuffle_method, capsys ): diff --git a/python/ray/data/tests/test_repartition_e2e.py b/python/ray/data/tests/test_repartition_e2e.py index 337623c4ca53..17c8d29c060b 100644 --- a/python/ray/data/tests/test_repartition_e2e.py +++ b/python/ray/data/tests/test_repartition_e2e.py @@ -144,6 +144,7 @@ def test_repartition_target_num_rows_per_block( # Each block is 8 ints ds = ray.data.range(total_rows, override_num_blocks=num_blocks).repartition( target_num_rows_per_block=target_num_rows_per_block, + strict=True, ) num_blocks = 0 @@ -270,16 +271,16 @@ def fn(batch): ds = ds.repartition(num_blocks=2, keys=[partition_col]) # mess up with the block size - ds = ds.repartition(target_num_rows_per_block=30) + ds = ds.repartition(target_num_rows_per_block=30, strict=True) # Verify fusion of StreamingRepartition and MapBatches operators b_s = target_num_rows * n_target_num_rows if streaming_repartition_first: - ds = ds.repartition(target_num_rows_per_block=target_num_rows) + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) ds = ds.map_batches(fn, batch_size=b_s) else: ds = ds.map_batches(fn, batch_size=b_s) - ds = ds.repartition(target_num_rows_per_block=target_num_rows) + ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) planner = create_planner() physical_plan = planner.plan(ds._logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) @@ -290,7 +291,7 @@ def fn(batch): else: assert ( physical_op.name - == f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows}]" + == f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]" ) # Write output to local Parquet files partitioned by key @@ -341,18 +342,18 @@ def fn(batch): ds = ds.repartition(num_blocks=2, keys=[partition_col]) # mess up with the block size - ds = ds.repartition(target_num_rows_per_block=30) + ds = ds.repartition(target_num_rows_per_block=30, strict=True) # Verify fusion of StreamingRepartition and MapBatches operators ds = ds.map_batches(fn, batch_size=20) - ds = ds.repartition(target_num_rows_per_block=20) + ds = ds.repartition(target_num_rows_per_block=20, strict=True) planner = create_planner() physical_plan = planner.plan(ds._logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) physical_op = physical_plan.dag assert ( physical_op.name - == "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20]" + == "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20,strict=True]" ) for block in ds.iter_batches(batch_size=None): @@ -382,6 +383,7 @@ def test_repartition_guarantee_row_num_to_be_exact( ds = ray.data.range(num_rows, override_num_blocks=override_num_blocks) ds = ds.repartition( target_num_rows_per_block=target_num_rows_per_block, + strict=True, ) ds = ds.materialize() @@ -431,7 +433,7 @@ def test_streaming_repartition_with_partial_last_block( table = [{"id": n} for n in range(num_rows)] ds = ray.data.from_items(table) - ds = ds.repartition(target_num_rows_per_block=20) + ds = ds.repartition(target_num_rows_per_block=20, strict=True) ds = ds.materialize() @@ -452,6 +454,97 @@ def test_streaming_repartition_with_partial_last_block( ), f"Expected all blocks except last to have 20 rows, got {block_row_counts}" +def test_streaming_repartition_non_strict_mode( + ray_start_regular_shared_2_cpus, + disable_fallback_to_object_extension, +): + """Test non-strict mode streaming repartition behavior. + + This test verifies: + 1. Non-strict mode produces at most 1 block < target per input block + 2. No stitching across input blocks + """ + num_rows = 100 + target = 20 + + # Create dataset with varying block sizes + ds = ray.data.range(num_rows, override_num_blocks=10) # 10 blocks of 10 rows each + + # Non-strict mode: should split each input block independently + ds_non_strict = ds.repartition(target_num_rows_per_block=target, strict=False) + ds_non_strict = ds_non_strict.materialize() + + # Collect block row counts + block_row_counts = [ + metadata.num_rows + for bundle in ds_non_strict.iter_internal_ref_bundles() + for metadata in bundle.metadata + ] + + # Verify non-strict mode behavior: no stitching across input blocks + # For non-strict mode with input blocks of 10 rows and target of 20: + # Each input block (10 rows) should produce exactly 1 block of 10 rows + # (since 10 < 20, no splitting needed, and no stitching with other blocks) + assert sum(block_row_counts) == num_rows, f"Expected {num_rows} total rows" + assert ( + len(block_row_counts) == 10 + ), f"Expected 10 blocks, got {len(block_row_counts)}" + assert all( + count == 10 for count in block_row_counts + ), f"Expected all blocks to have 10 rows (no stitching), got {block_row_counts}" + + +@pytest.mark.parametrize("batch_size", [30, 35, 45]) +def test_streaming_repartition_fusion_non_strict( + ray_start_regular_shared_2_cpus, + disable_fallback_to_object_extension, + batch_size, +): + """Test that non-strict mode can fuse with any batch_size. + + This test verifies: + 1. MapBatches -> StreamingRepartition(strict=False) can fuse regardless of batch_size + """ + num_rows = 100 + target = 20 + + def fn(batch): + # Just pass through, but verify we got data + assert len(batch["id"]) > 0, "Batch should not be empty" + return batch + + # Create dataset with 10 blocks (10 rows each) to ensure varied input block sizes + ds = ray.data.range(num_rows, override_num_blocks=10) + + # Non-strict mode should fuse even when batch_size % target != 0 + ds = ds.map_batches(fn, batch_size=batch_size) + ds = ds.repartition(target_num_rows_per_block=target, strict=False) + + # Verify fusion happened + planner = create_planner() + physical_plan = planner.plan(ds._logical_plan) + physical_plan = PhysicalOptimizer().optimize(physical_plan) + physical_op = physical_plan.dag + + assert ( + f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target},strict=False]" + in physical_op.name + ), ( + f"Expected fusion for batch_size={batch_size}, target={target}, " + f"but got operator name: {physical_op.name}" + ) + + # Verify correctness: count total rows and verify output block sizes + assert ds.count() == num_rows, f"Expected {num_rows} rows" + + # In non-strict mode, blocks are NOT guaranteed to be exactly target size + # because no stitching happens across input blocks from map_batches. + # Just verify that data is preserved correctly. + result = sorted([row["id"] for row in ds.take_all()]) + expected = list(range(num_rows)) + assert result == expected, "Data should be preserved correctly after fusion" + + @pytest.mark.timeout(60) def test_streaming_repartition_empty_dataset( ray_start_regular_shared_2_cpus,