diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index caeb27a20a30a..8fdea060e5657 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import torch +import pyarrow from ray.air._internal.device_manager import get_torch_device_manager_by_context from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed @@ -292,3 +293,227 @@ def consume_prefix_in_state_dict_if_present_not_in_place( metadata[newkey] = metadata.pop(key) return state_dict + + +def convert_ndarray_list_to_torch_tensor_list( + ndarrays: Union[List[np.ndarray], Dict[str, List[np.ndarray]]], + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[str] = None, +) -> Union[List[torch.Tensor], Dict[str, List[torch.Tensor]]]: + """Convert a list of NumPy ndarrays or dict of lists of ndarrays to Torch Tensors. + + Args: + ndarrays: A list of NumPy ndarrays or a dict mapping column names to lists of + ndarrays that we wish to convert to Torch Tensors. + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: The device on which the tensor(s) should be placed; if None, the Torch + tensor(s) will be constructed on the CPU. + + Returns: A list of Torch Tensors or a dict mapping column names to lists of Tensors. + """ + if isinstance(ndarrays, list): + # Single column case - list of ndarrays + if isinstance(dtypes, dict): + if len(dtypes) != 1: + raise ValueError( + "When constructing a single-column batch, only a single dtype " + f"should be given, instead got: {dtypes}" + ) + dtypes = next(iter(dtypes.values())) + return [ + convert_ndarray_batch_to_torch_tensor_batch( + ndarray, dtypes=dtypes, device=device + ) + for ndarray in ndarrays + ] + else: + # Multi-column case - dict of lists of ndarrays + return { + col_name: [ + convert_ndarray_batch_to_torch_tensor_batch( + ndarray, + dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes, + device=device, + ) + for ndarray in col_ndarrays + ] + for col_name, col_ndarrays in ndarrays.items() + } + + +def arrow_batch_to_tensors( + batch: pyarrow.Table, + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[str] = None, + combine_chunks: bool = False, +) -> Dict[str, List[torch.Tensor]]: + """Convert PyArrow batch to PyTorch tensors. + + Args: + batch: PyArrow batch to convert + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: Optional device to place tensors on + combine_chunks: If True, combine chunks in Arrow batch before converting to + tensors. + + Returns: + A dictionary of column name to list of tensors + """ + from ray.data._internal.arrow_ops import transform_pyarrow + + if combine_chunks: + numpy_batch = transform_pyarrow.table_to_numpy_dict_combined( + batch, + zero_copy_only=False, + ) + return convert_ndarray_batch_to_torch_tensor_batch( + numpy_batch, + dtypes=dtypes, + device=device, + ) + else: + numpy_list = transform_pyarrow.table_to_numpy_dict_chunked( + batch, + zero_copy_only=False, + ) + return convert_ndarray_list_to_torch_tensor_list( + numpy_list, + dtypes=dtypes, + device=device, + ) + + +def numpy_batch_to_torch_tensors( + batch: Dict[str, np.ndarray], + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[str] = None, +) -> Dict[str, List[torch.Tensor]]: + """Convert a dictionary of numpy arrays to PyTorch tensors. + + Args: + batch: Dictionary mapping column names to numpy arrays + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: Optional device to place tensors on + + Returns: + A dictionary of column name to list of tensors + """ + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + batch, + dtypes=dtypes, + device=device, + ) + + +@torch.no_grad() +def concat_tensors_to_device( + tensor_list: List[torch.Tensor], + device: str, +) -> torch.Tensor: + """Stack list of tensors into a contiguous GPU tensor. + + Args: + tensor_list: List of tensors to stack + device: The device to move tensors to + + Returns: + A contiguous tensor on the target device + """ + # Assumes tensors have the same shape/dtype + assert tensor_list, "Cannot stack empty list of tensors" + assert all( + isinstance(t, torch.Tensor) for t in tensor_list + ), "All items must be torch.Tensor" + assert all( + t.dtype == tensor_list[0].dtype for t in tensor_list + ), "All tensors must have the same dtype" + assert all( + t.shape[1:] == tensor_list[0].shape[1:] for t in tensor_list + ), "All tensors must have the same shape[1:]" + + first = tensor_list[0] + dtype = first.dtype + shape_tail = first.shape[1:] + total_rows = sum(t.shape[0] for t in tensor_list) + + # Allocate an empty Tensor on device + result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device) + + row_start = 0 + for t in tensor_list: + row_end = row_start + t.shape[0] + if t.is_pinned(): + # Perform non-blocking transfer if the tensor is pinned + result[row_start:row_end].copy_(t, non_blocking=True) + else: + # Perform blocking transfer if the tensor is not pinned + result[row_start:row_end].copy_(t) + row_start = row_end + + assert isinstance(result, torch.Tensor), "Result must be a torch.Tensor" + return result + + +@torch.no_grad() +def move_tensors_to_device( + batch: Union[ + torch.Tensor, + List[torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, List[torch.Tensor]], + ], + device: Optional[str] = None, +) -> Union[torch.Tensor, Dict[str, torch.Tensor],]: + """Move tensors to the specified device. + + Args: + batch: A tensor or collection of tensors to move to device. Can be: + - A single tensor + - A list of tensors + - A dict mapping keys to tensors + - A dict mapping keys to lists of tensors + device: The device to move tensors to. If None, tensors are not moved. + + Returns: + The input tensors moved to the specified device, maintaining the same structure. + Note that any lists of tensors will be concatenated into a single tensor. + """ + if device is None: + return batch + + if isinstance(batch, dict): + for k, v in batch.items(): + if isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v): + batch[k] = concat_tensors_to_device(v, device=device) + elif isinstance(v, torch.Tensor): + if v.is_pinned(): + batch[k] = v.to(device=device, non_blocking=True) + else: + batch[k] = v.to(device=device) + elif isinstance(batch, list) and all(isinstance(t, torch.Tensor) for t in batch): + batch = concat_tensors_to_device(batch, device=device) + else: + assert isinstance(batch, torch.Tensor), "Batch must be a Tensor" + if batch.is_pinned(): + batch = batch.to(device=device, non_blocking=True) + else: + batch = batch.to(device=device) + + if isinstance(batch, dict): + assert all(isinstance(v, torch.Tensor) for v in batch.values()), ( + "All values in dict must be tensors, got: " + f"{[type(v) for v in batch.values()]}" + ) + else: + assert isinstance( + batch, torch.Tensor + ), "Batch must be a Tensor or dict of Tensors" + + return batch diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 2dc4f867337c3..a6a88d4ed7107 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -643,6 +643,63 @@ def concat_and_sort( return take_table(ret, indices) +def table_to_numpy_dict_combined( + table: "pyarrow.Table", + *, + zero_copy_only: bool = False, +) -> Dict[str, np.ndarray]: + """Convert a PyArrow table to a dictionary of numpy arrays. + + Args: + table: The PyArrow table to convert. + zero_copy_only: Whether to only use zero-copy transfers. + + Returns: + A dictionary of numpy arrays. + """ + + numpy_batch = {} + for col_name in table.column_names: + col = table[col_name] + if isinstance(col, pyarrow.ChunkedArray): + combined_array = combine_chunked_array(col) + numpy_batch[col_name] = to_numpy( + combined_array, zero_copy_only=zero_copy_only + ) + else: + numpy_batch[col_name] = to_numpy(col, zero_copy_only=zero_copy_only) + return numpy_batch + + +def table_to_numpy_dict_chunked( + table: "pyarrow.Table", + *, + zero_copy_only: bool = False, +) -> Dict[str, List[np.ndarray]]: + """Convert a PyArrow table to a dictionary of lists of numpy arrays. + + Args: + table: The PyArrow table to convert. + zero_copy_only: Whether to only use zero-copy transfers. + + Returns: + A dictionary mapping column names to either: + - A list of numpy arrays (for chunked columns) + - A single numpy array (for non-chunked columns) + """ + + numpy_batch = {} + for col_name in table.column_names: + col = table[col_name] + if isinstance(col, pyarrow.ChunkedArray): + numpy_batch[col_name] = [ + to_numpy(chunk, zero_copy_only=zero_copy_only) for chunk in col.chunks + ] + else: + numpy_batch[col_name] = to_numpy(col, zero_copy_only=zero_copy_only) + return numpy_batch + + def to_numpy( array: Union["pyarrow.Array", "pyarrow.ChunkedArray"], *, diff --git a/python/ray/data/collate_fn.py b/python/ray/data/collate_fn.py new file mode 100644 index 0000000000000..e7bb960707a1b --- /dev/null +++ b/python/ray/data/collate_fn.py @@ -0,0 +1,199 @@ +import abc +from typing import Dict, Generic, Optional, TypeVar, Union, List, TYPE_CHECKING, Any + +import numpy as np + +from ray.data.block import DataBatch +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import torch + import pyarrow + import pandas + + from ray.data.dataset import CollatedData + + +DataBatchType = TypeVar("DataBatchType", bound=DataBatch) + + +@DeveloperAPI +class CollateFn(Generic[DataBatchType]): + """Abstract interface for collate_fn for iter_torch_batches. See doc-string of + `collate_fn` for more details. + """ + + @abc.abstractmethod + def __call__(self, batch: DataBatchType) -> "CollatedData": + """Convert a batch of data to collated format. + + Args: + batch: The input batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): + """Collate function that takes pyarrow.Table as the input batch type.""" + + def __call__(self, batch: "pyarrow.Table") -> "CollatedData": + """Convert a batch of pyarrow.Table to collated format. + + Args: + batch: The input pyarrow.Table batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): + """Collate function that takes a dictionary of numpy arrays as the input batch type.""" + + def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": + """Convert a batch of numpy arrays to collated format. + + Args: + batch: The input dictionary of numpy arrays batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]): + """Collate function that takes a pandas.DataFrame as the input batch type.""" + + def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": + """Convert a batch of pandas.DataFrame to collated format. + + Args: + batch: The input pandas.DataFrame batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class DefaultArrowCollateFn(ArrowBatchCollateFn): + """Default collate function for converting Arrow batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: "pyarrow.Table") -> Dict[str, List["torch.Tensor"]]: + """Convert an Arrow batch to PyTorch tensors. + + Args: + batch: PyArrow Table to convert + + Returns: + A dictionary of column name to list of tensors + """ + from ray.air._internal.torch_utils import ( + arrow_batch_to_tensors, + ) + + combine_chunks = self.device == "cpu" + return arrow_batch_to_tensors( + batch, dtypes=self.dtypes, device=None, combine_chunks=combine_chunks + ) + + +@DeveloperAPI +class DefaultNumpyCollateFn(NumpyBatchCollateFn): + """Default collate function for converting Numpy batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List["torch.Tensor"]]: + """Convert a Numpy batch to PyTorch tensors. + + Args: + batch: The input dictionary of numpy arrays batch to collate + + Returns: + A dictionary of column name to list of tensors + """ + from ray.air._internal.torch_utils import ( + numpy_batch_to_torch_tensors, + ) + + return numpy_batch_to_torch_tensors( + batch, dtypes=self.dtypes, device=self.device + ) + + +@DeveloperAPI +class DefaultPandasCollateFn(PandasBatchCollateFn): + """Default collate function for converting Pandas batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: "pandas.DataFrame") -> "torch.Tensor": + """Convert a Pandas batch to PyTorch tensors. + + Args: + batch: The input pandas.DataFrame batch to collate + + Returns: + A PyTorch tensor containing the collated batch. + """ + from ray.air._internal.torch_utils import ( + convert_pandas_to_torch_tensor, + ) + + return convert_pandas_to_torch_tensor( + batch, dtypes=self.dtypes, device=self.device + ) + + +@DeveloperAPI +def default_finalize_fn( + batch: Union[Dict[str, List["torch.Tensor"]], Any], device: Optional[str] = None +) -> Union[Dict[str, "torch.Tensor"], Any]: + """Default finalize function for moving PyTorch tensors to device. + + Args: + batch: Input batch to move to device. Can be: + - Dictionary mapping column names to lists of tensors + - Any other type supported by move_tensors_to_device + device: Target device to move tensors to + + Returns: + Batch with tensors moved to the target device. Type matches input type: + - If input is Dict[str, List[torch.Tensor]], returns Dict[str, torch.Tensor] + - Otherwise returns the same type as input with tensors moved to device + """ + from ray.air._internal.torch_utils import ( + move_tensors_to_device, + ) + + return move_tensors_to_device(batch, device=device) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 7bf8aab67e3a2..390717c6c1452 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -13,6 +13,7 @@ TypeVar, Union, ) +import warnings import numpy as np @@ -25,6 +26,14 @@ from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.context import DataContext from ray.util.annotations import PublicAPI +from ray.data.collate_fn import ( + CollateFn, + ArrowBatchCollateFn, + NumpyBatchCollateFn, + PandasBatchCollateFn, + DefaultArrowCollateFn, + default_finalize_fn, +) if TYPE_CHECKING: import tensorflow as tf @@ -242,7 +251,9 @@ def iter_torch_batches( batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: str = "auto", - collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None, + collate_fn: Optional[ + Union[Callable[[Dict[str, np.ndarray]], "CollatedData"], CollateFn] + ] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -301,17 +312,26 @@ def iter_torch_batches( Dataset is passed to Ray Train and ``collate_fn`` is not provided. Otherwise, defaults to CPU. You can't use this parameter with ``collate_fn``. - collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. - When this parameter is specified, the user should manually handle the - host to device data transfer outside of ``collate_fn``. - This is useful for further processing the data after it has been - batched. Potential use cases include collating along a dimension other - than the first, padding sequences of various lengths, or generally - handling batches of different length tensors. If not provided, the - default collate function is used which simply converts the batch of - numpy arrays to a batch of PyTorch tensors. This API is still - experimental and is subject to change. You can't use this parameter in - conjunction with ``dtypes`` or ``device``. + collate_fn: [Alpha] A function to customize how data batches are collated + before being passed to the model. This is useful for last-mile data + formatting such as padding, masking, or packaging tensors into custom + data structures. If not provided, `iter_torch_batches` automatically + converts batches to `torch.Tensor`s and moves them to the device + assigned to the current worker. The input to `collate_fn` may be: + + 1. dict of np.ndarray, where you should provide a function that + takes in a dict of Numpy arrays + 2. pd.DataFrame, where you should provide a callable class that + subclasses `PandasCollateFn` + 3. pyarrow.Table, where you should provide a callable class that + subclasses `ArrowCollateFn` (recommended for best performance) + + The output can be any type. If the output is a `torch.Tensor`, + `dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be + automatically moved to the current worker's device. For other types, + you must handle device transfer manually in your training loop. + Note: This function is called in a multi-threaded context; avoid using + thread-unsafe code. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the @@ -327,9 +347,6 @@ def iter_torch_batches( An iterable over Torch Tensor batches. """ - from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, - ) from ray.train.torch import get_device if collate_fn is not None and (dtypes is not None or device != "auto"): @@ -342,37 +359,35 @@ def iter_torch_batches( if device == "auto": # Use the appropriate device for Ray Train, or falls back to CPU if # Ray Train is not being used. - device = get_device() + device = get_device().type + finalize_fn = None if collate_fn is None: - # The default collate_fn handles formatting and Tensor creation. - # Here, we set device=None to defer host to device data transfer - # to the subsequent finalize_fn. - def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): - return convert_ndarray_batch_to_torch_tensor_batch( - batch, - dtypes=dtypes, - device=None, - ) - - # The default finalize_fn handles the host to device data transfer. - # This is executed in a 1-thread pool separately from collate_fn - # to allow independent parallelism of these steps. - def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]): - if device is not None: - if isinstance(batch, dict): - for k, t in batch.items(): - batch[k] = t.to(device=device) - else: - batch = batch.to(device=device) - return batch - + collate_fn = DefaultArrowCollateFn( + dtypes=dtypes, + device=device, + ) + finalize_fn = default_finalize_fn + batch_format = "pyarrow" + elif isinstance(collate_fn, ArrowBatchCollateFn): + batch_format = "pyarrow" + elif isinstance(collate_fn, NumpyBatchCollateFn): + batch_format = "numpy" + elif isinstance(collate_fn, PandasBatchCollateFn): + batch_format = "pandas" + elif callable(collate_fn): + batch_format = "numpy" + warnings.warn( + "Passing a callable to `iter_torch_batches` is deprecated and suggest using `ArrowBatchCollateFn` for the best performance.", + DeprecationWarning, + ) else: - finalize_fn = None + raise ValueError(f"Unsupported collate function: {type(collate_fn)}") return self.iter_batches( prefetch_batches=prefetch_batches, batch_size=batch_size, + batch_format=batch_format, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 6fdeb254de710..d6734e5a2745d 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -1,12 +1,14 @@ import sys import threading -from typing import Dict +from typing import Dict, Optional from unittest.mock import MagicMock, patch import numpy as np import pyarrow as pa import pytest import torch +import pyarrow +import pandas as pd import ray @@ -196,6 +198,100 @@ def collate_fn(batch: Dict[str, np.ndarray]): ), iter_batches_calls_kwargs +@pytest.fixture +def custom_collate_fns(): + """Fixture that provides both Arrow and Numpy custom collate functions.""" + from ray.data.iterator import ( + ArrowBatchCollateFn, + NumpyBatchCollateFn, + PandasBatchCollateFn, + ) + + class CustomArrowBatchCollateFn(ArrowBatchCollateFn): + def __init__( + self, + dtypes: Optional[Dict[str, torch.dtype]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: pyarrow.Table) -> torch.Tensor: + """Add 5 to the "id" column at the Arrow level.""" + modified_batch = pyarrow.Table.from_arrays( + [pyarrow.compute.add(batch["id"], 5)], names=["id"] + ) + from ray.air._internal.torch_utils import ( + arrow_batch_to_tensors, + ) + + return arrow_batch_to_tensors( + modified_batch, + dtypes=self.dtypes, + device=self.device, + combine_chunks=self.device == "cpu", + )["id"] + + class CustomNumpyBatchCollateFn(NumpyBatchCollateFn): + def __init__( + self, + dtypes: Optional[Dict[str, torch.dtype]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor: + """Add 5 to the "id" array.""" + modified_batch = {"id": batch["id"] + 5} + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + modified_batch, dtypes=self.dtypes, device=self.device + )["id"] + + class CustomPandasBatchCollateFn(PandasBatchCollateFn): + def __init__( + self, + dtypes: Optional[Dict[str, torch.dtype]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: pd.DataFrame) -> torch.Tensor: + """Add 5 to the "id" column.""" + modified_batch = pd.DataFrame({"id": batch["id"] + 5}) + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + modified_batch.to_dict("series"), dtypes=self.dtypes, device=self.device + )["id"] + + return { + "arrow": CustomArrowBatchCollateFn(device="cpu"), + "numpy": CustomNumpyBatchCollateFn(device="cpu"), + "pandas": CustomPandasBatchCollateFn(device="cpu"), + } + + +@pytest.mark.parametrize("collate_type", ["arrow", "numpy", "pandas"]) +def test_custom_batch_collate_fn( + ray_start_regular_shared, custom_collate_fns, collate_type +): + """Tests that custom batch collate functions can be used to modify + the batch before it is converted to a PyTorch tensor.""" + ds = ray.data.range(5) + it = ds.iterator() + for batch in it.iter_torch_batches(collate_fn=custom_collate_fns[collate_type]): + assert isinstance(batch, torch.Tensor) + assert batch.tolist() == list(range(5, 10)) + + @pytest.fixture(params=["regular", "chunked"]) def null_array_table(request): """Fixture that returns a PyArrow table with either a regular or chunked null array.""" diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index 1b5263360f743..7c45f0b2bff2b 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -1,13 +1,15 @@ # Standard library imports import logging import time -from typing import Any, Dict, Tuple, Iterator, Generator, Optional +from typing import Dict, Tuple, Iterator, Generator, Optional, Union, Type # Third-party imports import torch +import pyarrow import ray import ray.data import ray.train +from ray.data.iterator import ArrowBatchCollateFn # Local imports from config import BenchmarkConfig @@ -188,36 +190,52 @@ def create_batch_iterator( raise -class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): - """Factory for creating Ray DataLoader for image classification tasks. +class CustomArrowCollateFn(ArrowBatchCollateFn): + """Custom collate function for converting Arrow batches to PyTorch tensors.""" - Features: - - Distributed file reading with round-robin worker distribution - - Device transfer and error handling for data batches - - Configurable row limits per worker for controlled processing - - Performance monitoring and logging - """ + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. - def __init__(self, benchmark_config: BenchmarkConfig): - super().__init__(benchmark_config) + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + self.dtypes = dtypes + self.device = device - def collate_fn(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert Ray data batch to PyTorch tensors on the appropriate device. + def __call__(self, batch: "pyarrow.Table") -> Tuple[torch.Tensor, torch.Tensor]: + """Convert an Arrow batch to PyTorch tensors. Args: - batch: Dictionary with 'image' and 'label' numpy arrays + batch: PyArrow Table to convert Returns: - Tuple of (image_tensor, label_tensor) on the target device + Tuple of (image_tensor, label_tensor) """ from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, + arrow_batch_to_tensors, + move_tensors_to_device, ) - device = ray.train.torch.get_device() - batch = convert_ndarray_batch_to_torch_tensor_batch(batch, device=device) + tensors = arrow_batch_to_tensors( + batch, dtypes=self.dtypes, device=None, combine_chunks=self.device == "cpu" + ) + tensors = move_tensors_to_device(tensors, self.device) + return tensors["image"], tensors["label"] + + +class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): + """Factory for creating Ray DataLoader for image classification tasks.""" + + def __init__(self, benchmark_config: BenchmarkConfig): + super().__init__(benchmark_config) - return batch["image"], batch["label"] + def _get_collate_fn_cls(self) -> Type[ArrowBatchCollateFn]: + return CustomArrowCollateFn class ImageClassificationMockDataLoaderFactory(BaseDataLoaderFactory): diff --git a/release/train_tests/benchmark/ray_dataloader_factory.py b/release/train_tests/benchmark/ray_dataloader_factory.py index 96b914b5e5a38..4b3e756e8f7db 100644 --- a/release/train_tests/benchmark/ray_dataloader_factory.py +++ b/release/train_tests/benchmark/ray_dataloader_factory.py @@ -1,9 +1,8 @@ from abc import abstractmethod -from typing import Any, Dict, Tuple +from typing import Any, Dict, Type -import torch import ray.train -from ray.data import Dataset +from ray.data.iterator import ArrowBatchCollateFn from constants import DatasetKey from config import BenchmarkConfig, RayDataConfig @@ -28,24 +27,16 @@ def __init__(self, benchmark_config: BenchmarkConfig) -> None: data_context.retried_io_errors.append("AWS Error ACCESS_DENIED") @abstractmethod - def get_ray_datasets(self) -> Dict[str, Dataset]: - """Get the Ray datasets for training and validation. - - Returns: - Dict with "train" and "val" Dataset objects - """ + def _get_collate_fn_cls(self) -> Type[ArrowBatchCollateFn]: + """Return the collate function class. Must be implemented by subclass.""" pass - @abstractmethod - def collate_fn(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - """Get the collate function for the dataloader. + def get_train_dataloader(self): + """Get the training dataloader. Returns: - A function that takes a batch and returns a tuple of tensors. + Iterator of training batches """ - pass - - def get_train_dataloader(self): ds_iterator = ray.train.get_dataset_shard(DatasetKey.TRAIN) self._ray_ds_iterators[DatasetKey.TRAIN] = ds_iterator @@ -58,13 +49,20 @@ def get_train_dataloader(self): if dataloader_config.local_buffer_shuffle_size > 0 else None ), - collate_fn=self.collate_fn, + collate_fn=self._get_collate_fn_cls()( + device=ray.train.torch.get_device() + ), prefetch_batches=dataloader_config.ray_data_prefetch_batches, drop_last=True, ) ) def get_val_dataloader(self): + """Get the validation dataloader. + + Returns: + Iterator of validation batches + """ ds_iterator = ray.train.get_dataset_shard(DatasetKey.VALID) self._ray_ds_iterators[DatasetKey.VALID] = ds_iterator @@ -72,7 +70,9 @@ def get_val_dataloader(self): return iter( ds_iterator.iter_torch_batches( batch_size=dataloader_config.validation_batch_size, - collate_fn=self.collate_fn, + collate_fn=self._get_collate_fn_cls()( + device=ray.train.torch.get_device() + ), prefetch_batches=dataloader_config.ray_data_prefetch_batches, drop_last=True, )