[Data] Avoid redundant reads in train_test_split#60274
[Data] Avoid redundant reads in train_test_split#60274myandpr wants to merge 2 commits intoray-project:masterfrom
Conversation
Signed-off-by: yaommen <myanstu@163.com>
There was a problem hiding this comment.
Code Review
This pull request effectively addresses the issue of redundant dataset reads in train_test_split and split_proportionately by optimizing how the dataset length is calculated. The changes prioritize using metadata (_meta_count()) and materializing the dataset once when necessary, falling back to count() only as a last resort. This significantly improves performance by avoiding multiple full passes over the data. The modifications to _validate_test_size_int to accept a precomputed length are also a good enhancement. I have one suggestion to improve code clarity.
Signed-off-by: yaommen <myanstu@163.com>
|
This pull request has been automatically marked as stale because it has not had You can always ask for help on our discussion forum or Ray's public slack channel. If you'd like to keep this open, just leave any comment, and the stale label will be removed. |
| ds = ds.materialize() | ||
| dataset_length = ds._meta_count() | ||
| if dataset_length is None: | ||
| dataset_length = ds.count() |
There was a problem hiding this comment.
Actually can't we just call ds.count() here? Cause materializing defeats the purpose of this optimization
There was a problem hiding this comment.
Good question. The reason for materializing here is to avoid re-executing the pipeline.
If meta_count is unavailable and we call ds.count(), the plan executes once for counting, and then split_at_indices() executes again to produce splits (the original issue was redundant reads). By materializing once, we can get the row count from metadata and then split from the already computed snapshot, avoiding a second upstream read.
To make this concrete, I updated the implementation and reran the repro script demo_torch_detection_like_issue.py in this PR description. The logs show two executions:
diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py
index ecd0ddce08..9ddc971ee9 100644
--- a/python/ray/data/dataset.py
+++ b/python/ray/data/dataset.py
@@ -2378,10 +2378,7 @@ class Dataset:
ds = self
dataset_length = ds._meta_count()
if dataset_length is None:
- ds = ds.materialize()
- dataset_length = ds._meta_count()
- if dataset_length is None:
- dataset_length = ds.count()
+ dataset_length = ds.count()
cumulative_proportions = np.cumsum(proportions)
split_indices = [
int(dataset_length * proportion) for proportion in cumulative_proportions
@@ -2470,8 +2467,7 @@ class Dataset:
else:
ds_length = ds._meta_count()
if ds_length is None:
- ds = ds.materialize()
- ds_length = ds._meta_count()
+ ds_length = ds.count()
ds_length = self._validate_test_size_int(test_size, ds, ds_length=ds_length)
return ds.split_at_indices([ds_length - test_size])
2026-02-07 22:36:45,085 INFO worker.py:2007 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
/Users/xx/work/community/ray/python/ray/_private/worker.py:2055: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
warnings.warn(
2026-02-07 22:36:56,331 INFO logging.py:397 -- Registered dataset logger for dataset dataset_4_0
2026-02-07 22:36:56,367 INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2026-02-07_22-36-39_967896_16444/logs/ray-data
2026-02-07 22:36:56,367 INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)->Project] -> AggregateNumRows[AggregateNumRows]
2026-02-07 22:36:56,375 WARNING resource_manager.py:134 -- ⚠️ Ray's object store is configured to use only 23.6% of available memory (2.0GiB out of 8.5GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
2026-02-07 22:36:56,376 INFO streaming_executor.py:661 -- [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
Running Dataset dataset_4_0.: 0.00 row [00:00, ? row/s] 2026-02-07 22:36:56,402 WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up. | 0.00/1.00 [00:00<?, ? row/s]
2026-02-07 22:36:56,417)WARNING utils.py:33 -- Truncating long operator name to 100 characters. To disable this behavior, set `ray.data.DataContext.get_current().DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`. | 0.00/1.00 [00:00<?, ? row/s]
✔️ Dataset dataset_4_0 execution finished in 33.39 seconds: 100%|█████████████████████████████████████| 1.00/1.00 [00:33<00:00, 33.4s/ row]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 140/140 [00:33<00:00, 4.19 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20.0 [00:33<00:00, 1.67s/ row]
- Map(decode_annotation)->...->Project: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20
- AggregateNumRows: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 8.0B object store: 100%|█| 1.00/1.00 [00:33<00:00, 33.
2026-02-07 22:37:29,762 INFO streaming_executor.py:302 -- ✔️ Dataset dataset_4_0 execution finished in 33.39 seconds
2026-02-07 22:37:29,766 INFO logging.py:397 -- Registered dataset logger for dataset dataset_3_0
2026-02-07 22:37:29,791 INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_3_0. Full logs are in /tmp/ray/session_2026-02-07_22-36-39_967896_16444/logs/ray-data
2026-02-07 22:37:29,791 INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_3_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)]
Running Dataset dataset_3_0.: 0.00 row [00:00, ? row/s] 2026-02-07 22:37:29,812 WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up. | 0.00/1.00 [00:00<?, ? row/s]
✔️ Dataset dataset_3_0 execution finished in 38.60 seconds: : 20.0 row [00:38, 1.93s/ row] 0<?, ? row/s]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 140/140 [00:38<00:00, 3.63 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20.0 [00:38<00:00, 1.93s/ row]
- Map(decode_annotation)->Map(read_images): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.
2026-02-07 22:38:08,398 INFO streaming_executor.py:302 -- ✔️ Dataset dataset_3_0 execution finished in 38.60 seconds
train count: 16
test count: 4
- a run with
... -> AggregateNumRows(fromcount()), and - another run of the upstream pipeline for the split (
... -> Map(read_images)).
So this change regresses to double reads. That’s why the current materialize -> split path is intentional: it executes once, reuses metadata, and avoids the second upstream read.
Description
This PR fixes train_test_split/split_proportionately reading the dataset twice in the example by avoiding redundant execution.
Related issues
Additional information
implementation
Local quick check
Reproduce script
demo_torch_detection_like_issue.py(based on Torch object detection example https://docs.ray.io/en/latest/train/examples/pytorch/torch_detection.html)
Before:
After: