Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
838513c
feat: add strict option in StreamingRepartition
machichima Jan 15, 2026
08008d1
feat: enable fusion for non-strict mode
machichima Jan 15, 2026
591b00f
fix: .strict to ._strict
machichima Jan 19, 2026
ae95e03
test: add non strict tests
machichima Jan 19, 2026
8f7282a
test: set stirct=True for existing tests
machichima Jan 19, 2026
cdf8f9f
fix: include strict=... in operator name
machichima Jan 19, 2026
def13b2
fix: min_row_per_bundle and support fusion issue
machichima Jan 19, 2026
dc609e1
docs: update docstring
machichima Jan 19, 2026
2c87758
feat: validate strict with target_num_rows_per_block
machichima Jan 19, 2026
d9d4295
refactor: precommit
machichima Jan 19, 2026
04964bc
docs: update docstring
machichima Jan 19, 2026
a9fbce0
fix: pass strict param in CombineRepartitions
machichima Jan 21, 2026
7b825e5
fix: verify target_num_rows_per_block in StreamingRepartition
machichima Jan 21, 2026
55c79bd
Merge branch 'master' of github.com:ray-project/ray into streamingrep…
machichima Jan 21, 2026
accb54a
fix: pass strict param in CombineShuffles
machichima Jan 21, 2026
89965d0
fix: pass min_rows_per_bundle in non-strict mode
machichima Jan 23, 2026
f748b79
docs: update docstring
machichima Jan 23, 2026
49cc5fc
test: set strict=True
machichima Jan 23, 2026
68d01c4
fix: set min_rows_per_bundle to None
machichima Jan 26, 2026
8a48fdd
Merge branch 'master' of github.com:ray-project/ray into streamingrep…
machichima Feb 5, 2026
c77787c
fix: update _can_fuse logic for batch size
machichima Feb 5, 2026
6a2fec8
refactor: precommit
machichima Feb 5, 2026
111c054
fix: enable fuse with other operations in non-strict mode
machichima Feb 7, 2026
83c5ddb
test: add map>map>sr and map>sr>map case
machichima Feb 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions ci/lint/pydoclint-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1138,10 +1138,6 @@ python/ray/data/_internal/logical/operators/join_operator.py
DOC101: Method `Join.__init__`: Docstring contains fewer arguments than in function signature.
DOC103: Method `Join.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [aggregator_ray_remote_args: Optional[Dict[str, Any]], join_type: str, left_columns_suffix: Optional[str], left_input_op: LogicalOperator, left_key_columns: Tuple[str], num_partitions: int, partition_size_hint: Optional[int], right_columns_suffix: Optional[str], right_input_op: LogicalOperator, right_key_columns: Tuple[str]].
--------------------
python/ray/data/_internal/logical/operators/map_operator.py
DOC101: Method `StreamingRepartition.__init__`: Docstring contains fewer arguments than in function signature.
DOC103: Method `StreamingRepartition.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [input_op: LogicalOperator].
--------------------
python/ray/data/_internal/logical/operators/n_ary_operator.py
DOC001: Method `__init__` Potential formatting errors in docstring. Error message: No specification for "Args": ""
DOC001: Function/method `__init__`: Potential formatting errors in docstring. Error message: No specification for "Args": "" (Note: DOC001 could trigger other unrelated violations under this function/method too. Please fix the docstring formatting first.)
Expand Down
18 changes: 16 additions & 2 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,22 +417,36 @@ def __init__(

class StreamingRepartition(AbstractMap):
"""Logical operator for streaming repartition operation.

Args:
input_op: The operator preceding this operator in the plan DAG.
target_num_rows_per_block: The target number of rows per block granularity for
streaming repartition.
streaming repartition.
strict: If True, guarantees that all output blocks, except for the last one,
will have exactly target_num_rows_per_block rows. If False, uses best-effort
bundling and may produce at most one block smaller than target_num_rows_per_block
per input block without forcing exact sizes through block splitting.
Defaults to False.
"""

def __init__(
self,
input_op: LogicalOperator,
target_num_rows_per_block: int,
strict: bool = False,
):
if target_num_rows_per_block <= 0:
raise ValueError(
"target_num_rows_per_block must be positive for streaming repartition, "
f"got {target_num_rows_per_block}"
)
super().__init__(
f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block}]",
f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block},strict={strict}]",
input_op,
can_modify_num_rows=False,
)
self._target_num_rows_per_block = target_num_rows_per_block
self._strict = strict

@property
def target_num_rows_per_block(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/_internal/logical/rules/combine_shuffles.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def _combine(self, op: LogicalOperator) -> LogicalOperator:
elif isinstance(input_op, StreamingRepartition) and isinstance(
op, StreamingRepartition
):
strict = input_op._strict or op._strict
return StreamingRepartition(
input_op.input_dependencies[0],
target_num_rows_per_block=op.target_num_rows_per_block,
strict=strict,
)
elif isinstance(input_op, Repartition) and isinstance(op, Aggregate):
return Aggregate(
Expand Down
60 changes: 45 additions & 15 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,18 +274,25 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:

# only allow fusion of MapBatches -> StreamingRepartition
if isinstance(down_logical_op, StreamingRepartition):
return (
if not (
isinstance(up_logical_op, MapBatches)
and up_logical_op._batch_size is not None
and down_logical_op.target_num_rows_per_block is not None
and down_logical_op.target_num_rows_per_block > 0
# When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks.
# See `_fuse_streaming_repartition_operators_in_dag` docstring for details.
# TODO: when the StreamingRepartition supports none_strict_mode, we can fuse
# `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are.
# https://anyscale1.atlassian.net/browse/DATA-1731
and up_logical_op._batch_size
% down_logical_op.target_num_rows_per_block
):
return False
Comment on lines 281 to 286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this logic is correct -- if _batch_size is None we'd still allow to fuse StreamingRepartition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @alexeykudinkin ,
I was following the original logic here, which also return False when _batch_size is None

if isinstance(down_logical_op, StreamingRepartition):
return (
isinstance(up_logical_op, MapBatches)
and up_logical_op._batch_size is not None
and down_logical_op.target_num_rows_per_block is not None
and down_logical_op.target_num_rows_per_block > 0
# When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks.
# See `_fuse_streaming_repartition_operators_in_dag` docstring for details.
# TODO: when the StreamingRepartition supports none_strict_mode, we can fuse
# `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are.
# https://anyscale1.atlassian.net/browse/DATA-1731
and up_logical_op._batch_size
% down_logical_op.target_num_rows_per_block
== 0
)

Also, while we use StreamingRepartitionRefBundler(batch_size), based on the class def, the batch_size cannot be None

def __init__(self, target_num_rows_per_block: int):
assert (
target_num_rows_per_block > 0
), "target_num_rows_per_block must be positive for streaming repartition."

Therefor, I think we should keep this here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you're relaxing this, right?

There are now should be 2 modes:

  • StreamingRepartition(strict=True): batch-size need to be exact multiple of target_num_rows_per_block to produce correct results.
  • StreamingRepartition(strict=False): batch-size could be anything (even null)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense! Thank you for pointing this out. Updated in c77787c

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not (
isinstance(up_logical_op, MapBatches)
and up_logical_op._batch_size is not None
and down_logical_op.target_num_rows_per_block is not None
and down_logical_op.target_num_rows_per_block > 0
# When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks.
# See `_fuse_streaming_repartition_operators_in_dag` docstring for details.
# TODO: when the StreamingRepartition supports none_strict_mode, we can fuse
# `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are.
# https://anyscale1.atlassian.net/browse/DATA-1731
and up_logical_op._batch_size
% down_logical_op.target_num_rows_per_block
):
return False
if (
not isinstance(up_logical_op, MapBatches)
or not down_logical_op.target_num_rows_per_block
):
return False

Can we simplify the logic here like this? And add check and raise error at dataset api to check if target_num_rows_per_block is not None it should not be negative

(Maybe move this)

assert (
target_num_rows_per_block > 0
), "target_num_rows_per_block must be positive for streaming repartition."


# Non-strict mode: can always fuse, no matter what batch_size is.
# This allows fusion without cross-task buffering by using default bundler.
if not down_logical_op._strict:
return True

# Strict mode: only fuse when batch_size is a multiple of target_num_rows_per_block.
# When batch_size % target == 0, each batch can be perfectly sliced into chunks
# without cross-task buffering. See `_fuse_streaming_repartition_operators_in_dag`
# docstring for details.
return (
up_logical_op._batch_size % down_logical_op.target_num_rows_per_block
== 0
)
# Other operators cannot fuse with StreamingRepartition.
Expand All @@ -309,11 +316,30 @@ def _get_fused_streaming_repartition_operator(
up_logical_op = self._op_map.pop(up_op)
assert isinstance(up_logical_op, MapBatches)
assert isinstance(down_logical_op, StreamingRepartition)
assert (
up_logical_op._batch_size % down_logical_op.target_num_rows_per_block == 0
)

batch_size = up_logical_op._batch_size

# Choose ref_bundler and fusion behavior based on strict mode
if down_logical_op._strict:
# Strict mode: use StreamingRepartitionRefBundler for stitching.
# Only works when batch_size % target == 0 (verified in _can_fuse).
assert batch_size % down_logical_op.target_num_rows_per_block == 0, (
f"Strict mode fusion requires batch_size ({batch_size}) to be "
f"a multiple of target_num_rows_per_block "
f"({down_logical_op.target_num_rows_per_block})"
)
ref_bundler = StreamingRepartitionRefBundler(batch_size)
# No further fusion because StreamingRepartitionRefBundler is stateful
# and maintains internal buffering state across bundles.
supports_fusion = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this prevent fusion when batch_size == target_num_rows_per_block ?

Copy link
Contributor Author

@machichima machichima Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I think it's intended. As the original code (strict mode) hard-coded supports_fusion=False to prevent further fusion

            # For now, we don't want to over-fuse StreamingRepartition with other map operators,
            # so the result operator does not support further fusion.
            supports_fusion=False,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd not be blocking any subsequent fusion like that

Let's add a test that we're able to fuse multiple ops like this:

  • Map > Map > SR
  • Map > SR > SR

Copy link
Contributor Author

@machichima machichima Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the comment is on line 338 (supports_fusion=False), I want to make sure do we want to support fusion for strict mode? Or just add test for non-strict mode? I think it's the latter one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Map > SR > SR case cannot work here because after the first Map > SR fusion, the logical operator becomes AbstractUDFMap rather than MapBatches.

logical_op = AbstractUDFMap(
name,
input_op,
up_logical_op.fn,
can_modify_num_rows=up_logical_op.can_modify_num_rows,
fn_args=up_logical_op.fn_args,
fn_kwargs=up_logical_op.fn_kwargs,
fn_constructor_args=up_logical_op.fn_constructor_args,
fn_constructor_kwargs=up_logical_op.fn_constructor_kwargs,
min_rows_per_bundled_input=batch_size,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
self._op_map[op] = logical_op

The current implementation only allows MapBatches > SR fusion:

and isinstance(self._op_map[upstream_ops[0]], MapBatches)

To support Map > SR > SR fusion, we will need more changes, which I think is a bit out of scope of this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in:

else:
# Non-strict mode: use default pass-through bundler.
# Works with any batch_size without cross-task buffering.
ref_bundler = None
# Can fuse further because the default bundler is stateless
# and processes each bundle independently.
supports_fusion = True

compute = self._fuse_compute_strategy(
up_logical_op._compute, down_logical_op._compute
)
Expand All @@ -330,19 +356,23 @@ def _get_fused_streaming_repartition_operator(
input_op = input_deps[0]

assert up_op.data_context is down_op.data_context

# In non-strict mode, use min_rows_per_bundle to ensure creating batches with batch_size.
# In strict mode, ref_bundler handles bundling, so do not set min_rows_per_bundle.
min_rows = None if down_logical_op._strict else batch_size

op = MapOperator.create(
up_op.get_map_transformer().fuse(down_op.get_map_transformer()),
input_op,
up_op.data_context,
name=name,
compute_strategy=compute,
ref_bundler=StreamingRepartitionRefBundler(batch_size),
min_rows_per_bundle=min_rows,
ref_bundler=ref_bundler,
map_task_kwargs=map_task_kwargs,
ray_remote_args=ray_remote_args,
ray_remote_args_fn=ray_remote_args_fn,
# For now, we don't want to over-fuse StreamingRepartition with other map operators,
# so the result operator does not support further fusion.
supports_fusion=False,
supports_fusion=supports_fusion,
)
op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators)
for map_task_kwargs_fn in itertools.chain(
Expand Down
8 changes: 6 additions & 2 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,18 @@ def plan_streaming_repartition_op(
)
map_transformer = MapTransformer([transform_fn])

# Disable fusion for streaming repartition with the downstream op.
if op._strict:
ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block)
else:
ref_bundler = None

operator = MapOperator.create(
map_transformer,
input_physical_dag,
data_context,
name=op.name,
compute_strategy=compute,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need min_rows_per_bundle = op.target_num_rows_per_block here if strict=False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 89965d0

Copy link
Contributor Author

@machichima machichima Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like when we set min_rows_per_bundle here, the BlockRefBundler will try to stitch the output:

return list(output_buffer), _merge_ref_bundles(*output_buffer)

Therefor, I think we should keep it as None here to prevent stitching

if self._min_rows_per_bundle is None:
# Short-circuit if no bundle row target was defined.
assert len(self._bundle_buffer) == 1
bundle = self._bundle_buffer[0]
self._bundle_buffer = []
self._bundle_buffer_size = 0
self._bundle_buffer_size_bytes = 0
return [bundle], bundle

ref_bundler=StreamingRepartitionRefBundler(op.target_num_rows_per_block),
ref_bundler=ref_bundler,
ray_remote_args=op._ray_remote_args,
ray_remote_args_fn=op._ray_remote_args_fn,
)
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,7 @@ def repartition(
num_blocks: Optional[int] = None,
target_num_rows_per_block: Optional[int] = None,
*,
strict: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new strict parameter should be documented in the repartition method's docstring. Explaining the difference between strict=True (the old behavior) and strict=False (the new default) is important for users to understand its impact on block sizes and fusion.

You could add something like this to the Args section:

strict: If ``True``, `repartition` guarantees that all output blocks, except for the last one, will have `target_num_rows_per_block` rows. If ``False``, `repartition` is more relaxed and may produce blocks smaller than `target_num_rows_per_block` without stitching them. This is only used with `target_num_rows_per_block`. Defaults to ``False``.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in dc609e1

shuffle: bool = False,
keys: Optional[List[str]] = None,
sort: bool = False,
Expand Down Expand Up @@ -1689,6 +1690,13 @@ def repartition(
optimal execution, based on the `target_num_rows_per_block`. This is
the current behavior because of the implementation and may change in
the future.
strict: If ``True``, ``repartition`` guarantees that all output blocks,
except for the last one, will have exactly ``target_num_rows_per_block`` rows.
If ``False``, ``repartition`` uses best-effort bundling and may produce at most
one block smaller than ``target_num_rows_per_block`` per input block without
forcing exact sizes through block splitting.
This parameter is only used with ``target_num_rows_per_block``.
Defaults to ``False``.
shuffle: Whether to perform a distributed shuffle during the
repartition. When shuffle is enabled, each output block
contains a subset of data rows from each input block, which
Expand Down Expand Up @@ -1725,6 +1733,13 @@ def repartition(
warnings.warn(
"`shuffle` is ignored when `target_num_rows_per_block` is set."
)
else:
if strict:
# strict is used in row-based repartition only
warnings.warn(
"`strict` is ignored when `target_num_rows_per_block` is not set. "
"Use `target_num_rows_per_block` instead of `num_blocks` to enable `strict` mode."
)

if (num_blocks is None) and (target_num_rows_per_block is None):
raise ValueError(
Expand All @@ -1746,6 +1761,7 @@ def repartition(
op = StreamingRepartition(
self._logical_plan.dag,
target_num_rows_per_block=target_num_rows_per_block,
strict=strict,
)
else:
op = Repartition(
Expand Down
10 changes: 5 additions & 5 deletions python/ray/data/tests/test_operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,12 @@ def test_streaming_repartition_map_batches_fusion_order_and_params(

if order == "map_then_sr":
ds = ds.map_batches(lambda x: x, batch_size=batch_size)
ds = ds.repartition(target_num_rows_per_block=target_num_rows)
expected_fused_name = f"MapBatches(<lambda>)->StreamingRepartition[num_rows_per_block={target_num_rows}]"
ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True)
expected_fused_name = f"MapBatches(<lambda>)->StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]"
else: # sr_then_map
ds = ds.repartition(target_num_rows_per_block=target_num_rows)
ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True)
ds = ds.map_batches(lambda x: x, batch_size=batch_size)
expected_fused_name = f"StreamingRepartition[num_rows_per_block={target_num_rows}]->MapBatches(<lambda>)"
expected_fused_name = f"StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]->MapBatches(<lambda>)"

assert len(ds.take_all()) == n

Expand Down Expand Up @@ -813,7 +813,7 @@ def test_streaming_repartition_no_further_fuse(
stats1 = ds1.stats()

assert (
f"MapBatches(<lambda>)->StreamingRepartition[num_rows_per_block={target_rows}]"
f"MapBatches(<lambda>)->StreamingRepartition[num_rows_per_block={target_rows},strict=False]"
in stats1
), stats1
assert "MapBatches(<lambda>)->MapBatches(<lambda>)" in stats1
Expand Down
Loading