-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[Data] Support strict=False mode for StreamingRepartition #60295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 15 commits
838513c
08008d1
591b00f
ae95e03
8f7282a
cdf8f9f
def13b2
dc609e1
2c87758
d9d4295
04964bc
a9fbce0
7b825e5
55c79bd
accb54a
89965d0
f748b79
49cc5fc
68d01c4
8a48fdd
c77787c
6a2fec8
111c054
83c5ddb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -274,18 +274,25 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # only allow fusion of MapBatches -> StreamingRepartition | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(down_logical_op, StreamingRepartition): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| isinstance(up_logical_op, MapBatches) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and up_logical_op._batch_size is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and down_logical_op.target_num_rows_per_block is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and down_logical_op.target_num_rows_per_block > 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See `_fuse_streaming_repartition_operators_in_dag` docstring for details. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: when the StreamingRepartition supports none_strict_mode, we can fuse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # https://anyscale1.atlassian.net/browse/DATA-1731 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and up_logical_op._batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| % down_logical_op.target_num_rows_per_block | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
281
to
286
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this logic is correct -- if _batch_size is None we'd still allow to fuse StreamingRepartition
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @alexeykudinkin , ray/python/ray/data/_internal/logical/rules/operator_fusion.py Lines 280 to 294 in 8e2e0aa
Also, while we use ray/python/ray/data/_internal/streaming_repartition.py Lines 34 to 37 in 68d01c4
Therefor, I think we should keep this here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, you're relaxing this, right? There are now should be 2 modes:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense! Thank you for pointing this out. Updated in c77787c
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Can we simplify the logic here like this? And add check and raise error at dataset api to check if (Maybe move this) ray/python/ray/data/_internal/streaming_repartition.py Lines 35 to 37 in f5a53c4
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-strict mode: can always fuse, no matter what batch_size is. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # This allows fusion without cross-task buffering by using default bundler. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not down_logical_op._strict: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Strict mode: only fuse when batch_size is a multiple of target_num_rows_per_block. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When batch_size % target == 0, each batch can be perfectly sliced into chunks | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # without cross-task buffering. See `_fuse_streaming_repartition_operators_in_dag` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # docstring for details. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_logical_op._batch_size % down_logical_op.target_num_rows_per_block | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Other operators cannot fuse with StreamingRepartition. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -309,11 +316,30 @@ def _get_fused_streaming_repartition_operator( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_logical_op = self._op_map.pop(up_op) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(up_logical_op, MapBatches) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(down_logical_op, StreamingRepartition) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_logical_op._batch_size % down_logical_op.target_num_rows_per_block == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size = up_logical_op._batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Choose ref_bundler and fusion behavior based on strict mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if down_logical_op._strict: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Strict mode: use StreamingRepartitionRefBundler for stitching. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Only works when batch_size % target == 0 (verified in _can_fuse). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert batch_size % down_logical_op.target_num_rows_per_block == 0, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Strict mode fusion requires batch_size ({batch_size}) to be " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"a multiple of target_num_rows_per_block " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"({down_logical_op.target_num_rows_per_block})" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref_bundler = StreamingRepartitionRefBundler(batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # No further fusion because StreamingRepartitionRefBundler is stateful | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # and maintains internal buffering state across bundles. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supports_fusion = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this prevent fusion when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but I think it's intended. As the original code (strict mode) hard-coded # For now, we don't want to over-fuse StreamingRepartition with other map operators,
# so the result operator does not support further fusion.
supports_fusion=False,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd not be blocking any subsequent fusion like that Let's add a test that we're able to fuse multiple ops like this:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While the comment is on line 338 (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Map > SR > SR case cannot work here because after the first Map > SR fusion, the logical operator becomes ray/python/ray/data/_internal/logical/rules/operator_fusion.py Lines 355 to 369 in f3d444a
The current implementation only allows MapBatches > SR fusion:
To support Map > SR > SR fusion, we will need more changes, which I think is a bit out of scope of this PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-strict mode: use default pass-through bundler. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Works with any batch_size without cross-task buffering. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref_bundler = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Can fuse further because the default bundler is stateless | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # and processes each bundle independently. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supports_fusion = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compute = self._fuse_compute_strategy( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_logical_op._compute, down_logical_op._compute | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -330,19 +356,23 @@ def _get_fused_streaming_repartition_operator( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_op = input_deps[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert up_op.data_context is down_op.data_context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # In non-strict mode, use min_rows_per_bundle to ensure creating batches with batch_size. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # In strict mode, ref_bundler handles bundling, so do not set min_rows_per_bundle. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| min_rows = None if down_logical_op._strict else batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| op = MapOperator.create( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_op.get_map_transformer().fuse(down_op.get_map_transformer()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_op, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_op.data_context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name=name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compute_strategy=compute, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref_bundler=StreamingRepartitionRefBundler(batch_size), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| min_rows_per_bundle=min_rows, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref_bundler=ref_bundler, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| map_task_kwargs=map_task_kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ray_remote_args=ray_remote_args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ray_remote_args_fn=ray_remote_args_fn, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For now, we don't want to over-fuse StreamingRepartition with other map operators, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # so the result operator does not support further fusion. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supports_fusion=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supports_fusion=supports_fusion, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for map_task_kwargs_fn in itertools.chain( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -195,14 +195,18 @@ def plan_streaming_repartition_op( | |||||||||||||||||||
| ) | ||||||||||||||||||||
| map_transformer = MapTransformer([transform_fn]) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Disable fusion for streaming repartition with the downstream op. | ||||||||||||||||||||
| if op._strict: | ||||||||||||||||||||
| ref_bundler = StreamingRepartitionRefBundler(op.target_num_rows_per_block) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| ref_bundler = None | ||||||||||||||||||||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| operator = MapOperator.create( | ||||||||||||||||||||
| map_transformer, | ||||||||||||||||||||
| input_physical_dag, | ||||||||||||||||||||
| data_context, | ||||||||||||||||||||
| name=op.name, | ||||||||||||||||||||
| compute_strategy=compute, | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in 89965d0
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like when we set
Therefor, I think we should keep it as ray/python/ray/data/_internal/execution/operators/map_operator.py Lines 828 to 835 in 68d01c4
|
||||||||||||||||||||
| ref_bundler=StreamingRepartitionRefBundler(op.target_num_rows_per_block), | ||||||||||||||||||||
| ref_bundler=ref_bundler, | ||||||||||||||||||||
| ray_remote_args=op._ray_remote_args, | ||||||||||||||||||||
| ray_remote_args_fn=op._ray_remote_args_fn, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1640,6 +1640,7 @@ def repartition( | |
| num_blocks: Optional[int] = None, | ||
| target_num_rows_per_block: Optional[int] = None, | ||
| *, | ||
| strict: bool = False, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new You could add something like this to the strict: If ``True``, `repartition` guarantees that all output blocks, except for the last one, will have `target_num_rows_per_block` rows. If ``False``, `repartition` is more relaxed and may produce blocks smaller than `target_num_rows_per_block` without stitching them. This is only used with `target_num_rows_per_block`. Defaults to ``False``.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in dc609e1 |
||
| shuffle: bool = False, | ||
| keys: Optional[List[str]] = None, | ||
| sort: bool = False, | ||
|
|
@@ -1689,6 +1690,12 @@ def repartition( | |
| optimal execution, based on the `target_num_rows_per_block`. This is | ||
| the current behavior because of the implementation and may change in | ||
| the future. | ||
| strict: If ``True``, ``repartition`` guarantees that all output blocks, | ||
| except for the last one, will have exactly ``target_num_rows_per_block`` rows. | ||
| If ``False``, ``repartition`` is more relaxed and may produce blocks smaller | ||
| than ``target_num_rows_per_block`` without stitching them together. | ||
| This parameter is only used with ``target_num_rows_per_block``. | ||
| Defaults to ``False``. | ||
|
Comment on lines
1685
to
1691
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be better to say that will only produce at most 1 block that is < target_num_rows_per_block per input block if strict is false.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in f748b79 |
||
| shuffle: Whether to perform a distributed shuffle during the | ||
| repartition. When shuffle is enabled, each output block | ||
| contains a subset of data rows from each input block, which | ||
|
|
@@ -1725,6 +1732,13 @@ def repartition( | |
| warnings.warn( | ||
| "`shuffle` is ignored when `target_num_rows_per_block` is set." | ||
| ) | ||
| else: | ||
| if strict: | ||
| # strict is used in row-based repartition only | ||
| warnings.warn( | ||
| "`strict` is ignored when `target_num_rows_per_block` is not set. " | ||
| "Use `target_num_rows_per_block` instead of `num_blocks` to enable `strict` mode." | ||
| ) | ||
|
|
||
| if (num_blocks is None) and (target_num_rows_per_block is None): | ||
| raise ValueError( | ||
|
|
@@ -1746,6 +1760,7 @@ def repartition( | |
| op = StreamingRepartition( | ||
| self._logical_plan.dag, | ||
| target_num_rows_per_block=target_num_rows_per_block, | ||
| strict=strict, | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| else: | ||
| op = Repartition( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -144,6 +144,7 @@ def test_repartition_target_num_rows_per_block( | |||
| # Each block is 8 ints | ||||
| ds = ray.data.range(total_rows, override_num_blocks=num_blocks).repartition( | ||||
| target_num_rows_per_block=target_num_rows_per_block, | ||||
| strict=True, | ||||
| ) | ||||
|
|
||||
| num_blocks = 0 | ||||
|
|
@@ -270,16 +271,16 @@ def fn(batch): | |||
| ds = ds.repartition(num_blocks=2, keys=[partition_col]) | ||||
|
|
||||
| # mess up with the block size | ||||
| ds = ds.repartition(target_num_rows_per_block=30) | ||||
| ds = ds.repartition(target_num_rows_per_block=30, strict=True) | ||||
|
|
||||
| # Verify fusion of StreamingRepartition and MapBatches operators | ||||
| b_s = target_num_rows * n_target_num_rows | ||||
| if streaming_repartition_first: | ||||
| ds = ds.repartition(target_num_rows_per_block=target_num_rows) | ||||
| ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) | ||||
| ds = ds.map_batches(fn, batch_size=b_s) | ||||
| else: | ||||
| ds = ds.map_batches(fn, batch_size=b_s) | ||||
| ds = ds.repartition(target_num_rows_per_block=target_num_rows) | ||||
| ds = ds.repartition(target_num_rows_per_block=target_num_rows, strict=True) | ||||
| planner = create_planner() | ||||
| physical_plan = planner.plan(ds._logical_plan) | ||||
| physical_plan = PhysicalOptimizer().optimize(physical_plan) | ||||
|
|
@@ -290,7 +291,7 @@ def fn(batch): | |||
| else: | ||||
| assert ( | ||||
| physical_op.name | ||||
| == f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows}]" | ||||
| == f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows},strict=True]" | ||||
| ) | ||||
|
|
||||
| # Write output to local Parquet files partitioned by key | ||||
|
|
@@ -341,18 +342,18 @@ def fn(batch): | |||
| ds = ds.repartition(num_blocks=2, keys=[partition_col]) | ||||
|
|
||||
| # mess up with the block size | ||||
| ds = ds.repartition(target_num_rows_per_block=30) | ||||
| ds = ds.repartition(target_num_rows_per_block=30, strict=True) | ||||
|
|
||||
| # Verify fusion of StreamingRepartition and MapBatches operators | ||||
| ds = ds.map_batches(fn, batch_size=20) | ||||
| ds = ds.repartition(target_num_rows_per_block=20) | ||||
| ds = ds.repartition(target_num_rows_per_block=20, strict=True) | ||||
| planner = create_planner() | ||||
| physical_plan = planner.plan(ds._logical_plan) | ||||
| physical_plan = PhysicalOptimizer().optimize(physical_plan) | ||||
| physical_op = physical_plan.dag | ||||
| assert ( | ||||
| physical_op.name | ||||
| == "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20]" | ||||
| == "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20,strict=True]" | ||||
| ) | ||||
|
|
||||
| for block in ds.iter_batches(batch_size=None): | ||||
|
|
@@ -382,6 +383,7 @@ def test_repartition_guarantee_row_num_to_be_exact( | |||
| ds = ray.data.range(num_rows, override_num_blocks=override_num_blocks) | ||||
| ds = ds.repartition( | ||||
| target_num_rows_per_block=target_num_rows_per_block, | ||||
| strict=True, | ||||
| ) | ||||
| ds = ds.materialize() | ||||
|
|
||||
|
|
@@ -431,7 +433,7 @@ def test_streaming_repartition_with_partial_last_block( | |||
| table = [{"id": n} for n in range(num_rows)] | ||||
| ds = ray.data.from_items(table) | ||||
|
|
||||
| ds = ds.repartition(target_num_rows_per_block=20) | ||||
| ds = ds.repartition(target_num_rows_per_block=20, strict=True) | ||||
|
|
||||
| ds = ds.materialize() | ||||
|
|
||||
|
|
@@ -452,6 +454,97 @@ def test_streaming_repartition_with_partial_last_block( | |||
| ), f"Expected all blocks except last to have 20 rows, got {block_row_counts}" | ||||
|
|
||||
|
|
||||
| def test_streaming_repartition_non_strict_mode( | ||||
| ray_start_regular_shared_2_cpus, | ||||
| disable_fallback_to_object_extension, | ||||
| ): | ||||
| """Test non-strict mode streaming repartition behavior. | ||||
|
|
||||
| This test verifies: | ||||
| 1. Non-strict mode produces at most 1 block < target per input block | ||||
| 2. No stitching across input blocks | ||||
| """ | ||||
| num_rows = 100 | ||||
| target = 20 | ||||
|
|
||||
| # Create dataset with varying block sizes | ||||
| ds = ray.data.range(num_rows, override_num_blocks=10) # 10 blocks of 10 rows each | ||||
|
|
||||
| # Non-strict mode: should split each input block independently | ||||
| ds_non_strict = ds.repartition(target_num_rows_per_block=target, strict=False) | ||||
| ds_non_strict = ds_non_strict.materialize() | ||||
|
|
||||
| # Collect block row counts | ||||
| block_row_counts = [ | ||||
| metadata.num_rows | ||||
| for bundle in ds_non_strict.iter_internal_ref_bundles() | ||||
| for metadata in bundle.metadata | ||||
| ] | ||||
|
|
||||
| # Verify non-strict mode behavior: no stitching across input blocks | ||||
| # For non-strict mode with input blocks of 10 rows and target of 20: | ||||
| # Each input block (10 rows) should produce exactly 1 block of 10 rows | ||||
| # (since 10 < 20, no splitting needed, and no stitching with other blocks) | ||||
| assert sum(block_row_counts) == num_rows, f"Expected {num_rows} total rows" | ||||
| assert ( | ||||
| len(block_row_counts) == 10 | ||||
| ), f"Expected 10 blocks, got {len(block_row_counts)}" | ||||
| assert all( | ||||
| count == 10 for count in block_row_counts | ||||
| ), f"Expected all blocks to have 10 rows (no stitching), got {block_row_counts}" | ||||
|
|
||||
|
|
||||
| @pytest.mark.parametrize("batch_size", [30, 35, 45]) | ||||
| def test_streaming_repartition_fusion_non_strict( | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think fusion test should be in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's existing fusion and streaming repartition related test in this file, I think we can put this here as it align with existing tests. WDYT?
|
||||
| ray_start_regular_shared_2_cpus, | ||||
| disable_fallback_to_object_extension, | ||||
| batch_size, | ||||
| ): | ||||
| """Test that non-strict mode can fuse with any batch_size. | ||||
|
|
||||
| This test verifies: | ||||
| 1. MapBatches -> StreamingRepartition(strict=False) can fuse regardless of batch_size | ||||
| """ | ||||
| num_rows = 100 | ||||
| target = 20 | ||||
|
|
||||
| def fn(batch): | ||||
| # Just pass through, but verify we got data | ||||
| assert len(batch["id"]) > 0, "Batch should not be empty" | ||||
| return batch | ||||
|
|
||||
| # Create dataset with 10 blocks (10 rows each) to ensure varied input block sizes | ||||
| ds = ray.data.range(num_rows, override_num_blocks=10) | ||||
|
|
||||
| # Non-strict mode should fuse even when batch_size % target != 0 | ||||
| ds = ds.map_batches(fn, batch_size=batch_size) | ||||
| ds = ds.repartition(target_num_rows_per_block=target, strict=False) | ||||
|
|
||||
| # Verify fusion happened | ||||
| planner = create_planner() | ||||
| physical_plan = planner.plan(ds._logical_plan) | ||||
| physical_plan = PhysicalOptimizer().optimize(physical_plan) | ||||
| physical_op = physical_plan.dag | ||||
|
|
||||
| assert ( | ||||
| f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target},strict=False]" | ||||
| in physical_op.name | ||||
| ), ( | ||||
| f"Expected fusion for batch_size={batch_size}, target={target}, " | ||||
| f"but got operator name: {physical_op.name}" | ||||
| ) | ||||
|
|
||||
| # Verify correctness: count total rows and verify output block sizes | ||||
| assert ds.count() == num_rows, f"Expected {num_rows} rows" | ||||
|
|
||||
| # In non-strict mode, blocks are NOT guaranteed to be exactly target size | ||||
| # because no stitching happens across input blocks from map_batches. | ||||
| # Just verify that data is preserved correctly. | ||||
| result = sorted([row["id"] for row in ds.take_all()]) | ||||
| expected = list(range(num_rows)) | ||||
| assert result == expected, "Data should be preserved correctly after fusion" | ||||
|
|
||||
|
|
||||
| @pytest.mark.timeout(60) | ||||
| def test_streaming_repartition_empty_dataset( | ||||
| ray_start_regular_shared_2_cpus, | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto with the comment in
dataset.pyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in f748b79