Skip to content

[Data] Avoid redundant reads in train_test_split#60274

Open
myandpr wants to merge 2 commits intoray-project:masterfrom
myandpr:train-test-split
Open

[Data] Avoid redundant reads in train_test_split#60274
myandpr wants to merge 2 commits intoray-project:masterfrom
myandpr:train-test-split

Conversation

@myandpr
Copy link
Member

@myandpr myandpr commented Jan 18, 2026

Description

This PR fixes train_test_split/split_proportionately reading the dataset twice in the example by avoiding redundant execution.

Related issues

Link related issues: "Fixes #51223 ", "Closes #51223 ", or "Related to #51223 ".

Additional information

implementation

  • train_test_split(test_size=int) now tries _meta_count() first; if unknown, it materialize()s once and reuses the metadata row count, only falling back to count() if still unknown. It then splits on the (materialized) dataset to avoid re-reading.
  • split_proportionately() uses the same pattern to avoid count() forcing an extra execution for float splits.
  • _validate_test_size_int() now accepts an optional precomputed ds_length to avoid redundant count() calls.

Local quick check

Reproduce script

demo_torch_detection_like_issue.py

(based on Torch object detection example https://docs.ray.io/en/latest/train/examples/pytorch/torch_detection.html)

import io
import os
import ray
import requests
import xmltodict
import numpy as np
from PIL import Image
from typing import Any, Dict, List, Tuple

CLASS_TO_LABEL = {"background": 0, "cat": 1, "dog": 2}

def decode_annotation(row: Dict[str, Any]) -> Dict[str, Any]:
    text = row["bytes"].decode("utf-8")
    annotation = xmltodict.parse(text)["annotation"]
    objects = annotation["object"]
    if isinstance(objects, dict):
        objects = [objects]
    boxes: List[Tuple] = []
    for obj in objects:
        x1 = float(obj["bndbox"]["xmin"])
        y1 = float(obj["bndbox"]["ymin"])
        x2 = float(obj["bndbox"]["xmax"])
        y2 = float(obj["bndbox"]["ymax"])
        boxes.append((x1, y1, x2, y2))
    labels = [CLASS_TO_LABEL[obj["name"]] for obj in objects]
    return {"boxes": boxes, "labels": labels, "filename": annotation["filename"]}


def read_images(row: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    url = os.path.join("https://s3-us-west-2.amazonaws.com/air-example-data/AnimalDetection/JPEGImages",row["filename"],)
    response = requests.get(url)
    image = Image.open(io.BytesIO(response.content))
    row["image"] = np.array(image)
    return row

if __name__ == "__main__":
    ray.init()
    annotations = ray.data.read_binary_files("s3://anonymous@air-example-data/AnimalDetection/Annotations").map(decode_annotation)
    dataset = annotations.limit(20).map(read_images)
    train, test = dataset.train_test_split(test_size=0.2)
    print("train count:", train.count())
    print("test count:", test.count())
    ray.shutdown()

Before:

2026-01-19 02:12:40,039	INFO worker.py:2007 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
/Users/xxx/miniforge3/envs/clion-ray-ce/lib/python3.10/site-packages/ray/_private/worker.py:2055: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
  warnings.warn(
RayContext(dashboard_url='127.0.0.1:8265', python_version='3.10.19', ray_version='3.0.0.dev0', ray_commit='{{RAY_COMMIT_SHA}}')
2026-01-19 02:12:53,068	INFO logging.py:397 -- Registered dataset logger for dataset dataset_4_0
2026-01-19 02:12:53,091	INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2026-01-19_02-12-36_655885_24877/logs/ray-data
2026-01-19 02:12:53,091	INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)->Project] -> AggregateNumRows[AggregateNumRows]
2026-01-19 02:12:53,108	WARNING resource_manager.py:134 -- ⚠️  Ray's object store is configured to use only 23.6% of available memory (2.0GiB out of 8.5GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
2026-01-19 02:12:53,108	INFO streaming_executor.py:661 -- [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
Running Dataset dataset_4_0.: 0.00 row [00:00, ? row/s]                                                                                                 2026-01-19 02:12:53,144  WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up.                                                                            | 0.00/1.00 [00:00<?, ? row/s]
2026-01-19 02:12:53,156)WARNING utils.py:33 -- Truncating long operator name to 100 characters. To disable this behavior, set `ray.data.DataContext.get_current().DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`.                                                             | 0.00/1.00 [00:00<?, ? row/s]
✔️  Dataset dataset_4_0 execution finished in 37.60 seconds: 100%|██████████████████████████████████████████████████| 1.00/1.00 [00:37<00:00, 37.6s/ row]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|██████████████| 140/140 [00:37<00:00, 3.73 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|██████████████| 20.0/20.0 [00:37<00:00, 1.88s/ row]
- Map(decode_annotation)->...->Project: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20.0 [00:37<00:
- AggregateNumRows: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 8.0B object store: 100%|██████| 1.00/1.00 [00:37<00:00, 37.6s/ row]
2026-01-19 02:13:30,695 INFO streaming_executor.py:302 -- ✔️  Dataset dataset_4_0 execution finished in 37.60 seconds
2026-01-19 02:13:30,703 INFO logging.py:397 -- Registered dataset logger for dataset dataset_3_0
2026-01-19 02:13:30,725 INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_3_0. Full logs are in /tmp/ray/session_2026-01-19_02-12-36_655885_24877/logs/ray-data
2026-01-19 02:13:30,725	INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_3_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)]
Running Dataset dataset_3_0.: 0.00 row [00:00, ? row/s]                                                                                                 2026-01-19 02:13:30,742  WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up.                                                                            | 0.00/1.00 [00:00<?, ? row/s]
✔️  Dataset dataset_3_0 execution finished in 36.65 seconds: 100%|██████████████████████████████████████████████████| 20.0/20.0 [00:36<00:00, 1.83s/ row]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|██████████████| 140/140 [00:36<00:00, 3.82 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|██████████████| 20.0/20.0 [00:36<00:00, 1.83s/ row]
- Map(decode_annotation)->Map(read_images): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 10.5MiB object store: 100%|█| 20.0/20.0 [00
2026-01-19 02:14:07,383 INFO streaming_executor.py:302 -- ✔️  Dataset dataset_3_0 execution finished in 36.65 seconds
train count: 16
test count: 4

After:

2026-01-19 02:19:40,794	INFO worker.py:2007 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
/Users/xxx/miniforge3/envs/clion-ray-ce/lib/python3.10/site-packages/ray/_private/worker.py:2055: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
  warnings.warn(
RayContext(dashboard_url='127.0.0.1:8265', python_version='3.10.19', ray_version='3.0.0.dev0', ray_commit='{{RAY_COMMIT_SHA}}')
2026-01-19 02:19:49,261	INFO logging.py:397 -- Registered dataset logger for dataset dataset_4_0
2026-01-19 02:19:49,280	INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2026-01-19_02-19-37_385772_25956/logs/ray-data
2026-01-19 02:19:49,280	INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)]
2026-01-19 02:19:49,288	WARNING resource_manager.py:134 -- ⚠️  Ray's object store is configured to use only 24.5% of available memory (2.0GiB out of 8.2GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
2026-01-19 02:19:49,288	INFO streaming_executor.py:661 -- [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
Running Dataset dataset_4_0.: 0.00 row [00:00, ? row/s]                                                                                                 2026-01-19 02:19:49,318  WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up.                                                                            | 0.00/1.00 [00:00<?, ? row/s]
✔️  Dataset dataset_4_0 execution finished in 37.63 seconds: 100%|██████████████████████████████████████████████████| 20.0/20.0 [00:37<00:00, 1.88s/ row]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|██████████████| 140/140 [00:37<00:00, 3.72 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|██████████████| 20.0/20.0 [00:37<00:00, 1.88s/ row]
- Map(decode_annotation)->Map(read_images): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 10.5MiB object store: 100%|█| 20.0/20.0 [00
2026-01-19 02:20:26,910 INFO streaming_executor.py:302 -- ✔️  Dataset dataset_4_0 execution finished in 37.63 seconds
train count: 16
test count: 4

Signed-off-by: yaommen <myanstu@163.com>
@myandpr myandpr requested a review from a team as a code owner January 18, 2026 18:44
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the issue of redundant dataset reads in train_test_split and split_proportionately by optimizing how the dataset length is calculated. The changes prioritize using metadata (_meta_count()) and materializing the dataset once when necessary, falling back to count() only as a last resort. This significantly improves performance by avoiding multiple full passes over the data. The modifications to _validate_test_size_int to accept a precomputed length are also a good enhancement. I have one suggestion to improve code clarity.

@ray-gardener ray-gardener bot added data Ray Data-related issues community-contribution Contributed by the community labels Jan 18, 2026
Signed-off-by: yaommen <myanstu@163.com>
@github-actions
Copy link

github-actions bot commented Feb 2, 2026

This pull request has been automatically marked as stale because it has not had
any activity for 14 days. It will be closed in another 14 days if no further activity occurs.
Thank you for your contributions.

You can always ask for help on our discussion forum or Ray's public slack channel.

If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@github-actions github-actions bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Feb 2, 2026
@myandpr myandpr removed the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Feb 2, 2026
Comment on lines +2381 to +2384
ds = ds.materialize()
dataset_length = ds._meta_count()
if dataset_length is None:
dataset_length = ds.count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can't we just call ds.count() here? Cause materializing defeats the purpose of this optimization

Copy link
Member Author

@myandpr myandpr Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. The reason for materializing here is to avoid re-executing the pipeline.

If meta_count is unavailable and we call ds.count(), the plan executes once for counting, and then split_at_indices() executes again to produce splits (the original issue was redundant reads). By materializing once, we can get the row count from metadata and then split from the already computed snapshot, avoiding a second upstream read.

To make this concrete, I updated the implementation and reran the repro script demo_torch_detection_like_issue.py in this PR description. The logs show two executions:

diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py
index ecd0ddce08..9ddc971ee9 100644
--- a/python/ray/data/dataset.py
+++ b/python/ray/data/dataset.py
@@ -2378,10 +2378,7 @@ class Dataset:
         ds = self
         dataset_length = ds._meta_count()
         if dataset_length is None:
-            ds = ds.materialize()
-            dataset_length = ds._meta_count()
-            if dataset_length is None:
-                dataset_length = ds.count()
+            dataset_length = ds.count()
         cumulative_proportions = np.cumsum(proportions)
         split_indices = [
             int(dataset_length * proportion) for proportion in cumulative_proportions
@@ -2470,8 +2467,7 @@ class Dataset:
         else:
             ds_length = ds._meta_count()
             if ds_length is None:
-                ds = ds.materialize()
-                ds_length = ds._meta_count()
+                ds_length = ds.count()
             ds_length = self._validate_test_size_int(test_size, ds, ds_length=ds_length)
             return ds.split_at_indices([ds_length - test_size])

2026-02-07 22:36:45,085	INFO worker.py:2007 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
/Users/xx/work/community/ray/python/ray/_private/worker.py:2055: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
  warnings.warn(
2026-02-07 22:36:56,331	INFO logging.py:397 -- Registered dataset logger for dataset dataset_4_0
2026-02-07 22:36:56,367	INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2026-02-07_22-36-39_967896_16444/logs/ray-data
2026-02-07 22:36:56,367	INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)->Project] -> AggregateNumRows[AggregateNumRows]
2026-02-07 22:36:56,375	WARNING resource_manager.py:134 -- ⚠️  Ray's object store is configured to use only 23.6% of available memory (2.0GiB out of 8.5GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
2026-02-07 22:36:56,376	INFO streaming_executor.py:661 -- [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
Running Dataset dataset_4_0.: 0.00 row [00:00, ? row/s]                                                                                    2026-02-07 22:36:56,402  WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up.                                                  | 0.00/1.00 [00:00<?, ? row/s]
2026-02-07 22:36:56,417)WARNING utils.py:33 -- Truncating long operator name to 100 characters. To disable this behavior, set `ray.data.DataContext.get_current().DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`.                                   | 0.00/1.00 [00:00<?, ? row/s]
✔️  Dataset dataset_4_0 execution finished in 33.39 seconds: 100%|█████████████████████████████████████| 1.00/1.00 [00:33<00:00, 33.4s/ row]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 140/140 [00:33<00:00, 4.19 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20.0 [00:33<00:00, 1.67s/ row]
- Map(decode_annotation)->...->Project: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20
- AggregateNumRows: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 8.0B object store: 100%|█| 1.00/1.00 [00:33<00:00, 33.
2026-02-07 22:37:29,762 INFO streaming_executor.py:302 -- ✔️  Dataset dataset_4_0 execution finished in 33.39 seconds
2026-02-07 22:37:29,766 INFO logging.py:397 -- Registered dataset logger for dataset dataset_3_0
2026-02-07 22:37:29,791 INFO streaming_executor.py:182 -- Starting execution of Dataset dataset_3_0. Full logs are in /tmp/ray/session_2026-02-07_22-36-39_967896_16444/logs/ray-data
2026-02-07 22:37:29,791	INFO streaming_executor.py:183 -- Execution plan of Dataset dataset_3_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadBinary] -> LimitOperator[limit=20] -> TaskPoolMapOperator[Map(decode_annotation)->Map(read_images)]
Running Dataset dataset_3_0.: 0.00 row [00:00, ? row/s]                                                                                    2026-02-07 22:37:29,812  WARNING resource_manager.py:791 -- Cluster resources are not enough to run any task from TaskPoolMapOperator[ReadBinary]. The job may hang forever unless the cluster scales up.                                                  | 0.00/1.00 [00:00<?, ? row/s]
✔️  Dataset dataset_3_0 execution finished in 38.60 seconds: : 20.0 row [00:38, 1.93s/ row]                                    0<?, ? row/s]
- ReadBinary: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 140/140 [00:38<00:00, 3.63 row/s]
- limit=20: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.0/20.0 [00:38<00:00, 1.93s/ row]
- Map(decode_annotation)->Map(read_images): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: 100%|█| 20.
2026-02-07 22:38:08,398 INFO streaming_executor.py:302 -- ✔️  Dataset dataset_3_0 execution finished in 38.60 seconds
train count: 16
test count: 4
  1. a run with ... -> AggregateNumRows (from count()), and
  2. another run of the upstream pipeline for the split (... -> Map(read_images)).

So this change regresses to double reads. That’s why the current materialize -> split path is intentional: it executes once, reuses metadata, and avoids the second upstream read.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community data Ray Data-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Data] Dataset.train_test_split reads dataset twice

2 participants