diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index caeb27a20a30a..55a24ff87d751 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -1,12 +1,15 @@ import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Sequence +from collections.abc import Mapping, Sequence as ABCSequence import numpy as np import pandas as pd import torch +import pyarrow from ray.air._internal.device_manager import get_torch_device_manager_by_context from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed +from ray.data.collate_fn import TensorBatchType, TensorBatchReturnType def get_devices() -> List[torch.device]: @@ -131,7 +134,7 @@ def get_tensor_for_columns(columns, dtype): def convert_ndarray_to_torch_tensor( ndarray: np.ndarray, dtype: Optional[torch.dtype] = None, - device: Optional[str] = None, + device: Optional[Union[str, "torch.device"]] = None, ) -> torch.Tensor: """Convert a NumPy ndarray to a Torch Tensor. @@ -169,7 +172,7 @@ def convert_ndarray_to_torch_tensor( def convert_ndarray_batch_to_torch_tensor_batch( ndarrays: Union[np.ndarray, Dict[str, np.ndarray]], dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, - device: Optional[str] = None, + device: Optional[Union[str, "torch.device"]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """Convert a NumPy ndarray batch to a Torch Tensor batch. @@ -292,3 +295,203 @@ def consume_prefix_in_state_dict_if_present_not_in_place( metadata[newkey] = metadata.pop(key) return state_dict + + +def convert_ndarray_list_to_torch_tensor_list( + ndarrays: Dict[str, List[np.ndarray]], + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[Union[str, "torch.device"]] = None, +) -> Dict[str, List[torch.Tensor]]: + """Convert a dict mapping column names to lists of ndarrays to Torch Tensors. + + Args: + ndarrays: A dict mapping column names to lists of ndarrays that we wish to convert + to Torch Tensors. + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: The device on which the tensor(s) should be placed; if None, the Torch + tensor(s) will be constructed on the CPU. + + Returns: A dict mapping column names to lists of Tensors. + """ + return { + col_name: [ + convert_ndarray_batch_to_torch_tensor_batch( + ndarray, + dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes, + device=device, + ) + for ndarray in col_ndarrays + ] + for col_name, col_ndarrays in ndarrays.items() + } + + +def arrow_batch_to_tensors( + batch: pyarrow.Table, + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + combine_chunks: bool = False, +) -> Dict[str, List[torch.Tensor]]: + """Convert PyArrow batch to PyTorch tensors. + + Args: + batch: PyArrow batch to convert + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + combine_chunks: If True, combine chunks in Arrow batch before converting to + tensors. + + Returns: + A dictionary of column name to list of tensors. For non-chunked columns, + the list will contain a single tensor. + """ + from ray.data._internal.arrow_ops import transform_pyarrow + from ray.data._internal.arrow_block import ArrowBlockAccessor + + if combine_chunks: + numpy_batch = ArrowBlockAccessor(batch).to_batch_format("numpy") + return { + col_name: convert_ndarray_batch_to_torch_tensor_batch( + col_array, + dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes, + ) + for col_name, col_array in numpy_batch.items() + } + else: + numpy_list = transform_pyarrow.table_to_numpy_dict_chunked( + batch, + ) + return convert_ndarray_list_to_torch_tensor_list( + numpy_list, + dtypes=dtypes, + ) + + +@torch.no_grad() +def concat_tensors_to_device( + tensor_sequence: Sequence[torch.Tensor], + device: Optional[Union[str, "torch.device"]] = None, + non_blocking: bool = False, +) -> torch.Tensor: + """Stack sequence of tensors into a contiguous GPU tensor. + + Args: + tensor_sequence: Sequence of tensors to stack + device: The device to move tensors to + non_blocking: If True, perform device transfer without forcing a + synchronization. + + Returns: + A contiguous tensor on the target device + """ + # Assumes tensors have the same shape/dtype + assert tensor_sequence, "Cannot stack empty sequence of tensors" + assert all( + isinstance(t, torch.Tensor) for t in tensor_sequence + ), "All items must be torch.Tensor" + assert all( + t.dtype == tensor_sequence[0].dtype for t in tensor_sequence + ), "All tensors must have the same dtype" + assert all( + t.shape[1:] == tensor_sequence[0].shape[1:] for t in tensor_sequence + ), "All tensors must have the same shape[1:]" + + first = tensor_sequence[0] + dtype = first.dtype + shape_tail = first.shape[1:] + total_rows = sum(t.shape[0] for t in tensor_sequence) + + # Allocate an empty Tensor on device + result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device) + + row_start = 0 + for t in tensor_sequence: + row_end = row_start + t.shape[0] + result[row_start:row_end].copy_(t, non_blocking=non_blocking) + row_start = row_end + + assert isinstance(result, torch.Tensor), "Result must be a torch.Tensor" + return result + + +@torch.no_grad() +def move_tensors_to_device( + batch: TensorBatchType, + device: Optional[Union[str, "torch.device"]] = None, + non_blocking: bool = False, +) -> TensorBatchReturnType: + """Move tensors to the specified device. + + Args: + batch: A tensor or collection of tensors to move to device. Can be: + - A single tensor + - A sequence of tensors + - A sequence of sequences of tensors. The inner sequence of tensors is + combined during GPU transfer. + - A mapping (e.g., dict) of keys to tensors or sequences of tensors. The + sequence of tensors is combined during GPU transfer. + device: The device to move tensors to. If None, tensors are not moved. + non_blocking: If True, perform device transfer without forcing a + synchronization. + + Returns: + The input tensors moved to the specified device + """ + if device is None: + return batch + + if isinstance(batch, torch.Tensor): + return batch.to(device=device, non_blocking=non_blocking) + + elif isinstance(batch, Mapping): + new_batch = {} + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + new_batch[k] = v.to(device=device, non_blocking=non_blocking) + + elif isinstance(v, ABCSequence) and not isinstance(v, (str, bytes)): + if all(isinstance(t, torch.Tensor) for t in v): + new_batch[k] = concat_tensors_to_device( + v, device=device, non_blocking=non_blocking + ) + else: + raise TypeError( + f"Expected a sequence of torch.Tensor for key '{k}', " + f"but got sequence of types: {[type(t) for t in v]}" + ) + + else: + raise TypeError( + f"Unsupported type for key '{k}': expected torch.Tensor or " + f"sequence of torch.Tensor, but got {type(v)}" + ) + return new_batch + + elif isinstance(batch, ABCSequence) and not isinstance(batch, (str, bytes)): + if all(isinstance(t, torch.Tensor) for t in batch): + return concat_tensors_to_device( + batch, device=device, non_blocking=non_blocking + ) + + elif all( + isinstance(seq, ABCSequence) + and not isinstance(seq, (str, bytes)) + and all(isinstance(t, torch.Tensor) for t in seq) + for seq in batch + ): + return tuple( + concat_tensors_to_device(seq, device=device, non_blocking=non_blocking) + for seq in batch + ) + + else: + sample_type = type(batch[0]) if batch else "empty" + raise TypeError( + f"Unsupported sequence structure. Got: {type(batch)} with inner type " + f"{sample_type}" + ) + + else: + raise TypeError( + f"Batch must be one of {TensorBatchType} types, got {type(batch)}" + ) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index bc8effb580f71..0301ce483f7a0 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -643,6 +643,32 @@ def concat_and_sort( return take_table(ret, indices) +def table_to_numpy_dict_chunked( + table: "pyarrow.Table", +) -> Dict[str, List[np.ndarray]]: + """Convert a PyArrow table to a dictionary of lists of numpy arrays. + + Args: + table: The PyArrow table to convert. + + Returns: + A dictionary mapping column names to lists of numpy arrays. For chunked columns, + the list will contain multiple arrays (one per chunk). For non-chunked columns, + the list will contain a single array. + """ + + numpy_batch = {} + for col_name in table.column_names: + col = table[col_name] + if isinstance(col, pyarrow.ChunkedArray): + numpy_batch[col_name] = [ + to_numpy(chunk, zero_copy_only=False) for chunk in col.chunks + ] + else: + numpy_batch[col_name] = [to_numpy(col, zero_copy_only=False)] + return numpy_batch + + def to_numpy( array: Union["pyarrow.Array", "pyarrow.ChunkedArray"], *, diff --git a/python/ray/data/collate_fn.py b/python/ray/data/collate_fn.py new file mode 100644 index 0000000000000..6cdf42ccc5b13 --- /dev/null +++ b/python/ray/data/collate_fn.py @@ -0,0 +1,218 @@ +import abc +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +from ray.data.block import DataBatch +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import pandas + import pyarrow + import torch + + from ray.data.dataset import CollatedData + + +DataBatchType = TypeVar("DataBatchType", bound=DataBatch) + + +TensorBatchType = Union[ + "torch.Tensor", + Sequence["torch.Tensor"], + # For nested sequences of tensors, the inner sequence of tensors is combined during + # GPU transfer in `move_tensors_to_device`. + Sequence[Sequence["torch.Tensor"]], + Mapping[str, "torch.Tensor"], + # For mapping (e.g., dict) of keys to sequences of tensors, the sequence of tensors + # is combined during GPU transfer in `move_tensors_to_device`. + Mapping[str, Sequence["torch.Tensor"]], +] + + +@DeveloperAPI +def is_tensor_batch_type(batch: Any) -> bool: + """Check if a batch matches any of the TensorBatchType variants. + + This function checks if the input batch is one of the following types: + 1. A single torch.Tensor + 2. A sequence of torch.Tensors + 3. A sequence of sequences of torch.Tensors + 4. A mapping (e.g., dict) of keys to torch.Tensors + 5. A mapping (e.g., dict) of keys to sequences of torch.Tensors + + Args: + batch: The input batch to check. Can be any type. + + Returns: + bool: True if the batch matches any TensorBatchType variant, False otherwise. + """ + import torch + + if isinstance(batch, torch.Tensor): + return True + + # A sequence of tensors or sequence of sequences of tensors + if isinstance(batch, Sequence) and not isinstance(batch, (str, bytes)): + return all( + isinstance(t, torch.Tensor) # Direct tensor + or ( + isinstance(t, Sequence) # Nested sequence + and all(isinstance(tt, torch.Tensor) for tt in t) + ) + for t in batch + ) + + # A mapping (e.g., dict) of keys to torch.Tensors or a mapping (e.g., dict) of + # keys to sequences of torch.Tensors + if isinstance(batch, Mapping): + return all( + isinstance(v, torch.Tensor) # The value is a tensor + or ( + isinstance(v, Sequence) # The value is a sequence + and all(isinstance(t, torch.Tensor) for t in v) + ) + for v in batch.values() + ) + + return False + + +TensorBatchReturnType = Union[ + "torch.Tensor", + Tuple["torch.Tensor", ...], + Dict[str, "torch.Tensor"], +] + + +@DeveloperAPI +class CollateFn(Generic[DataBatchType]): + """Abstract interface for collate_fn for `iter_torch_batches`. See doc-string of + `collate_fn` in `iter_torch_batches` API for more details. + """ + + @abc.abstractmethod + def __call__(self, batch: DataBatchType) -> "CollatedData": + """Convert a batch of data to collated format. + + Args: + batch: The input batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): + """Collate function that takes pyarrow.Table as the input batch type. + Arrow tables with chunked arrays can be efficiently transferred to GPUs without + combining the chunks with the `arrow_batch_to_tensors` utility function. + See `DefaultCollateFn` for example. + """ + + def __call__(self, batch: "pyarrow.Table") -> "CollatedData": + """Convert a batch of pyarrow.Table to collated format. + + Args: + batch: The input pyarrow.Table batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): + """Collate function that takes a dictionary of numpy arrays as the input batch type.""" + + def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": + """Convert a batch of numpy arrays to collated format. + + Args: + batch: The input dictionary of numpy arrays batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]): + """Collate function that takes a pandas.DataFrame as the input batch type.""" + + def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": + """Convert a batch of pandas.DataFrame to collated format. + + Args: + batch: The input pandas.DataFrame batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class DefaultCollateFn(ArrowBatchCollateFn): + """Default collate function for converting Arrow batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[Union[str, "torch.device"]] = None, + ): + """Initialize the collate function. + + Args: + dtypes: The torch dtype(s) for the created tensor(s); if None, the dtype + will be inferred from the tensor data. + device: The device on which the tensor should be placed. Can be a string + (e.g. "cpu", "cuda:0") or a torch.device object. + """ + import torch + + super().__init__() + self.dtypes = dtypes + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + def __call__(self, batch: "pyarrow.Table") -> Dict[str, List["torch.Tensor"]]: + """Convert an Arrow batch to PyTorch tensors. + + Args: + batch: PyArrow Table to convert + + Returns: + Dictionary mapping column names to lists of tensors + """ + from ray.air._internal.torch_utils import ( + arrow_batch_to_tensors, + ) + + # For GPU transfer, we can skip the combining chunked arrays. This is because + # we can convert the chunked arrays to corresponding numpy format and then to + # Tensors and transfer the corresponding list of Tensors to GPU directly. + # However, for CPU transfer, we need to combine the chunked arrays first + # before converting to numpy format and then to Tensors. + combine_chunks = self.device.type == "cpu" + return arrow_batch_to_tensors( + batch, dtypes=self.dtypes, combine_chunks=combine_chunks + ) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index d410c0566e2d3..c1d4c49b49b06 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -1,5 +1,6 @@ import abc import time +import warnings from typing import ( TYPE_CHECKING, Any, @@ -23,8 +24,18 @@ from ray.data._internal.plan import ExecutionPlan from ray.data._internal.stats import DatasetStats, StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format +from ray.data.collate_fn import ( + ArrowBatchCollateFn, + CollateFn, + DefaultCollateFn, + NumpyBatchCollateFn, + PandasBatchCollateFn, + TensorBatchReturnType, + TensorBatchType, + is_tensor_batch_type, +) from ray.data.context import DataContext -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, RayDeprecationWarning if TYPE_CHECKING: import tensorflow as tf @@ -242,7 +253,9 @@ def iter_torch_batches( batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: str = "auto", - collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None, + collate_fn: Optional[ + Union[Callable[[Dict[str, np.ndarray]], "CollatedData"], CollateFn] + ] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -301,17 +314,28 @@ def iter_torch_batches( Dataset is passed to Ray Train and ``collate_fn`` is not provided. Otherwise, defaults to CPU. You can't use this parameter with ``collate_fn``. - collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. - When this parameter is specified, the user should manually handle the - host to device data transfer outside of ``collate_fn``. - This is useful for further processing the data after it has been - batched. Potential use cases include collating along a dimension other - than the first, padding sequences of various lengths, or generally - handling batches of different length tensors. If not provided, the - default collate function is used which simply converts the batch of - numpy arrays to a batch of PyTorch tensors. This API is still - experimental and is subject to change. You can't use this parameter in - conjunction with ``dtypes`` or ``device``. + collate_fn: [Alpha] A function to customize how data batches are collated + before being passed to the model. This is useful for last-mile data + formatting such as padding, masking, or packaging tensors into custom + data structures. If not provided, `iter_torch_batches` automatically + converts batches to `torch.Tensor`s and moves them to the device + assigned to the current worker. The input to `collate_fn` may be: + + 1. pyarrow.Table, where you should provide a callable class that + subclasses `ArrowBatchCollateFn` (recommended for best performance). + Note that you should use util function `arrow_batch_to_tensors` to + convert the pyarrow.Table to a dictionary of non-contiguous tensor + batches. + 2. Dict[str, np.ndarray], where you should provide a callable class that + subclasses `NumpyBatchCollateFn` + 3. pd.DataFrame, where you should provide a callable class that + subclasses `PandasBatchCollateFn` + + The output can be any type. If the output is a `TensorBatchType`, it will be + automatically moved to the current worker's device. For other types, + you must handle device transfer manually in your training loop. + Note: This function is called in a multi-threaded context; avoid using + thread-unsafe code. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the @@ -327,9 +351,6 @@ def iter_torch_batches( An iterable over Torch Tensor batches. """ - from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, - ) from ray.train.torch import get_device if collate_fn is not None and (dtypes is not None or device != "auto"): @@ -344,40 +365,74 @@ def iter_torch_batches( # Ray Train is not being used. device = get_device() - if collate_fn is None: - # The default collate_fn handles formatting and Tensor creation. - # Here, we set device=None to defer host to device data transfer - # to the subsequent finalize_fn. - def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): - return convert_ndarray_batch_to_torch_tensor_batch( - batch, - dtypes=dtypes, - device=None, - ) + from ray.air._internal.torch_utils import ( + move_tensors_to_device, + ) - # The default finalize_fn handles the host to device data transfer. - # This is executed in a 1-thread pool separately from collate_fn - # to allow independent parallelism of these steps. - def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]): - if device is not None: - if isinstance(batch, dict): - for k, t in batch.items(): - batch[k] = t.to(device=device) - else: - batch = batch.to(device=device) + # The default finalize_fn handles the host to device data transfer. + # This is executed in a 1-thread pool separately from collate_fn + # to allow independent parallelism of these steps. + def default_finalize_fn( + batch: TensorBatchType, + ) -> Union[TensorBatchReturnType, Any]: + """Default finalize function for moving PyTorch tensors to device. If + batch is of type `TensorBatchType`, it will be automatically moved to the + current worker's device. For other types, you must handle device transfer + manually in your training loop. + + Args: + batch: Input batch to move to device. + + Returns: + Batch with tensors moved to the target device. + - If input is TensorBatchType, returns tensors moved to device + - Otherwise returns the same type as input without moving tensors + to device. + """ + if is_tensor_batch_type(batch): + return move_tensors_to_device(batch, device=device) + else: return batch + if collate_fn is None: + # The default collate_fn handles formatting and Tensor creation. + # Here, we defer host to device data transfer to the subsequent + # finalize_fn. + collate_fn = DefaultCollateFn( + dtypes=dtypes, + device=device, + ) + batch_format = "pyarrow" + elif isinstance(collate_fn, ArrowBatchCollateFn): + # The ArrowBatchCollateFn handles formatting and Tensor creation. + # Here, we defer host to device data transfer to the subsequent + # finalize_fn. + batch_format = "pyarrow" + elif isinstance(collate_fn, NumpyBatchCollateFn): + batch_format = "numpy" + elif isinstance(collate_fn, PandasBatchCollateFn): + batch_format = "pandas" + elif callable(collate_fn): + batch_format = "numpy" + warnings.warn( + "Passing a function to `iter_torch_batches(collate_fn)` is " + "deprecated in Ray 2.47. Please switch to using a callable class that " + "inherits from `ArrowBatchCollateFn`, `NumpyBatchCollateFn`, or " + "`PandasBatchCollateFn`.", + RayDeprecationWarning, + ) else: - finalize_fn = None + raise ValueError(f"Unsupported collate function: {type(collate_fn)}") return self.iter_batches( prefetch_batches=prefetch_batches, batch_size=batch_size, + batch_format=batch_format, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, _collate_fn=collate_fn, - _finalize_fn=finalize_fn, + _finalize_fn=default_finalize_fn, ) def iter_tf_batches( diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 6fdeb254de710..a1cf652a9e54f 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -188,11 +188,10 @@ def collate_fn(batch: Dict[str, np.ndarray]): assert isinstance(batch, torch.Tensor) assert batch.tolist() == list(range(5, 10)) - # When collate_fn is specified, check that`_finalize_fn` - # is not used in `DataIterator.iter_batches()`. + # Check that _finalize_fn is always used in `DataIterator.iter_batches()`. iter_batches_calls_kwargs = [a.kwargs for a in it.iter_batches.call_args_list] assert all( - kwargs["_finalize_fn"] is None for kwargs in iter_batches_calls_kwargs + kwargs["_finalize_fn"] is not None for kwargs in iter_batches_calls_kwargs ), iter_batches_calls_kwargs diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index e9842bf7b2110..717a30c0413f0 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -493,6 +493,21 @@ py_test( ], ) +py_test( + name = "test_iter_torch_batches_gpu", + size = "medium", + srcs = ["tests/test_iter_torch_batches_gpu.py"], + tags = [ + "exclusive", + "gpu_only", + "team:ml", + ], + deps = [ + ":conftest", + ":train_lib", + ], +) + py_test( name = "test_gpu_auto_transfer", size = "medium", diff --git a/python/ray/train/tests/test_iter_torch_batches_gpu.py b/python/ray/train/tests/test_iter_torch_batches_gpu.py new file mode 100644 index 0000000000000..5f9e12b255631 --- /dev/null +++ b/python/ray/train/tests/test_iter_torch_batches_gpu.py @@ -0,0 +1,475 @@ +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +import torch + +import ray +from ray.air._internal.torch_utils import ( + arrow_batch_to_tensors, + convert_ndarray_batch_to_torch_tensor_batch, +) +from ray.data.iterator import ( + ArrowBatchCollateFn, + NumpyBatchCollateFn, + PandasBatchCollateFn, +) +from ray.train.torch import get_device + + +class BaseArrowBatchCollateFn(ArrowBatchCollateFn): + """Base class for Arrow batch collate functions that process and convert to tensors. + + This class provides common functionality for processing PyArrow tables and converting + them to PyTorch tensors. It handles device placement and dtype conversion. + + Attributes: + device: Optional device to place tensors on. Can be a string (e.g. "cpu", "cuda:0") + or a torch.device object. + """ + + device: Optional[torch.device] + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__() + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + def _process_batch(self, batch: pa.Table) -> pa.Table: + """Process the batch by adding 5 to the id column. + + Args: + batch: Input PyArrow table containing an "id" column. + + Returns: + A new PyArrow table with modified "id" column and original "value" column. + """ + return pa.Table.from_arrays( + [pa.compute.add(batch["id"], 5), batch["id"]], + names=["id", "value"], + ) + + def _get_tensors(self, batch: pa.Table) -> Dict[str, torch.Tensor]: + """Convert batch to tensors. + + Args: + batch: Input PyArrow table to convert to tensors. + + Returns: + Dictionary mapping column names to PyTorch tensors. + """ + return arrow_batch_to_tensors( + batch, + combine_chunks=self.device.type == "cpu", + ) + + +class SingleTensorArrowBatchCollateFn(BaseArrowBatchCollateFn): + """Collate function that returns only the id column as a tensor.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pa.Table) -> torch.Tensor: + """Return only the id column as a tensor.""" + assert isinstance(batch, pa.Table) + modified_batch = self._process_batch(batch) + return self._get_tensors(modified_batch)["id"] + + +class TupleArrowBatchCollateFn(BaseArrowBatchCollateFn): + """Collate function that returns id and value as a tuple of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pa.Table) -> Tuple[torch.Tensor, torch.Tensor]: + """Return id and value as a tuple of tensors.""" + assert isinstance(batch, pa.Table) + modified_batch = self._process_batch(batch) + return ( + self._get_tensors(modified_batch)["id"], + self._get_tensors(modified_batch)["value"], + ) + + +class DictArrowBatchCollateFn(BaseArrowBatchCollateFn): + """Collate function that returns id and value as a dictionary of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pa.Table) -> Dict[str, torch.Tensor]: + """Return id and value as a dictionary of tensors.""" + assert isinstance(batch, pa.Table) + modified_batch = self._process_batch(batch) + return self._get_tensors(modified_batch) + + +class ListArrowBatchCollateFn(BaseArrowBatchCollateFn): + """Collate function that returns id and value as a list of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pa.Table) -> List[torch.Tensor]: + """Return id and value as a list of tensors.""" + assert isinstance(batch, pa.Table) + modified_batch = self._process_batch(batch) + tensors = self._get_tensors(modified_batch) + return [tensors["id"], tensors["value"]] + + +class BaseNumpyBatchCollateFn(NumpyBatchCollateFn): + """Base class for Numpy batch collate functions that process and convert to tensors. + + This class provides common functionality for processing Numpy arrays and converting + them to PyTorch tensors. It handles device placement and dtype conversion. + + Attributes: + device: Optional device to place tensors on. Can be a string (e.g. "cpu", "cuda:0") + or a torch.device object. + """ + + device: Optional[Union[str, torch.device]] + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__() + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + def _process_batch(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Process the batch by adding 5 to the id array. + + Args: + batch: Input dictionary containing numpy arrays. + + Returns: + A new dictionary with modified "id" array and original "value" array. + """ + return {"id": batch["id"] + 5, "value": batch["id"]} + + def _get_tensors(self, batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]: + """Convert batch to tensors. + + Args: + batch: Input dictionary of numpy arrays to convert to tensors. + + Returns: + Dictionary mapping column names to PyTorch tensors. + """ + return convert_ndarray_batch_to_torch_tensor_batch( + batch, dtypes=None, device=None + ) + + +class SingleTensorNumpyBatchCollateFn(BaseNumpyBatchCollateFn): + """Collate function that returns only the id array as a tensor.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor: + """Return only the id array as a tensor.""" + assert isinstance(batch, dict) + modified_batch = self._process_batch(batch) + return self._get_tensors(modified_batch)["id"] + + +class TupleNumpyBatchCollateFn(BaseNumpyBatchCollateFn): + """Collate function that returns id and value as a tuple of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__( + self, batch: Dict[str, np.ndarray] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Return id and value as a tuple of tensors.""" + assert isinstance(batch, dict) + modified_batch = self._process_batch(batch) + tensors = self._get_tensors(modified_batch) + return tensors["id"], tensors["value"] + + +class DictNumpyBatchCollateFn(BaseNumpyBatchCollateFn): + """Collate function that returns id and value as a dictionary of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]: + """Return id and value as a dictionary of tensors.""" + assert isinstance(batch, dict) + modified_batch = self._process_batch(batch) + return self._get_tensors(modified_batch) + + +class ListNumpyBatchCollateFn(BaseNumpyBatchCollateFn): + """Collate function that returns id and value as a list of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: Dict[str, np.ndarray]) -> List[torch.Tensor]: + """Return id and value as a list of tensors.""" + assert isinstance(batch, dict) + modified_batch = self._process_batch(batch) + tensors = self._get_tensors(modified_batch) + return [tensors["id"], tensors["value"]] + + +class BasePandasBatchCollateFn(PandasBatchCollateFn): + """Base class for Pandas batch collate functions that process and convert to tensors. + + This class provides common functionality for processing Pandas DataFrames and converting + them to PyTorch tensors. It handles device placement and dtype conversion. + + Attributes: + device: Optional device to place tensors on. Can be a string (e.g. "cpu", "cuda:0") + or a torch.device object. + """ + + device: Optional[str] + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__() + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + def _process_batch(self, batch: pd.DataFrame) -> pd.DataFrame: + """Process the batch by adding 5 to the id column. + + Args: + batch: Input Pandas DataFrame. + + Returns: + A new DataFrame with modified "id" column and original "value" column. + """ + return pd.DataFrame({"id": batch["id"] + 5, "value": batch["id"]}) + + def _get_tensors(self, batch: pd.DataFrame) -> Dict[str, torch.Tensor]: + """Convert batch to tensors. + + Args: + batch: Input Pandas DataFrame to convert to tensors. + + Returns: + Dictionary mapping column names to PyTorch tensors. + """ + return convert_ndarray_batch_to_torch_tensor_batch( + batch.to_dict("series"), dtypes=None, device=None + ) + + +class SingleTensorPandasBatchCollateFn(BasePandasBatchCollateFn): + """Collate function that returns only the id column as a tensor.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pd.DataFrame) -> torch.Tensor: + """Return only the id column as a tensor.""" + modified_batch = self._process_batch(batch) + return self._get_tensors(modified_batch)["id"] + + +class TuplePandasBatchCollateFn(BasePandasBatchCollateFn): + """Collate function that returns id and value as a tuple of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]: + """Return id and value as a tuple of tensors.""" + assert isinstance(batch, pd.DataFrame) + modified_batch = self._process_batch(batch) + tensors = self._get_tensors(modified_batch) + return tensors["id"], tensors["value"] + + +class DictPandasBatchCollateFn(BasePandasBatchCollateFn): + """Collate function that returns id and value as a dictionary of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pd.DataFrame) -> Dict[str, torch.Tensor]: + """Return id and value as a dictionary of tensors.""" + assert isinstance(batch, pd.DataFrame) + modified_batch = self._process_batch(batch) + return self._get_tensors(modified_batch) + + +class ListPandasBatchCollateFn(BasePandasBatchCollateFn): + """Collate function that returns id and value as a list of tensors.""" + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__(device) + + def __call__(self, batch: pd.DataFrame) -> List[torch.Tensor]: + """Return id and value as a list of tensors.""" + assert isinstance(batch, pd.DataFrame) + modified_batch = self._process_batch(batch) + tensors = self._get_tensors(modified_batch) + return [tensors["id"], tensors["value"]] + + +@pytest.fixture +def custom_collate_fns(): + """Fixture that provides both Arrow and Numpy custom collate functions.""" + + def _create_collate_fns(device): + return { + "arrow": { + "single": SingleTensorArrowBatchCollateFn(device=device), + "tuple": TupleArrowBatchCollateFn(device=device), + "dict": DictArrowBatchCollateFn(device=device), + "list": ListArrowBatchCollateFn(device=device), + }, + "numpy": { + "single": SingleTensorNumpyBatchCollateFn(device=device), + "tuple": TupleNumpyBatchCollateFn(device=device), + "dict": DictNumpyBatchCollateFn(device=device), + "list": ListNumpyBatchCollateFn(device=device), + }, + "pandas": { + "single": SingleTensorPandasBatchCollateFn(device=device), + "tuple": TuplePandasBatchCollateFn(device=device), + "dict": DictPandasBatchCollateFn(device=device), + "list": ListPandasBatchCollateFn(device=device), + }, + } + + return _create_collate_fns + + +@pytest.mark.parametrize( + "collate_type,return_type,device", + [ + ("arrow", "single", "cpu"), + ("arrow", "single", "cuda"), + ("arrow", "tuple", "cpu"), + ("arrow", "tuple", "cuda"), + ("arrow", "dict", "cpu"), + ("arrow", "dict", "cuda"), + ("arrow", "list", "cpu"), + ("arrow", "list", "cuda"), + ("numpy", "single", "cpu"), + ("numpy", "single", "cuda"), + ("numpy", "tuple", "cpu"), + ("numpy", "tuple", "cuda"), + ("numpy", "dict", "cpu"), + ("numpy", "dict", "cuda"), + ("numpy", "list", "cpu"), + ("numpy", "list", "cuda"), + ("pandas", "single", "cpu"), + ("pandas", "single", "cuda"), + ("pandas", "tuple", "cpu"), + ("pandas", "tuple", "cuda"), + ("pandas", "dict", "cpu"), + ("pandas", "dict", "cuda"), + ("pandas", "list", "cpu"), + ("pandas", "list", "cuda"), + ], +) +def test_custom_batch_collate_fn( + ray_start_4_cpus_2_gpus, custom_collate_fns, collate_type, return_type, device +): + """Tests that custom batch collate functions can be used to modify + the batch before it is converted to a PyTorch tensor.""" + # Skip GPU tests if CUDA is not available + if device != "cpu" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + # Get the actual device to use + if device == "cuda": + device = str(get_device()) + + ds = ray.data.range(5) + it = ds.iterator() + + collate_fns = custom_collate_fns(device) + collate_fn = ( + collate_fns[collate_type][return_type] + if return_type + else collate_fns[collate_type] + ) + + for batch in it.iter_torch_batches(collate_fn=collate_fn): + if return_type == "single": + assert isinstance(batch, torch.Tensor) + assert sorted(batch.tolist()) == list(range(5, 10)) + assert str(batch.device) == device + elif return_type == "dict": + assert isinstance(batch, dict) + assert sorted(batch["id"].tolist()) == list(range(5, 10)) + assert sorted(batch["value"].tolist()) == list(range(5)) + assert str(batch["id"].device) == device + assert str(batch["value"].device) == device + else: # tuple or list + assert isinstance(batch, torch.Tensor) + # For tuple/list return types, tensors are concatenated + # First 5 values: modified id values [5,6,7,8,9] + # Last 5 values: original values [0,1,2,3,4] + assert sorted(batch.tolist()) == list(range(10)) + assert str(batch.device) == device + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index 1b5263360f743..70a1e1884113a 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -1,13 +1,15 @@ # Standard library imports import logging import time -from typing import Any, Dict, Tuple, Iterator, Generator, Optional +from typing import Dict, Tuple, Iterator, Generator, Optional, Union, Type # Third-party imports import torch +import pyarrow import ray import ray.data import ray.train +from ray.data.iterator import ArrowBatchCollateFn # Local imports from config import BenchmarkConfig @@ -188,36 +190,50 @@ def create_batch_iterator( raise -class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): - """Factory for creating Ray DataLoader for image classification tasks. +class CustomArrowCollateFn(ArrowBatchCollateFn): + """Custom collate function for converting Arrow batches to PyTorch tensors.""" - Features: - - Distributed file reading with round-robin worker distribution - - Device transfer and error handling for data batches - - Configurable row limits per worker for controlled processing - - Performance monitoring and logging - """ + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. - def __init__(self, benchmark_config: BenchmarkConfig): - super().__init__(benchmark_config) + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + self.dtypes = dtypes + self.device = device - def collate_fn(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert Ray data batch to PyTorch tensors on the appropriate device. + def __call__(self, batch: "pyarrow.Table") -> Tuple[torch.Tensor, torch.Tensor]: + """Convert an Arrow batch to PyTorch tensors. Args: - batch: Dictionary with 'image' and 'label' numpy arrays + batch: PyArrow Table to convert Returns: - Tuple of (image_tensor, label_tensor) on the target device + Tuple of (image_tensor, label_tensor) """ from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, + arrow_batch_to_tensors, ) - device = ray.train.torch.get_device() - batch = convert_ndarray_batch_to_torch_tensor_batch(batch, device=device) + tensors = arrow_batch_to_tensors( + batch, dtypes=self.dtypes, combine_chunks=self.device == "cpu" + ) + return tensors["image"], tensors["label"] + + +class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): + """Factory for creating Ray DataLoader for image classification tasks.""" + + def __init__(self, benchmark_config: BenchmarkConfig): + super().__init__(benchmark_config) - return batch["image"], batch["label"] + def _get_collate_fn_cls(self) -> Type[ArrowBatchCollateFn]: + return CustomArrowCollateFn class ImageClassificationMockDataLoaderFactory(BaseDataLoaderFactory): diff --git a/release/train_tests/benchmark/ray_dataloader_factory.py b/release/train_tests/benchmark/ray_dataloader_factory.py index 96b914b5e5a38..4b3e756e8f7db 100644 --- a/release/train_tests/benchmark/ray_dataloader_factory.py +++ b/release/train_tests/benchmark/ray_dataloader_factory.py @@ -1,9 +1,8 @@ from abc import abstractmethod -from typing import Any, Dict, Tuple +from typing import Any, Dict, Type -import torch import ray.train -from ray.data import Dataset +from ray.data.iterator import ArrowBatchCollateFn from constants import DatasetKey from config import BenchmarkConfig, RayDataConfig @@ -28,24 +27,16 @@ def __init__(self, benchmark_config: BenchmarkConfig) -> None: data_context.retried_io_errors.append("AWS Error ACCESS_DENIED") @abstractmethod - def get_ray_datasets(self) -> Dict[str, Dataset]: - """Get the Ray datasets for training and validation. - - Returns: - Dict with "train" and "val" Dataset objects - """ + def _get_collate_fn_cls(self) -> Type[ArrowBatchCollateFn]: + """Return the collate function class. Must be implemented by subclass.""" pass - @abstractmethod - def collate_fn(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - """Get the collate function for the dataloader. + def get_train_dataloader(self): + """Get the training dataloader. Returns: - A function that takes a batch and returns a tuple of tensors. + Iterator of training batches """ - pass - - def get_train_dataloader(self): ds_iterator = ray.train.get_dataset_shard(DatasetKey.TRAIN) self._ray_ds_iterators[DatasetKey.TRAIN] = ds_iterator @@ -58,13 +49,20 @@ def get_train_dataloader(self): if dataloader_config.local_buffer_shuffle_size > 0 else None ), - collate_fn=self.collate_fn, + collate_fn=self._get_collate_fn_cls()( + device=ray.train.torch.get_device() + ), prefetch_batches=dataloader_config.ray_data_prefetch_batches, drop_last=True, ) ) def get_val_dataloader(self): + """Get the validation dataloader. + + Returns: + Iterator of validation batches + """ ds_iterator = ray.train.get_dataset_shard(DatasetKey.VALID) self._ray_ds_iterators[DatasetKey.VALID] = ds_iterator @@ -72,7 +70,9 @@ def get_val_dataloader(self): return iter( ds_iterator.iter_torch_batches( batch_size=dataloader_config.validation_batch_size, - collate_fn=self.collate_fn, + collate_fn=self._get_collate_fn_cls()( + device=ray.train.torch.get_device() + ), prefetch_batches=dataloader_config.ray_data_prefetch_batches, drop_last=True, )