From 35edb58a4de02d9237846e56e40bb8b93965ed88 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 01:12:49 +0000 Subject: [PATCH 01/30] WIP: Handle non-contiguous Tensors GPU transfer Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 102 ++++++ .../_internal/arrow_ops/transform_pyarrow.py | 57 ++++ python/ray/data/iterator.py | 322 ++++++++++++++++-- 3 files changed, 449 insertions(+), 32 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index caeb27a20a30..95f08e50859b 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import torch +import pyarrow from ray.air._internal.device_manager import get_torch_device_manager_by_context from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed @@ -292,3 +293,104 @@ def consume_prefix_in_state_dict_if_present_not_in_place( metadata[newkey] = metadata.pop(key) return state_dict + + +def convert_ndarray_list_to_torch_tensor_list( + ndarrays: Union[List[np.ndarray], Dict[str, List[np.ndarray]]], + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[str] = None, +) -> Union[List[torch.Tensor], Dict[str, List[torch.Tensor]]]: + """Convert a list of NumPy ndarrays or dict of lists of ndarrays to Torch Tensors. + + Args: + ndarrays: A list of NumPy ndarrays or a dict mapping column names to lists of + ndarrays that we wish to convert to Torch Tensors. + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: The device on which the tensor(s) should be placed; if None, the Torch + tensor(s) will be constructed on the CPU. + + Returns: A list of Torch Tensors or a dict mapping column names to lists of Tensors. + """ + if isinstance(ndarrays, list): + # Single column case - list of ndarrays + if isinstance(dtypes, dict): + if len(dtypes) != 1: + raise ValueError( + "When constructing a single-column batch, only a single dtype " + f"should be given, instead got: {dtypes}" + ) + dtypes = next(iter(dtypes.values())) + return [ + convert_ndarray_batch_to_torch_tensor_batch( + ndarray, dtypes=dtypes, device=device + ) + for ndarray in ndarrays + ] + else: + # Multi-column case - dict of lists of ndarrays + return { + col_name: [ + convert_ndarray_batch_to_torch_tensor_batch( + ndarray, + dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes, + device=device, + ) + for ndarray in col_ndarrays + ] + for col_name, col_ndarrays in ndarrays.items() + } + + +def arrow_table_to_gpu_tensors( + batch: pyarrow.Table, + combine_chunks: bool = True, + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[str] = None, +) -> Union[ + "torch.Tensor", + List["torch.Tensor"], + Dict[str, "torch.Tensor"], + Dict[str, List["torch.Tensor"]], +]: + """Convert PyArrow table to PyTorch tensors. + + Args: + batch: PyArrow table to convert + combine_chunks: Whether to combine chunks or keep separate + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: Optional device to place tensors on + + Returns: + PyTorch tensors converted from the Arrow table, can be: + - A single tensor + - A list of tensors + - A dict of column name to tensor + - A dict of column name to list of tensors + """ + + from ray.data._internal.arrow_ops import transform_pyarrow + + if combine_chunks: + numpy_batch = transform_pyarrow.table_to_numpy_dict_combined( + batch, + zero_copy_only=False, + ) + result = convert_ndarray_batch_to_torch_tensor_batch( + numpy_batch, + dtypes=dtypes, + device=device, + ) + else: + numpy_list = transform_pyarrow.table_to_numpy_dict_chunked( + batch, + zero_copy_only=False, + ) + result = convert_ndarray_list_to_torch_tensor_list( + numpy_list, + dtypes=dtypes, + device=device, + ) + + return result diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 1e6a8ae7a91d..1329c1bcc441 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -583,6 +583,63 @@ def concat_and_sort( return take_table(ret, indices) +def table_to_numpy_dict_combined( + table: "pyarrow.Table", + *, + zero_copy_only: bool = False, +) -> Dict[str, np.ndarray]: + """Convert a PyArrow table to a dictionary of numpy arrays. + + Args: + table: The PyArrow table to convert. + zero_copy_only: Whether to only use zero-copy transfers. + + Returns: + A dictionary of numpy arrays. + """ + + numpy_batch = {} + for col_name in table.column_names: + col = table[col_name] + if isinstance(col, pyarrow.ChunkedArray): + combined_array = combine_chunked_array(col) + numpy_batch[col_name] = combined_array.to_numpy( + zero_copy_only=zero_copy_only + ) + else: + numpy_batch[col_name] = col.to_numpy(zero_copy_only=zero_copy_only) + return numpy_batch + + +def table_to_numpy_dict_chunked( + table: "pyarrow.Table", + *, + zero_copy_only: bool = False, +) -> Dict[str, List[np.ndarray]]: + """Convert a PyArrow table to a dictionary of lists of numpy arrays. + + Args: + table: The PyArrow table to convert. + zero_copy_only: Whether to only use zero-copy transfers. + + Returns: + A dictionary mapping column names to either: + - A list of numpy arrays (for chunked columns) + - A single numpy array (for non-chunked columns) + """ + + numpy_batch = {} + for col_name in table.column_names: + col = table[col_name] + if isinstance(col, pyarrow.ChunkedArray): + numpy_batch[col_name] = [ + chunk.to_numpy(zero_copy_only=zero_copy_only) for chunk in col.chunks + ] + else: + numpy_batch[col_name] = col.to_numpy(zero_copy_only=zero_copy_only) + return numpy_batch + + def to_numpy( array: Union["pyarrow.Array", "pyarrow.ChunkedArray"], *, diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 7bf8aab67e3a..ddd77d2cbed7 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -12,9 +12,12 @@ Tuple, TypeVar, Union, + Generic, ) import numpy as np +import pyarrow +import torch from ray.data._internal.block_batching.iter_batches import iter_batches from ray.data._internal.execution.interfaces import RefBundle @@ -56,6 +59,270 @@ def __iter__(self): return self.iterator_gen() +class CollateFn(Generic[T]): + """A function that converts a DataBatch to a CollatedData.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + self.dtypes = dtypes + self.device = device + + @abc.abstractmethod + def __call__(self, batch: T) -> "CollatedData": + """Convert a batch of data to collated format. + + Args: + batch: The input batch to collate + + Returns: + The collated data in the format expected by the model + """ + ... + + +class ArrowBatchCollateFn(CollateFn[pyarrow.Table]): + """Collate function for converting Arrow tables to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) + + def __call__(self, batch: pyarrow.Table) -> "CollatedData": + """Convert a PyArrow table to PyTorch tensors. + + Args: + batch: PyArrow table to convert + + Returns: + Collated PyTorch tensors + """ + ... + + +class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): + """Collate function for converting Numpy batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) + + def _numpy_batch_to_torch_tensors( + self, + batch: Dict[str, np.ndarray], + device: Optional[str] = None, + ) -> Union["torch.Tensor", Dict[str, "torch.Tensor"]]: + """Convert a dictionary of numpy arrays to PyTorch tensors. + + Args: + batch: Dictionary mapping column names to numpy arrays + + Returns: + Either a single PyTorch tensor or a dict mapping column names to tensors + """ + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + batch, + dtypes=self.dtypes, + device=device, + ) + + def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": + """Convert a Numpy batch to PyTorch tensors. + + Args: + batch: Numpy batch to convert + + Returns: + Collated PyTorch tensors + """ + ... + + +class DefaultCollateFn(ArrowBatchCollateFn): + """Default collate function for converting Arrow batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) + + def __call__( + self, batch: pyarrow.Table + ) -> Union[ + "torch.Tensor", + List["torch.Tensor"], + Dict[str, "torch.Tensor"], + Dict[str, List["torch.Tensor"]], + ]: + """Convert an Arrow batch to PyTorch tensors. + + Args: + batch: PyArrow Table to convert + + Returns: + Collated PyTorch tensors, can be: + - A single tensor + - A list of tensors + - A dict of column name to tensor + - A dict of column name to list of tensors + """ + from ray.air._internal.torch_utils import ( + arrow_table_to_gpu_tensors, + ) + + combine_chunks = self.device is None or self.device == "cpu" + return arrow_table_to_gpu_tensors( + batch, combine_chunks=combine_chunks, dtypes=self.dtypes, device=self.device + ) + + +class DefaultFinalizeFn: + """Default finalize function for moving PyTorch tensors to device.""" + + def __init__( + self, + device: Optional[str] = None, + ): + """Initialize the finalize function. + + Args: + device: Optional device to place tensors on + """ + self.device = device + + @staticmethod + def _concat_tensors_to_device( + tensor_list: List["torch.Tensor"], + device: str, + ) -> "torch.Tensor": + """Stack list of tensors into a contiguous GPU tensor. + + Args: + tensor_list: List of tensors to stack + + Returns: + A contiguous tensor on the target device + """ + # Assumes tensors have the same shape/dtype + assert tensor_list, "Cannot stack empty list of tensors" + assert all( + isinstance(t, torch.Tensor) for t in tensor_list + ), "All items must be torch.Tensor" + assert all( + t.dtype == tensor_list[0].dtype for t in tensor_list + ), "All tensors must have the same dtype" + assert all( + t.shape[1:] == tensor_list[0].shape[1:] for t in tensor_list + ), "All tensors must have the same shape[1:]" + + first = tensor_list[0] + dtype = first.dtype + shape_tail = first.shape[1:] + total_rows = sum(t.shape[0] for t in tensor_list) + + # Allocate an empty Tensor on device + result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device) + + row_start = 0 + for t in tensor_list: + row_end = row_start + t.shape[0] + if t.is_pinned(): + # Perform non-blocking transfer if the tensor is pinned + result[row_start:row_end].copy_(t, non_blocking=True) + else: + # Perform blocking transfer if the tensor is not pinned + result[row_start:row_end].copy_(t) + row_start = row_end + + return result + + @torch.no_grad() + def __call__( + self, + batch: Union[ + "torch.Tensor", + List["torch.Tensor"], + Dict[str, "torch.Tensor"], + Dict[str, List["torch.Tensor"]], + ], + ) -> Union[ + "torch.Tensor", + List["torch.Tensor"], + Dict[str, "torch.Tensor"], + Dict[str, List["torch.Tensor"]], + ]: + """Move tensors to device. + + Args: + batch: Tensor or collection of tensors to move to device + + Returns: + Tensor or collection of tensors moved to the target device + """ + if self.device is None: + return batch + + if isinstance(batch, dict): + for k, v in batch.items(): + if isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v): + batch[k] = self._concat_tensors_to_device(v, device=self.device) + elif isinstance(v, torch.Tensor): + if v.is_pinned(): + batch[k] = v.to(device=self.device, non_blocking=True) + else: + batch[k] = v.to(device=self.device) + elif isinstance(batch, list) and all( + isinstance(t, torch.Tensor) for t in batch + ): + batch = self._concat_tensors_to_device(batch, device=self.device) + else: + assert isinstance(batch, torch.Tensor), "Batch must be a Tensor" + if batch.is_pinned(): + batch = batch.to(device=self.device, non_blocking=True) + else: + batch = batch.to(device=self.device) + + return batch + + @PublicAPI class DataIterator(abc.ABC): """An iterator for reading records from a :class:`~Dataset`. @@ -242,7 +509,9 @@ def iter_torch_batches( batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: str = "auto", - collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None, + collate_fn: Optional[ + Union[Callable[[Dict[str, np.ndarray]], "CollatedData"], CollateFn] + ] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -301,17 +570,17 @@ def iter_torch_batches( Dataset is passed to Ray Train and ``collate_fn`` is not provided. Otherwise, defaults to CPU. You can't use this parameter with ``collate_fn``. - collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. + collate_fn: A function to convert a PyArrow Table or Numpy batch to PyTorch tensors. When this parameter is specified, the user should manually handle the host to device data transfer outside of ``collate_fn``. This is useful for further processing the data after it has been batched. Potential use cases include collating along a dimension other than the first, padding sequences of various lengths, or generally handling batches of different length tensors. If not provided, the - default collate function is used which simply converts the batch of - numpy arrays to a batch of PyTorch tensors. This API is still - experimental and is subject to change. You can't use this parameter in - conjunction with ``dtypes`` or ``device``. + default collate function is used which simply converts the batch to + PyTorch tensors. This API is still experimental and is subject to + change. You can't use this parameter in conjunction with ``dtypes`` + or ``device``. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the @@ -327,9 +596,6 @@ def iter_torch_batches( An iterable over Torch Tensor batches. """ - from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, - ) from ray.train.torch import get_device if collate_fn is not None and (dtypes is not None or device != "auto"): @@ -344,35 +610,27 @@ def iter_torch_batches( # Ray Train is not being used. device = get_device() + finalize_fn = None if collate_fn is None: - # The default collate_fn handles formatting and Tensor creation. - # Here, we set device=None to defer host to device data transfer - # to the subsequent finalize_fn. - def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): - return convert_ndarray_batch_to_torch_tensor_batch( - batch, - dtypes=dtypes, - device=None, - ) - - # The default finalize_fn handles the host to device data transfer. - # This is executed in a 1-thread pool separately from collate_fn - # to allow independent parallelism of these steps. - def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]): - if device is not None: - if isinstance(batch, dict): - for k, t in batch.items(): - batch[k] = t.to(device=device) - else: - batch = batch.to(device=device) - return batch - + collate_fn = DefaultCollateFn( + dtypes=dtypes, + device=device, + ) + finalize_fn = DefaultFinalizeFn(device=device) + batch_format = "pyarrow" + elif isinstance(collate_fn, ArrowBatchCollateFn): + batch_format = "pyarrow" + elif isinstance(collate_fn, NumpyBatchCollateFn): + batch_format = "numpy" + elif callable(collate_fn): + batch_format = "numpy" else: - finalize_fn = None + raise ValueError(f"Unsupported collate function: {type(collate_fn)}") return self.iter_batches( prefetch_batches=prefetch_batches, batch_size=batch_size, + batch_format=batch_format, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, From f5d83ea3f528b46c615e47813e077cfd749f0a25 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 04:43:49 +0000 Subject: [PATCH 02/30] Lint Signed-off-by: Srinath Krishnamachari --- python/ray/data/iterator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index ddd77d2cbed7..4580c752601a 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -17,7 +17,6 @@ import numpy as np import pyarrow -import torch from ray.data._internal.block_batching.iter_batches import iter_batches from ray.data._internal.execution.interfaces import RefBundle @@ -217,6 +216,8 @@ def __call__( class DefaultFinalizeFn: """Default finalize function for moving PyTorch tensors to device.""" + import torch + def __init__( self, device: Optional[str] = None, @@ -241,6 +242,8 @@ def _concat_tensors_to_device( Returns: A contiguous tensor on the target device """ + import torch + # Assumes tensors have the same shape/dtype assert tensor_list, "Cannot stack empty list of tensors" assert all( @@ -297,6 +300,8 @@ def __call__( Returns: Tensor or collection of tensors moved to the target device """ + import torch + if self.device is None: return batch From 309b72f350d8084e91ceffb34299aeec77f8c10b Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 05:29:32 +0000 Subject: [PATCH 03/30] Lint Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/iterator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 4580c752601a..fc141a82549d 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -216,8 +216,6 @@ def __call__( class DefaultFinalizeFn: """Default finalize function for moving PyTorch tensors to device.""" - import torch - def __init__( self, device: Optional[str] = None, @@ -277,7 +275,6 @@ def _concat_tensors_to_device( return result - @torch.no_grad() def __call__( self, batch: Union[ From 1c465ddefd48f8e3c283f695757f167c0696c506 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 16:43:15 +0000 Subject: [PATCH 04/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 134 +++++++++++++++++++++++- python/ray/data/iterator.py | 114 +++----------------- 2 files changed, 143 insertions(+), 105 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 95f08e50859b..4593ddee15e7 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -342,9 +342,8 @@ def convert_ndarray_list_to_torch_tensor_list( } -def arrow_table_to_gpu_tensors( +def arrow_table_to_tensors( batch: pyarrow.Table, - combine_chunks: bool = True, dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, ) -> Union[ @@ -357,7 +356,6 @@ def arrow_table_to_gpu_tensors( Args: batch: PyArrow table to convert - combine_chunks: Whether to combine chunks or keep separate dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype will be inferred from the NumPy ndarray data. device: Optional device to place tensors on @@ -372,6 +370,15 @@ def arrow_table_to_gpu_tensors( from ray.data._internal.arrow_ops import transform_pyarrow + combine_chunks: bool = ( + dtypes is None + or not any( + dtype is object for dtype in dtypes.values() if isinstance(dtypes, dict) + ) + or device is None + or device == "cpu" + ) + if combine_chunks: numpy_batch = transform_pyarrow.table_to_numpy_dict_combined( batch, @@ -394,3 +401,124 @@ def arrow_table_to_gpu_tensors( ) return result + + +def numpy_batch_to_torch_tensors( + batch: Dict[str, np.ndarray], + dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, + device: Optional[str] = None, +) -> Union["torch.Tensor", Dict[str, "torch.Tensor"]]: + """Convert a dictionary of numpy arrays to PyTorch tensors. + + Args: + batch: Dictionary mapping column names to numpy arrays + + Returns: + Either a single PyTorch tensor or a dict mapping column names to tensors + """ + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + batch, + dtypes=dtypes, + device=device, + ) + + +def concat_tensors_to_device( + tensor_list: List["torch.Tensor"], + device: str, +) -> "torch.Tensor": + """Stack list of tensors into a contiguous GPU tensor. + + Args: + tensor_list: List of tensors to stack + + Returns: + A contiguous tensor on the target device + """ + # Assumes tensors have the same shape/dtype + assert tensor_list, "Cannot stack empty list of tensors" + assert all( + isinstance(t, torch.Tensor) for t in tensor_list + ), "All items must be torch.Tensor" + assert all( + t.dtype == tensor_list[0].dtype for t in tensor_list + ), "All tensors must have the same dtype" + assert all( + t.shape[1:] == tensor_list[0].shape[1:] for t in tensor_list + ), "All tensors must have the same shape[1:]" + + first = tensor_list[0] + dtype = first.dtype + shape_tail = first.shape[1:] + total_rows = sum(t.shape[0] for t in tensor_list) + + # Allocate an empty Tensor on device + result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device) + + row_start = 0 + for t in tensor_list: + row_end = row_start + t.shape[0] + if t.is_pinned(): + # Perform non-blocking transfer if the tensor is pinned + result[row_start:row_end].copy_(t, non_blocking=True) + else: + # Perform blocking transfer if the tensor is not pinned + result[row_start:row_end].copy_(t) + row_start = row_end + + return result + + +def move_tensors_to_device( + batch: Union[ + torch.Tensor, + List[torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, List[torch.Tensor]], + ], + device: Optional[str] = None, +) -> Union[ + torch.Tensor, + List[torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, List[torch.Tensor]], +]: + """Move tensors to the specified device. + + Args: + batch: A tensor or collection of tensors to move to device. Can be: + - A single tensor + - A list of tensors + - A dict mapping keys to tensors + - A dict mapping keys to lists of tensors + device: The device to move tensors to. If None, tensors are not moved. + + Returns: + The input tensors moved to the specified device, maintaining the same structure. + """ + if device is None: + return batch + + if isinstance(batch, dict): + for k, v in batch.items(): + if isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v): + batch[k] = concat_tensors_to_device(v, device=device) + elif isinstance(v, torch.Tensor): + if v.is_pinned(): + batch[k] = v.to(device=device, non_blocking=True) + else: + batch[k] = v.to(device=device) + elif isinstance(batch, list) and all(isinstance(t, torch.Tensor) for t in batch): + batch = concat_tensors_to_device(batch, device=device) + else: + assert isinstance(batch, torch.Tensor), "Batch must be a Tensor" + if batch.is_pinned(): + batch = batch.to(device=device, non_blocking=True) + else: + batch = batch.to(device=device) + + return batch diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index fc141a82549d..46da6d2fbfa3 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -26,7 +26,7 @@ from ray.data._internal.stats import DatasetStats, StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.context import DataContext -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI if TYPE_CHECKING: import tensorflow as tf @@ -58,6 +58,7 @@ def __iter__(self): return self.iterator_gen() +@DeveloperAPI class CollateFn(Generic[T]): """A function that converts a DataBatch to a CollatedData.""" @@ -88,6 +89,7 @@ def __call__(self, batch: T) -> "CollatedData": ... +@DeveloperAPI class ArrowBatchCollateFn(CollateFn[pyarrow.Table]): """Collate function for converting Arrow tables to PyTorch tensors.""" @@ -116,6 +118,7 @@ def __call__(self, batch: pyarrow.Table) -> "CollatedData": ... +@DeveloperAPI class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): """Collate function for converting Numpy batches to PyTorch tensors.""" @@ -132,29 +135,6 @@ def __init__( """ super().__init__(dtypes=dtypes, device=device) - def _numpy_batch_to_torch_tensors( - self, - batch: Dict[str, np.ndarray], - device: Optional[str] = None, - ) -> Union["torch.Tensor", Dict[str, "torch.Tensor"]]: - """Convert a dictionary of numpy arrays to PyTorch tensors. - - Args: - batch: Dictionary mapping column names to numpy arrays - - Returns: - Either a single PyTorch tensor or a dict mapping column names to tensors - """ - from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, - ) - - return convert_ndarray_batch_to_torch_tensor_batch( - batch, - dtypes=self.dtypes, - device=device, - ) - def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": """Convert a Numpy batch to PyTorch tensors. @@ -167,6 +147,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": ... +@DeveloperAPI class DefaultCollateFn(ArrowBatchCollateFn): """Default collate function for converting Arrow batches to PyTorch tensors.""" @@ -204,15 +185,13 @@ def __call__( - A dict of column name to list of tensors """ from ray.air._internal.torch_utils import ( - arrow_table_to_gpu_tensors, + arrow_table_to_tensors, ) - combine_chunks = self.device is None or self.device == "cpu" - return arrow_table_to_gpu_tensors( - batch, combine_chunks=combine_chunks, dtypes=self.dtypes, device=self.device - ) + return arrow_table_to_tensors(batch, dtypes=self.dtypes, device=self.device) +@DeveloperAPI class DefaultFinalizeFn: """Default finalize function for moving PyTorch tensors to device.""" @@ -227,54 +206,6 @@ def __init__( """ self.device = device - @staticmethod - def _concat_tensors_to_device( - tensor_list: List["torch.Tensor"], - device: str, - ) -> "torch.Tensor": - """Stack list of tensors into a contiguous GPU tensor. - - Args: - tensor_list: List of tensors to stack - - Returns: - A contiguous tensor on the target device - """ - import torch - - # Assumes tensors have the same shape/dtype - assert tensor_list, "Cannot stack empty list of tensors" - assert all( - isinstance(t, torch.Tensor) for t in tensor_list - ), "All items must be torch.Tensor" - assert all( - t.dtype == tensor_list[0].dtype for t in tensor_list - ), "All tensors must have the same dtype" - assert all( - t.shape[1:] == tensor_list[0].shape[1:] for t in tensor_list - ), "All tensors must have the same shape[1:]" - - first = tensor_list[0] - dtype = first.dtype - shape_tail = first.shape[1:] - total_rows = sum(t.shape[0] for t in tensor_list) - - # Allocate an empty Tensor on device - result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device) - - row_start = 0 - for t in tensor_list: - row_end = row_start + t.shape[0] - if t.is_pinned(): - # Perform non-blocking transfer if the tensor is pinned - result[row_start:row_end].copy_(t, non_blocking=True) - else: - # Perform blocking transfer if the tensor is not pinned - result[row_start:row_end].copy_(t) - row_start = row_end - - return result - def __call__( self, batch: Union[ @@ -297,32 +228,11 @@ def __call__( Returns: Tensor or collection of tensors moved to the target device """ - import torch - - if self.device is None: - return batch - - if isinstance(batch, dict): - for k, v in batch.items(): - if isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v): - batch[k] = self._concat_tensors_to_device(v, device=self.device) - elif isinstance(v, torch.Tensor): - if v.is_pinned(): - batch[k] = v.to(device=self.device, non_blocking=True) - else: - batch[k] = v.to(device=self.device) - elif isinstance(batch, list) and all( - isinstance(t, torch.Tensor) for t in batch - ): - batch = self._concat_tensors_to_device(batch, device=self.device) - else: - assert isinstance(batch, torch.Tensor), "Batch must be a Tensor" - if batch.is_pinned(): - batch = batch.to(device=self.device, non_blocking=True) - else: - batch = batch.to(device=self.device) + from ray.air._internal.torch_utils import ( + move_tensors_to_device, + ) - return batch + return move_tensors_to_device(batch, device=self.device) @PublicAPI From d90b8da57c69a09d896911e0893bcd6d4bdb86ab Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 17:39:48 +0000 Subject: [PATCH 05/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 40 ++++++++++--------- python/ray/data/tests/test_iterator.py | 51 +++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 18 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 4593ddee15e7..853a8b09460e 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -347,10 +347,10 @@ def arrow_table_to_tensors( dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, ) -> Union[ - "torch.Tensor", - List["torch.Tensor"], - Dict[str, "torch.Tensor"], - Dict[str, List["torch.Tensor"]], + torch.Tensor, + List[torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, List[torch.Tensor]], ]: """Convert PyArrow table to PyTorch tensors. @@ -367,16 +367,18 @@ def arrow_table_to_tensors( - A dict of column name to tensor - A dict of column name to list of tensors """ - from ray.data._internal.arrow_ops import transform_pyarrow + # Handle both single dtype and dict of dtypes + has_object_dtype = False + if dtypes is not None: + if isinstance(dtypes, dict): + has_object_dtype = any(dtype is object for dtype in dtypes.values()) + else: + has_object_dtype = dtypes is object + combine_chunks: bool = ( - dtypes is None - or not any( - dtype is object for dtype in dtypes.values() if isinstance(dtypes, dict) - ) - or device is None - or device == "cpu" + dtypes is None or not has_object_dtype or device is None or device == "cpu" ) if combine_chunks: @@ -384,7 +386,7 @@ def arrow_table_to_tensors( batch, zero_copy_only=False, ) - result = convert_ndarray_batch_to_torch_tensor_batch( + return convert_ndarray_batch_to_torch_tensor_batch( numpy_batch, dtypes=dtypes, device=device, @@ -394,24 +396,25 @@ def arrow_table_to_tensors( batch, zero_copy_only=False, ) - result = convert_ndarray_list_to_torch_tensor_list( + return convert_ndarray_list_to_torch_tensor_list( numpy_list, dtypes=dtypes, device=device, ) - return result - def numpy_batch_to_torch_tensors( batch: Dict[str, np.ndarray], dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, -) -> Union["torch.Tensor", Dict[str, "torch.Tensor"]]: +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """Convert a dictionary of numpy arrays to PyTorch tensors. Args: batch: Dictionary mapping column names to numpy arrays + dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype + will be inferred from the NumPy ndarray data. + device: Optional device to place tensors on Returns: Either a single PyTorch tensor or a dict mapping column names to tensors @@ -428,13 +431,14 @@ def numpy_batch_to_torch_tensors( def concat_tensors_to_device( - tensor_list: List["torch.Tensor"], + tensor_list: List[torch.Tensor], device: str, -) -> "torch.Tensor": +) -> torch.Tensor: """Stack list of tensors into a contiguous GPU tensor. Args: tensor_list: List of tensors to stack + device: The device to move tensors to Returns: A contiguous tensor on the target device diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 25af9a878f12..fbf14a6de79a 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch +import pyarrow import ray @@ -195,6 +196,56 @@ def collate_fn(batch: Dict[str, np.ndarray]): ), iter_batches_calls_kwargs +@pytest.fixture +def custom_collate_fns(): + """Fixture that provides both Arrow and Numpy custom collate functions.""" + from ray.data.iterator import ArrowBatchCollateFn, NumpyBatchCollateFn + + class CustomArrowBatchCollateFn(ArrowBatchCollateFn): + def __call__(self, batch: pyarrow.Table) -> torch.Tensor: + """Add 5 to the "id" column at the Arrow level.""" + modified_batch = pyarrow.Table.from_arrays( + [pyarrow.compute.add(batch["id"], 5)], names=["id"] + ) + from ray.air._internal.torch_utils import ( + arrow_table_to_tensors, + ) + + return arrow_table_to_tensors( + modified_batch, dtypes=self.dtypes, device=self.device + )["id"] + + class CustomNumpyBatchCollateFn(NumpyBatchCollateFn): + def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor: + """Add 5 to the "id" array.""" + modified_batch = {"id": batch["id"] + 5} + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + modified_batch, dtypes=self.dtypes, device=self.device + )["id"] + + return { + "arrow": CustomArrowBatchCollateFn(), + "numpy": CustomNumpyBatchCollateFn(), + } + + +@pytest.mark.parametrize("collate_type", ["arrow", "numpy"]) +def test_custom_batch_collate_fn( + ray_start_regular_shared, custom_collate_fns, collate_type +): + """Tests that custom batch collate functions can be used to modify + the batch before it is converted to a PyTorch tensor.""" + ds = ray.data.range(5) + it = ds.iterator() + for batch in it.iter_torch_batches(collate_fn=custom_collate_fns[collate_type]): + assert isinstance(batch, torch.Tensor) + assert batch.tolist() == list(range(5, 10)) + + def test_iterator_to_materialized_dataset(ray_start_regular_shared): """Tests that `DataIterator.materialize` fully consumes the iterator and returns a `MaterializedDataset` view of the data From 700e7fec5e1eec9cf0903bf4ce9dc64655b83e0f Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 18:06:02 +0000 Subject: [PATCH 06/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/iterator.py | 8 ++++---- python/ray/data/tests/test_iterator.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 46da6d2fbfa3..d679865bd716 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -16,7 +16,6 @@ ) import numpy as np -import pyarrow from ray.data._internal.block_batching.iter_batches import iter_batches from ray.data._internal.execution.interfaces import RefBundle @@ -31,6 +30,7 @@ if TYPE_CHECKING: import tensorflow as tf import torch + import pyarrow from ray.data.dataset import ( CollatedData, @@ -90,7 +90,7 @@ def __call__(self, batch: T) -> "CollatedData": @DeveloperAPI -class ArrowBatchCollateFn(CollateFn[pyarrow.Table]): +class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): """Collate function for converting Arrow tables to PyTorch tensors.""" def __init__( @@ -106,7 +106,7 @@ def __init__( """ super().__init__(dtypes=dtypes, device=device) - def __call__(self, batch: pyarrow.Table) -> "CollatedData": + def __call__(self, batch: "pyarrow.Table") -> "CollatedData": """Convert a PyArrow table to PyTorch tensors. Args: @@ -165,7 +165,7 @@ def __init__( super().__init__(dtypes=dtypes, device=device) def __call__( - self, batch: pyarrow.Table + self, batch: "pyarrow.Table" ) -> Union[ "torch.Tensor", List["torch.Tensor"], diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index fbf14a6de79a..54e2e2a8490c 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -1,6 +1,6 @@ import sys import threading -from typing import Dict +from typing import Dict, Optional from unittest.mock import MagicMock, patch import numpy as np @@ -202,6 +202,13 @@ def custom_collate_fns(): from ray.data.iterator import ArrowBatchCollateFn, NumpyBatchCollateFn class CustomArrowBatchCollateFn(ArrowBatchCollateFn): + def __init__( + self, + dtypes: Optional[Dict[str, torch.dtype]] = None, + device: Optional[str] = None, + ): + super().__init__(dtypes=dtypes, device=device) + def __call__(self, batch: pyarrow.Table) -> torch.Tensor: """Add 5 to the "id" column at the Arrow level.""" modified_batch = pyarrow.Table.from_arrays( @@ -216,6 +223,13 @@ def __call__(self, batch: pyarrow.Table) -> torch.Tensor: )["id"] class CustomNumpyBatchCollateFn(NumpyBatchCollateFn): + def __init__( + self, + dtypes: Optional[Dict[str, torch.dtype]] = None, + device: Optional[str] = None, + ): + super().__init__(dtypes=dtypes, device=device) + def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor: """Add 5 to the "id" array.""" modified_batch = {"id": batch["id"] + 5} From 3b90c2ce20c5910e81a37484f02ae2fd1b236966 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 18:29:54 +0000 Subject: [PATCH 07/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 2 ++ python/ray/data/iterator.py | 39 +++++++++++++++++++++++-- python/ray/data/tests/test_iterator.py | 29 ++++++++++++++++-- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 853a8b09460e..672c51ae21eb 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -430,6 +430,7 @@ def numpy_batch_to_torch_tensors( ) +@torch.no_grad() def concat_tensors_to_device( tensor_list: List[torch.Tensor], device: str, @@ -477,6 +478,7 @@ def concat_tensors_to_device( return result +@torch.no_grad() def move_tensors_to_device( batch: Union[ torch.Tensor, diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index d679865bd716..492977a39630 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -31,6 +31,7 @@ import tensorflow as tf import torch import pyarrow + import pandas from ray.data.dataset import ( CollatedData, @@ -58,8 +59,11 @@ def __iter__(self): return self.iterator_gen() +DataBatchType = TypeVar("DataBatchType", bound=DataBatch) + + @DeveloperAPI -class CollateFn(Generic[T]): +class CollateFn(Generic[DataBatchType]): """A function that converts a DataBatch to a CollatedData.""" def __init__( @@ -77,7 +81,7 @@ def __init__( self.device = device @abc.abstractmethod - def __call__(self, batch: T) -> "CollatedData": + def __call__(self, batch: DataBatchType) -> "CollatedData": """Convert a batch of data to collated format. Args: @@ -147,6 +151,35 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": ... +@DeveloperAPI +class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]): + """Collate function for converting Pandas batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) + + def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": + """Convert a Pandas batch to PyTorch tensors. + + Args: + batch: Pandas batch to convert + + Returns: + Collated PyTorch tensors + """ + ... + + @DeveloperAPI class DefaultCollateFn(ArrowBatchCollateFn): """Default collate function for converting Arrow batches to PyTorch tensors.""" @@ -534,6 +567,8 @@ def iter_torch_batches( batch_format = "pyarrow" elif isinstance(collate_fn, NumpyBatchCollateFn): batch_format = "numpy" + elif isinstance(collate_fn, PandasBatchCollateFn): + batch_format = "pandas" elif callable(collate_fn): batch_format = "numpy" else: diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 54e2e2a8490c..f5ce79dbf876 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -7,6 +7,7 @@ import pytest import torch import pyarrow +import pandas as pd import ray @@ -199,7 +200,11 @@ def collate_fn(batch: Dict[str, np.ndarray]): @pytest.fixture def custom_collate_fns(): """Fixture that provides both Arrow and Numpy custom collate functions.""" - from ray.data.iterator import ArrowBatchCollateFn, NumpyBatchCollateFn + from ray.data.iterator import ( + ArrowBatchCollateFn, + NumpyBatchCollateFn, + PandasBatchCollateFn, + ) class CustomArrowBatchCollateFn(ArrowBatchCollateFn): def __init__( @@ -241,13 +246,33 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor: modified_batch, dtypes=self.dtypes, device=self.device )["id"] + class CustomPandasBatchCollateFn(PandasBatchCollateFn): + def __init__( + self, + dtypes: Optional[Dict[str, torch.dtype]] = None, + device: Optional[str] = None, + ): + super().__init__(dtypes=dtypes, device=device) + + def __call__(self, batch: pd.DataFrame) -> torch.Tensor: + """Add 5 to the "id" column.""" + modified_batch = pd.DataFrame({"id": batch["id"] + 5}) + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + modified_batch.to_dict("series"), dtypes=self.dtypes, device=self.device + )["id"] + return { "arrow": CustomArrowBatchCollateFn(), "numpy": CustomNumpyBatchCollateFn(), + "pandas": CustomPandasBatchCollateFn(), } -@pytest.mark.parametrize("collate_type", ["arrow", "numpy"]) +@pytest.mark.parametrize("collate_type", ["arrow", "numpy", "pandas"]) def test_custom_batch_collate_fn( ray_start_regular_shared, custom_collate_fns, collate_type ): From 678cf7d803777dd592eb4ab39eb4bc4fbf820ea8 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 23 Apr 2025 22:48:50 +0000 Subject: [PATCH 08/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/iterator.py | 76 +++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 492977a39630..bd6673afe6fd 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -181,7 +181,7 @@ def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": @DeveloperAPI -class DefaultCollateFn(ArrowBatchCollateFn): +class DefaultArrowCollateFn(ArrowBatchCollateFn): """Default collate function for converting Arrow batches to PyTorch tensors.""" def __init__( @@ -224,6 +224,76 @@ def __call__( return arrow_table_to_tensors(batch, dtypes=self.dtypes, device=self.device) +@DeveloperAPI +class DefaultNumpyCollateFn(NumpyBatchCollateFn): + """Default collate function for converting Numpy batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) + + def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": + """Convert a Pandas batch to PyTorch tensors. + + Args: + batch: Pandas batch to convert + + Returns: + Collated PyTorch tensors + """ + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + return convert_ndarray_batch_to_torch_tensor_batch( + batch, dtypes=self.dtypes, device=self.device + ) + + +@DeveloperAPI +class DefaultPandasCollateFn(PandasBatchCollateFn): + """Default collate function for converting Pandas batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. + + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) + + def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": + """Convert a Pandas batch to PyTorch tensors. + + Args: + batch: Pandas batch to convert + + Returns: + Collated PyTorch tensors + """ + from ray.air._internal.torch_utils import ( + convert_pandas_batch_to_torch_tensor_batch, + ) + + return convert_pandas_batch_to_torch_tensor_batch( + batch, dtypes=self.dtypes, device=self.device + ) + + @DeveloperAPI class DefaultFinalizeFn: """Default finalize function for moving PyTorch tensors to device.""" @@ -557,12 +627,12 @@ def iter_torch_batches( finalize_fn = None if collate_fn is None: - collate_fn = DefaultCollateFn( + collate_fn = DefaultNumpyCollateFn( dtypes=dtypes, device=device, ) finalize_fn = DefaultFinalizeFn(device=device) - batch_format = "pyarrow" + batch_format = "numpy" elif isinstance(collate_fn, ArrowBatchCollateFn): batch_format = "pyarrow" elif isinstance(collate_fn, NumpyBatchCollateFn): From ed5c31d91b39f38395cae58c91a295adcb890d54 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 03:42:15 +0000 Subject: [PATCH 09/30] Handle Arrow Array null types in to_numpy Signed-off-by: Srinath Krishnamachari --- .../data/_internal/arrow_ops/transform_pyarrow.py | 4 ++++ python/ray/data/tests/test_iterator.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 1e6a8ae7a91d..fd9be87c8e49 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -594,8 +594,12 @@ def to_numpy( import pyarrow as pa if isinstance(array, pa.Array): + if pa.types.is_null(array.type): + return np.full(len(array), np.nan, dtype=np.float32) return array.to_numpy(zero_copy_only=zero_copy_only) elif isinstance(array, pa.ChunkedArray): + if pa.types.is_null(array.type): + return np.full(array.length(), np.nan, dtype=np.float32) if PYARROW_VERSION >= MIN_PYARROW_VERSION_CHUNKED_ARRAY_TO_NUMPY_ZERO_COPY_ONLY: return array.to_numpy(zero_copy_only=zero_copy_only) else: diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 25af9a878f12..94c5e08a552d 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import numpy as np +import pyarrow as pa import pytest import torch @@ -195,6 +196,20 @@ def collate_fn(batch: Dict[str, np.ndarray]): ), iter_batches_calls_kwargs +def test_torch_conversion_null_type(ray_start_regular_shared): + """Test iter_torch_batches with a PyArrow table containing null type arrays.""" + table = pa.table({"fruit_apple": pa.array([None, None, None], type=pa.null())}) + + ds = ray.data.from_arrow(table) + it = ds.iterator() + for batch in it.iter_torch_batches(): + assert isinstance(batch, dict) + assert "fruit_apple" in batch + assert isinstance(batch["fruit_apple"], torch.Tensor) + assert torch.isnan(batch["fruit_apple"]).all() + assert batch["fruit_apple"].shape == (3,) + + def test_iterator_to_materialized_dataset(ray_start_regular_shared): """Tests that `DataIterator.materialize` fully consumes the iterator and returns a `MaterializedDataset` view of the data From 2f3933d5bde9db4cde8f7023851ab07c48b63bb6 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 04:08:18 +0000 Subject: [PATCH 10/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/tests/test_iterator.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 94c5e08a552d..6fdeb254de71 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -196,11 +196,29 @@ def collate_fn(batch: Dict[str, np.ndarray]): ), iter_batches_calls_kwargs -def test_torch_conversion_null_type(ray_start_regular_shared): - """Test iter_torch_batches with a PyArrow table containing null type arrays.""" - table = pa.table({"fruit_apple": pa.array([None, None, None], type=pa.null())}) +@pytest.fixture(params=["regular", "chunked"]) +def null_array_table(request): + """Fixture that returns a PyArrow table with either a regular or chunked null array.""" + if request.param == "regular": + # Regular array + return pa.table({"fruit_apple": pa.array([None, None, None], type=pa.null())}) + else: + # Chunked array + return pa.table( + { + "fruit_apple": pa.chunked_array( + [ + pa.array([None], type=pa.null()), + pa.array([None, None], type=pa.null()), + ] + ) + } + ) - ds = ray.data.from_arrow(table) + +def test_torch_conversion_null_type(ray_start_regular_shared, null_array_table): + """Test iter_torch_batches with a PyArrow table containing null type arrays.""" + ds = ray.data.from_arrow(null_array_table) it = ds.iterator() for batch in it.iter_torch_batches(): assert isinstance(batch, dict) From 1ac0f4a620137e2842f1fa8bfd600c2148178d17 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 02:58:16 +0000 Subject: [PATCH 11/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 8 ++++---- python/ray/data/iterator.py | 12 ++++++------ python/ray/data/tests/test_iterator.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 672c51ae21eb..08e77311cf35 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -342,7 +342,7 @@ def convert_ndarray_list_to_torch_tensor_list( } -def arrow_table_to_tensors( +def arrow_batch_to_tensors( batch: pyarrow.Table, dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, @@ -352,16 +352,16 @@ def arrow_table_to_tensors( Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], ]: - """Convert PyArrow table to PyTorch tensors. + """Convert PyArrow batch to PyTorch tensors. Args: - batch: PyArrow table to convert + batch: PyArrow batch to convert dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype will be inferred from the NumPy ndarray data. device: Optional device to place tensors on Returns: - PyTorch tensors converted from the Arrow table, can be: + PyTorch tensors converted from the Arrow batch, can be: - A single tensor - A list of tensors - A dict of column name to tensor diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index bd6673afe6fd..3d257fe4ab69 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -218,10 +218,10 @@ def __call__( - A dict of column name to list of tensors """ from ray.air._internal.torch_utils import ( - arrow_table_to_tensors, + arrow_batch_to_tensors, ) - return arrow_table_to_tensors(batch, dtypes=self.dtypes, device=self.device) + return arrow_batch_to_tensors(batch, dtypes=self.dtypes, device=self.device) @DeveloperAPI @@ -251,10 +251,10 @@ def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": Collated PyTorch tensors """ from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, + numpy_batch_to_torch_tensors, ) - return convert_ndarray_batch_to_torch_tensor_batch( + return numpy_batch_to_torch_tensors( batch, dtypes=self.dtypes, device=self.device ) @@ -627,12 +627,12 @@ def iter_torch_batches( finalize_fn = None if collate_fn is None: - collate_fn = DefaultNumpyCollateFn( + collate_fn = DefaultArrowCollateFn( dtypes=dtypes, device=device, ) finalize_fn = DefaultFinalizeFn(device=device) - batch_format = "numpy" + batch_format = "pyarrow" elif isinstance(collate_fn, ArrowBatchCollateFn): batch_format = "pyarrow" elif isinstance(collate_fn, NumpyBatchCollateFn): diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index f5ce79dbf876..b746a3c18a17 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -220,10 +220,10 @@ def __call__(self, batch: pyarrow.Table) -> torch.Tensor: [pyarrow.compute.add(batch["id"], 5)], names=["id"] ) from ray.air._internal.torch_utils import ( - arrow_table_to_tensors, + arrow_batch_to_tensors, ) - return arrow_table_to_tensors( + return arrow_batch_to_tensors( modified_batch, dtypes=self.dtypes, device=self.device )["id"] From 2a0b3b3447b660f416da535d30ba2f292d2644d4 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 04:45:21 +0000 Subject: [PATCH 12/30] Lint Signed-off-by: Srinath Krishnamachari --- python/ray/data/tests/test_iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 4c3df97a94bd..52ddde4cd460 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -285,7 +285,7 @@ def test_custom_batch_collate_fn( assert isinstance(batch, torch.Tensor) assert batch.tolist() == list(range(5, 10)) - + @pytest.fixture(params=["regular", "chunked"]) def null_array_table(request): """Fixture that returns a PyArrow table with either a regular or chunked null array.""" From b1d1835ed4dc6e374496c193ae531fdda3be7659 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 16:29:49 +0000 Subject: [PATCH 13/30] Fixes Signed-off-by: Srinath Krishnamachari --- .../ray/data/_internal/arrow_ops/transform_pyarrow.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 5774d07df1d6..f280df8292cf 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -603,11 +603,11 @@ def table_to_numpy_dict_combined( col = table[col_name] if isinstance(col, pyarrow.ChunkedArray): combined_array = combine_chunked_array(col) - numpy_batch[col_name] = combined_array.to_numpy( - zero_copy_only=zero_copy_only + numpy_batch[col_name] = to_numpy( + combined_array, zero_copy_only=zero_copy_only ) else: - numpy_batch[col_name] = col.to_numpy(zero_copy_only=zero_copy_only) + numpy_batch[col_name] = to_numpy(col, zero_copy_only=zero_copy_only) return numpy_batch @@ -633,10 +633,10 @@ def table_to_numpy_dict_chunked( col = table[col_name] if isinstance(col, pyarrow.ChunkedArray): numpy_batch[col_name] = [ - chunk.to_numpy(zero_copy_only=zero_copy_only) for chunk in col.chunks + to_numpy(chunk, zero_copy_only=zero_copy_only) for chunk in col.chunks ] else: - numpy_batch[col_name] = col.to_numpy(zero_copy_only=zero_copy_only) + numpy_batch[col_name] = to_numpy(col, zero_copy_only=zero_copy_only) return numpy_batch From 2bb01f319b50e2c6157b68f88cc5d47cb80dcef5 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 18:21:58 +0000 Subject: [PATCH 14/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/iterator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 3d257fe4ab69..357034ddd08e 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -221,7 +221,7 @@ def __call__( arrow_batch_to_tensors, ) - return arrow_batch_to_tensors(batch, dtypes=self.dtypes, device=self.device) + return arrow_batch_to_tensors(batch, dtypes=self.dtypes, device=None) @DeveloperAPI @@ -254,9 +254,7 @@ def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": numpy_batch_to_torch_tensors, ) - return numpy_batch_to_torch_tensors( - batch, dtypes=self.dtypes, device=self.device - ) + return numpy_batch_to_torch_tensors(batch, dtypes=self.dtypes, device=None) @DeveloperAPI @@ -290,7 +288,7 @@ def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": ) return convert_pandas_batch_to_torch_tensor_batch( - batch, dtypes=self.dtypes, device=self.device + batch, dtypes=self.dtypes, device=None ) From f0e8a25f4c5a7118f761ad70a5f82103b3dadfa1 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 22:05:27 +0000 Subject: [PATCH 15/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 13 +-- .../benchmark/image_classification/factory.py | 88 +++++++++++++++---- .../benchmark/ray_dataloader_factory.py | 43 +++------ 3 files changed, 81 insertions(+), 63 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 08e77311cf35..b7c1487af47d 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -369,18 +369,7 @@ def arrow_batch_to_tensors( """ from ray.data._internal.arrow_ops import transform_pyarrow - # Handle both single dtype and dict of dtypes - has_object_dtype = False - if dtypes is not None: - if isinstance(dtypes, dict): - has_object_dtype = any(dtype is object for dtype in dtypes.values()) - else: - has_object_dtype = dtypes is object - - combine_chunks: bool = ( - dtypes is None or not has_object_dtype or device is None or device == "cpu" - ) - + combine_chunks: bool = device is None or device == "cpu" if combine_chunks: numpy_batch = transform_pyarrow.table_to_numpy_dict_combined( batch, diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index 1b5263360f74..7fd9b8151815 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -1,10 +1,11 @@ # Standard library imports import logging import time -from typing import Any, Dict, Tuple, Iterator, Generator, Optional +from typing import Dict, Tuple, Iterator, Generator, Optional, Union # Third-party imports import torch +import pyarrow import ray import ray.data import ray.train @@ -188,36 +189,85 @@ def create_batch_iterator( raise -class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): - """Factory for creating Ray DataLoader for image classification tasks. +class CustomArrowCollateFn(ray.data.iterator.ArrowBatchCollateFn): + """Custom collate function for converting Arrow batches to PyTorch tensors.""" - Features: - - Distributed file reading with round-robin worker distribution - - Device transfer and error handling for data batches - - Configurable row limits per worker for controlled processing - - Performance monitoring and logging - """ + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + """Initialize the collate function. - def __init__(self, benchmark_config: BenchmarkConfig): - super().__init__(benchmark_config) + Args: + dtypes: Optional torch dtype(s) for the tensors + device: Optional device to place tensors on + """ + super().__init__(dtypes=dtypes, device=device) - def collate_fn(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert Ray data batch to PyTorch tensors on the appropriate device. + def __call__(self, batch: "pyarrow.Table") -> Tuple[torch.Tensor, torch.Tensor]: + """Convert an Arrow batch to PyTorch tensors. Args: - batch: Dictionary with 'image' and 'label' numpy arrays + batch: PyArrow Table to convert Returns: - Tuple of (image_tensor, label_tensor) on the target device + Tuple of (image_tensor, label_tensor) """ from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, + arrow_batch_to_tensors, + move_tensors_to_device, + ) + + tensors = arrow_batch_to_tensors(batch, dtypes=self.dtypes, device=None) + tensors = move_tensors_to_device(tensors, self.device) + return tensors["image"], tensors["label"] + + +class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): + """Factory for creating Ray DataLoader for image classification tasks.""" + + def __init__(self, benchmark_config: BenchmarkConfig): + super().__init__(benchmark_config) + + def get_train_dataloader(self): + """Get the training dataloader. + + Returns: + Iterator of training batches + """ + ds_iterator = self._ray_ds_iterators["train"] = ray.train.get_dataset_shard( + "train" + ) + dataloader_config = self.get_dataloader_config() + return iter( + ds_iterator.iter_torch_batches( + batch_size=dataloader_config.train_batch_size, + local_shuffle_buffer_size=( + dataloader_config.local_buffer_shuffle_size + if dataloader_config.local_buffer_shuffle_size > 0 + else None + ), + collate_fn=CustomArrowCollateFn(device=ray.train.torch.get_device()), + prefetch_batches=dataloader_config.prefetch_batches, + ) ) - device = ray.train.torch.get_device() - batch = convert_ndarray_batch_to_torch_tensor_batch(batch, device=device) + def get_val_dataloader(self): + """Get the validation dataloader. - return batch["image"], batch["label"] + Returns: + Iterator of validation batches + """ + ds_iterator = self._ray_ds_iterators["val"] = ray.train.get_dataset_shard("val") + dataloader_config = self.get_dataloader_config() + return iter( + ds_iterator.iter_torch_batches( + batch_size=dataloader_config.validation_batch_size, + collate_fn=CustomArrowCollateFn(device=ray.train.torch.get_device()), + prefetch_batches=dataloader_config.prefetch_batches, + ) + ) class ImageClassificationMockDataLoaderFactory(BaseDataLoaderFactory): diff --git a/release/train_tests/benchmark/ray_dataloader_factory.py b/release/train_tests/benchmark/ray_dataloader_factory.py index 28bb774fff7e..efa76bb7b49d 100644 --- a/release/train_tests/benchmark/ray_dataloader_factory.py +++ b/release/train_tests/benchmark/ray_dataloader_factory.py @@ -1,7 +1,6 @@ from abc import abstractmethod -from typing import Any, Dict, Tuple +from typing import Any, Dict -import torch import ray.train from ray.data import Dataset @@ -32,42 +31,22 @@ def get_ray_datasets(self) -> Dict[str, Dataset]: pass @abstractmethod - def collate_fn(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - """Get the collate function for the dataloader. + def get_train_dataloader(self): + """Get the training dataloader. Returns: - A function that takes a batch and returns a tuple of tensors. + Iterator of training batches """ pass - def get_train_dataloader(self): - ds_iterator = self._ray_ds_iterators["train"] = ray.train.get_dataset_shard( - "train" - ) - dataloader_config = self.get_dataloader_config() - return iter( - ds_iterator.iter_torch_batches( - batch_size=dataloader_config.train_batch_size, - local_shuffle_buffer_size=( - dataloader_config.local_buffer_shuffle_size - if dataloader_config.local_buffer_shuffle_size > 0 - else None - ), - collate_fn=self.collate_fn, - prefetch_batches=dataloader_config.prefetch_batches, - ) - ) - + @abstractmethod def get_val_dataloader(self): - ds_iterator = self._ray_ds_iterators["val"] = ray.train.get_dataset_shard("val") - dataloader_config = self.get_dataloader_config() - return iter( - ds_iterator.iter_torch_batches( - batch_size=dataloader_config.validation_batch_size, - collate_fn=self.collate_fn, - prefetch_batches=dataloader_config.prefetch_batches, - ) - ) + """Get the validation dataloader. + + Returns: + Iterator of validation batches + """ + pass def get_metrics(self) -> Dict[str, Any]: metrics = {} From 62198958de28d90103ab428e52adc764f20ca166 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Fri, 25 Apr 2025 04:41:29 +0000 Subject: [PATCH 16/30] Train release test: Enable multiprocess spawn (CUDA compatability) Signed-off-by: Srinath Krishnamachari --- release/train_tests/benchmark/config.py | 4 ++-- .../train_tests/benchmark/torch_dataloader_factory.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/release/train_tests/benchmark/config.py b/release/train_tests/benchmark/config.py index e2d9a0bf6d6e..524252fdb51c 100644 --- a/release/train_tests/benchmark/config.py +++ b/release/train_tests/benchmark/config.py @@ -39,8 +39,8 @@ class BenchmarkConfig(BaseModel): max_failures: int = 0 task: str = "image_classification" - locality_with_output: bool = False - actor_locality_enabled: bool = False + locality_with_output: bool = True + actor_locality_enabled: bool = True enable_shard_locality: bool = True # Data diff --git a/release/train_tests/benchmark/torch_dataloader_factory.py b/release/train_tests/benchmark/torch_dataloader_factory.py index e68486187873..4ca01fc7e613 100644 --- a/release/train_tests/benchmark/torch_dataloader_factory.py +++ b/release/train_tests/benchmark/torch_dataloader_factory.py @@ -1,5 +1,6 @@ from typing import Dict, Iterator, Tuple import logging +import multiprocessing from abc import ABC, abstractmethod import torch @@ -14,6 +15,16 @@ logger = ContextLoggerAdapter(logging.getLogger(__name__)) +# Set multiprocessing start method to 'spawn' for CUDA compatibility +if torch.cuda.is_available(): + try: + multiprocessing.set_start_method("spawn", force=True) + logger.info( + "Set multiprocessing start method to 'spawn' for CUDA compatibility" + ) + except RuntimeError: + logger.info("Multiprocessing start method already set") + class TorchDataLoaderFactory(BaseDataLoaderFactory, ABC): """Factory for creating PyTorch DataLoaders.""" From 817b5be9cb5b6a8a1d8ee410f77c633a9fa4b5c3 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Fri, 25 Apr 2025 05:48:11 +0000 Subject: [PATCH 17/30] Fixes Signed-off-by: Srinath Krishnamachari --- release/train_tests/benchmark/train_benchmark.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/release/train_tests/benchmark/train_benchmark.py b/release/train_tests/benchmark/train_benchmark.py index 89c2a1f277fe..8b70b03975c3 100644 --- a/release/train_tests/benchmark/train_benchmark.py +++ b/release/train_tests/benchmark/train_benchmark.py @@ -288,7 +288,8 @@ def get_metrics(self) -> Dict[str, float]: num_workers = ray.train.get_context().get_world_size() train_time = ( self._metrics["train/step"].get() - + self._metrics["train/iter_first_batch"].get() + # Exclude the time it takes to get the first batch. + # + self._metrics["train/iter_first_batch"].get() + self._metrics["train/iter_batch"].get() ) if train_time > 0: @@ -301,7 +302,8 @@ def get_metrics(self) -> Dict[str, float]: validation_time = ( self._metrics["validation/step"].get() - + self._metrics["validation/iter_first_batch"].get() + # Exclude the time it takes to get the first batch. + # + self._metrics["validation/iter_first_batch"].get() + self._metrics["validation/iter_batch"].get() ) if validation_time > 0: From 9bc89a0d3ebb79e32cdbff95d593bc194c8a5353 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Fri, 25 Apr 2025 06:39:13 +0000 Subject: [PATCH 18/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 4 +++- python/ray/data/iterator.py | 5 ++++- python/ray/data/tests/test_iterator.py | 11 +++++++---- .../benchmark/image_classification/factory.py | 4 +++- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index b7c1487af47d..7b72ee3fbc71 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -346,6 +346,7 @@ def arrow_batch_to_tensors( batch: pyarrow.Table, dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, + combine_chunks: bool = False, ) -> Union[ torch.Tensor, List[torch.Tensor], @@ -359,6 +360,8 @@ def arrow_batch_to_tensors( dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype will be inferred from the NumPy ndarray data. device: Optional device to place tensors on + combine_chunks: If True, combine chunks in Arrow batch before converting to + tensors. Returns: PyTorch tensors converted from the Arrow batch, can be: @@ -369,7 +372,6 @@ def arrow_batch_to_tensors( """ from ray.data._internal.arrow_ops import transform_pyarrow - combine_chunks: bool = device is None or device == "cpu" if combine_chunks: numpy_batch = transform_pyarrow.table_to_numpy_dict_combined( batch, diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 357034ddd08e..5c142f236cf0 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -221,7 +221,10 @@ def __call__( arrow_batch_to_tensors, ) - return arrow_batch_to_tensors(batch, dtypes=self.dtypes, device=None) + combine_chunks = self.device == "cpu" + return arrow_batch_to_tensors( + batch, dtypes=self.dtypes, device=None, combine_chunks=combine_chunks + ) @DeveloperAPI diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 52ddde4cd460..a04d6f87b7d0 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -225,7 +225,10 @@ def __call__(self, batch: pyarrow.Table) -> torch.Tensor: ) return arrow_batch_to_tensors( - modified_batch, dtypes=self.dtypes, device=self.device + modified_batch, + dtypes=self.dtypes, + device=self.device, + combine_chunks=self.device == "cpu", )["id"] class CustomNumpyBatchCollateFn(NumpyBatchCollateFn): @@ -267,9 +270,9 @@ def __call__(self, batch: pd.DataFrame) -> torch.Tensor: )["id"] return { - "arrow": CustomArrowBatchCollateFn(), - "numpy": CustomNumpyBatchCollateFn(), - "pandas": CustomPandasBatchCollateFn(), + "arrow": CustomArrowBatchCollateFn(device="cpu"), + "numpy": CustomNumpyBatchCollateFn(device="cpu"), + "pandas": CustomPandasBatchCollateFn(device="cpu"), } diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index 7fd9b8151815..c8be8a9834ad 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -219,7 +219,9 @@ def __call__(self, batch: "pyarrow.Table") -> Tuple[torch.Tensor, torch.Tensor]: move_tensors_to_device, ) - tensors = arrow_batch_to_tensors(batch, dtypes=self.dtypes, device=None) + tensors = arrow_batch_to_tensors( + batch, dtypes=self.dtypes, device=None, combine_chunks=self.device == "cpu" + ) tensors = move_tensors_to_device(tensors, self.device) return tensors["image"], tensors["label"] From e2333c305abd7fdaadff3dea5c31e1b3099c41ca Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Fri, 25 Apr 2025 19:05:22 +0000 Subject: [PATCH 19/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- release/train_tests/benchmark/config.py | 4 ++-- .../benchmark/torch_dataloader_factory.py | 20 +++++++++---------- .../train_tests/benchmark/train_benchmark.py | 6 ++---- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/release/train_tests/benchmark/config.py b/release/train_tests/benchmark/config.py index 524252fdb51c..e2d9a0bf6d6e 100644 --- a/release/train_tests/benchmark/config.py +++ b/release/train_tests/benchmark/config.py @@ -39,8 +39,8 @@ class BenchmarkConfig(BaseModel): max_failures: int = 0 task: str = "image_classification" - locality_with_output: bool = True - actor_locality_enabled: bool = True + locality_with_output: bool = False + actor_locality_enabled: bool = False enable_shard_locality: bool = True # Data diff --git a/release/train_tests/benchmark/torch_dataloader_factory.py b/release/train_tests/benchmark/torch_dataloader_factory.py index 4ca01fc7e613..a4d1fc346e05 100644 --- a/release/train_tests/benchmark/torch_dataloader_factory.py +++ b/release/train_tests/benchmark/torch_dataloader_factory.py @@ -1,6 +1,5 @@ from typing import Dict, Iterator, Tuple import logging -import multiprocessing from abc import ABC, abstractmethod import torch @@ -15,16 +14,6 @@ logger = ContextLoggerAdapter(logging.getLogger(__name__)) -# Set multiprocessing start method to 'spawn' for CUDA compatibility -if torch.cuda.is_available(): - try: - multiprocessing.set_start_method("spawn", force=True) - logger.info( - "Set multiprocessing start method to 'spawn' for CUDA compatibility" - ) - except RuntimeError: - logger.info("Multiprocessing start method already set") - class TorchDataLoaderFactory(BaseDataLoaderFactory, ABC): """Factory for creating PyTorch DataLoaders.""" @@ -71,6 +60,15 @@ def __init__( f"across {num_gpus} GPUs" ) + import torch.multiprocessing as mp + + if torch.cuda.is_available(): + try: + mp.set_start_method("spawn", force=True) + logger.info("Set multiprocessing start method to 'spawn' for CUDA compatibility") + except RuntimeError: + logger.info("Multiprocessing start method already set") + def _get_device(self) -> torch.device: """Get the device for the current worker using Ray Train's device management.""" try: diff --git a/release/train_tests/benchmark/train_benchmark.py b/release/train_tests/benchmark/train_benchmark.py index 8b70b03975c3..89c2a1f277fe 100644 --- a/release/train_tests/benchmark/train_benchmark.py +++ b/release/train_tests/benchmark/train_benchmark.py @@ -288,8 +288,7 @@ def get_metrics(self) -> Dict[str, float]: num_workers = ray.train.get_context().get_world_size() train_time = ( self._metrics["train/step"].get() - # Exclude the time it takes to get the first batch. - # + self._metrics["train/iter_first_batch"].get() + + self._metrics["train/iter_first_batch"].get() + self._metrics["train/iter_batch"].get() ) if train_time > 0: @@ -302,8 +301,7 @@ def get_metrics(self) -> Dict[str, float]: validation_time = ( self._metrics["validation/step"].get() - # Exclude the time it takes to get the first batch. - # + self._metrics["validation/iter_first_batch"].get() + + self._metrics["validation/iter_first_batch"].get() + self._metrics["validation/iter_batch"].get() ) if validation_time > 0: From 0b2debe57779c37c040aecdecf0d5b0f2d3e5b6d Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Fri, 25 Apr 2025 19:33:52 +0000 Subject: [PATCH 20/30] Lint Signed-off-by: Srinath Krishnamachari --- release/train_tests/benchmark/torch_dataloader_factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/release/train_tests/benchmark/torch_dataloader_factory.py b/release/train_tests/benchmark/torch_dataloader_factory.py index a4d1fc346e05..e38682a9d049 100644 --- a/release/train_tests/benchmark/torch_dataloader_factory.py +++ b/release/train_tests/benchmark/torch_dataloader_factory.py @@ -65,7 +65,9 @@ def __init__( if torch.cuda.is_available(): try: mp.set_start_method("spawn", force=True) - logger.info("Set multiprocessing start method to 'spawn' for CUDA compatibility") + logger.info( + "Set multiprocessing start method to 'spawn' for CUDA compatibility" + ) except RuntimeError: logger.info("Multiprocessing start method already set") From 2ed36657777de7946aacd901045a7b20803a750e Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Fri, 25 Apr 2025 19:33:52 +0000 Subject: [PATCH 21/30] Fixes Signed-off-by: Srinath Krishnamachari --- .../benchmark/torch_dataloader_factory.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/release/train_tests/benchmark/torch_dataloader_factory.py b/release/train_tests/benchmark/torch_dataloader_factory.py index e38682a9d049..ce31c86f25ac 100644 --- a/release/train_tests/benchmark/torch_dataloader_factory.py +++ b/release/train_tests/benchmark/torch_dataloader_factory.py @@ -15,6 +15,18 @@ logger = ContextLoggerAdapter(logging.getLogger(__name__)) +if torch.cuda.is_available(): + import torch.multiprocessing as mp + + try: + mp.set_start_method("spawn", force=True) + logger.info( + "Set multiprocessing start method to 'spawn' for CUDA compatibility" + ) + except RuntimeError: + logger.info("Multiprocessing start method already set") + + class TorchDataLoaderFactory(BaseDataLoaderFactory, ABC): """Factory for creating PyTorch DataLoaders.""" @@ -60,17 +72,6 @@ def __init__( f"across {num_gpus} GPUs" ) - import torch.multiprocessing as mp - - if torch.cuda.is_available(): - try: - mp.set_start_method("spawn", force=True) - logger.info( - "Set multiprocessing start method to 'spawn' for CUDA compatibility" - ) - except RuntimeError: - logger.info("Multiprocessing start method already set") - def _get_device(self) -> torch.device: """Get the device for the current worker using Ray Train's device management.""" try: From e4a760b0283915192e245e0c09263518a73fdc0f Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 24 Apr 2025 23:01:55 +0000 Subject: [PATCH 22/30] Addressed review comments Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 17 +- python/ray/data/collate_fn.py | 198 ++++++++++++++ python/ray/data/iterator.py | 329 +++--------------------- python/ray/data/tests/test_iterator.py | 9 +- 4 files changed, 241 insertions(+), 312 deletions(-) create mode 100644 python/ray/data/collate_fn.py diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 7b72ee3fbc71..cbf23c31dcab 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -347,12 +347,7 @@ def arrow_batch_to_tensors( dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, combine_chunks: bool = False, -) -> Union[ - torch.Tensor, - List[torch.Tensor], - Dict[str, torch.Tensor], - Dict[str, List[torch.Tensor]], -]: +) -> Dict[str, List[torch.Tensor]]: """Convert PyArrow batch to PyTorch tensors. Args: @@ -364,11 +359,7 @@ def arrow_batch_to_tensors( tensors. Returns: - PyTorch tensors converted from the Arrow batch, can be: - - A single tensor - - A list of tensors - - A dict of column name to tensor - - A dict of column name to list of tensors + A dictionary of column name to list of tensors """ from ray.data._internal.arrow_ops import transform_pyarrow @@ -398,7 +389,7 @@ def numpy_batch_to_torch_tensors( batch: Dict[str, np.ndarray], dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None, device: Optional[str] = None, -) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: +) -> Dict[str, List[torch.Tensor]]: """Convert a dictionary of numpy arrays to PyTorch tensors. Args: @@ -408,7 +399,7 @@ def numpy_batch_to_torch_tensors( device: Optional device to place tensors on Returns: - Either a single PyTorch tensor or a dict mapping column names to tensors + A dictionary of column name to list of tensors """ from ray.air._internal.torch_utils import ( convert_ndarray_batch_to_torch_tensor_batch, diff --git a/python/ray/data/collate_fn.py b/python/ray/data/collate_fn.py new file mode 100644 index 000000000000..3df6e804ebb3 --- /dev/null +++ b/python/ray/data/collate_fn.py @@ -0,0 +1,198 @@ +import abc +from typing import Dict, Generic, Optional, TypeVar, Union, List, TYPE_CHECKING, Any + +import numpy as np + +from ray.data.block import DataBatch +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import torch + import pyarrow + import pandas + + from ray.data.dataset import CollatedData + + +DataBatchType = TypeVar("DataBatchType", bound=DataBatch) + + +@DeveloperAPI +class CollateFn(Generic[DataBatchType]): + """Abstract interface for collate_fn for iter_torch_batches. See doc-string of + `collate_fn` for more details. + """ + + @abc.abstractmethod + def __call__(self, batch: DataBatchType) -> "CollatedData": + """Convert a batch of data to collated format. + + Args: + batch: The input batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): + """Collate function that takes pyarrow.Table as the input batch type.""" + + def __call__(self, batch: "pyarrow.Table") -> "CollatedData": + """Convert a batch of pyarrow.Table to collated format. + + Args: + batch: The input pyarrow.Table batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): + """Collate function that takes a dictionary of numpy arrays as the input batch type.""" + + def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": + """Convert a batch of numpy arrays to collated format. + + Args: + batch: The input dictionary of numpy arrays batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]): + """Collate function that takes a pandas.DataFrame as the input batch type.""" + + def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": + """Convert a batch of pandas.DataFrame to collated format. + + Args: + batch: The input pandas.DataFrame batch to collate. + + Returns: + The collated data in the format expected by the model. + """ + ... + + +@DeveloperAPI +class DefaultArrowCollateFn(ArrowBatchCollateFn): + """Default collate function for converting Arrow batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: "pyarrow.Table") -> Dict[str, List["torch.Tensor"]]: + """Convert an Arrow batch to PyTorch tensors. + + Args: + batch: PyArrow Table to convert + + Returns: + A dictionary of column name to list of tensors + """ + from ray.air._internal.torch_utils import ( + arrow_batch_to_tensors, + ) + + combine_chunks = self.device.type == "cpu" + return arrow_batch_to_tensors( + batch, dtypes=self.dtypes, device=None, combine_chunks=combine_chunks + ) + + +@DeveloperAPI +class DefaultNumpyCollateFn(NumpyBatchCollateFn): + """Default collate function for converting Numpy batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List["torch.Tensor"]]: + """Convert a Numpy batch to PyTorch tensors. + + Args: + batch: The input dictionary of numpy arrays batch to collate + + Returns: + A dictionary of column name to list of tensors + """ + from ray.air._internal.torch_utils import ( + numpy_batch_to_torch_tensors, + ) + + return numpy_batch_to_torch_tensors( + batch, dtypes=self.dtypes, device=self.device + ) + + +@DeveloperAPI +class DefaultPandasCollateFn(PandasBatchCollateFn): + """Default collate function for converting Pandas batches to PyTorch tensors.""" + + def __init__( + self, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + ): + self.dtypes = dtypes + self.device = device + + def __call__(self, batch: "pandas.DataFrame") -> "torch.Tensor": + """Convert a Pandas batch to PyTorch tensors. + + Args: + batch: The input pandas.DataFrame batch to collate + + Returns: + A PyTorch tensor containing the collated batch. + """ + from ray.air._internal.torch_utils import ( + convert_pandas_to_torch_tensor, + ) + + return convert_pandas_to_torch_tensor( + batch, dtypes=self.dtypes, device=self.device + ) + + +def default_finalize_fn( + batch: Union[Dict[str, List["torch.Tensor"]], Any], device: Optional[str] = None +) -> Union[Dict[str, "torch.Tensor"], Any]: + """Default finalize function for moving PyTorch tensors to device. + + Args: + batch: Input batch to move to device. Can be: + - Dictionary mapping column names to lists of tensors + - Any other type supported by move_tensors_to_device + device: Target device to move tensors to + + Returns: + Batch with tensors moved to the target device. Type matches input type: + - If input is Dict[str, List[torch.Tensor]], returns Dict[str, torch.Tensor] + - Otherwise returns the same type as input with tensors moved to device + """ + from ray.air._internal.torch_utils import ( + move_tensors_to_device, + ) + + return move_tensors_to_device(batch, device=device) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 5c142f236cf0..8439c978bb3c 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -12,8 +12,8 @@ Tuple, TypeVar, Union, - Generic, ) +import warnings import numpy as np @@ -25,13 +25,19 @@ from ray.data._internal.stats import DatasetStats, StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.context import DataContext -from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.util.annotations import PublicAPI +from ray.data.collate_fn import ( + CollateFn, + ArrowBatchCollateFn, + NumpyBatchCollateFn, + PandasBatchCollateFn, + DefaultArrowCollateFn, + default_finalize_fn, +) if TYPE_CHECKING: import tensorflow as tf import torch - import pyarrow - import pandas from ray.data.dataset import ( CollatedData, @@ -59,286 +65,6 @@ def __iter__(self): return self.iterator_gen() -DataBatchType = TypeVar("DataBatchType", bound=DataBatch) - - -@DeveloperAPI -class CollateFn(Generic[DataBatchType]): - """A function that converts a DataBatch to a CollatedData.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - self.dtypes = dtypes - self.device = device - - @abc.abstractmethod - def __call__(self, batch: DataBatchType) -> "CollatedData": - """Convert a batch of data to collated format. - - Args: - batch: The input batch to collate - - Returns: - The collated data in the format expected by the model - """ - ... - - -@DeveloperAPI -class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): - """Collate function for converting Arrow tables to PyTorch tensors.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - super().__init__(dtypes=dtypes, device=device) - - def __call__(self, batch: "pyarrow.Table") -> "CollatedData": - """Convert a PyArrow table to PyTorch tensors. - - Args: - batch: PyArrow table to convert - - Returns: - Collated PyTorch tensors - """ - ... - - -@DeveloperAPI -class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): - """Collate function for converting Numpy batches to PyTorch tensors.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - super().__init__(dtypes=dtypes, device=device) - - def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": - """Convert a Numpy batch to PyTorch tensors. - - Args: - batch: Numpy batch to convert - - Returns: - Collated PyTorch tensors - """ - ... - - -@DeveloperAPI -class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]): - """Collate function for converting Pandas batches to PyTorch tensors.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - super().__init__(dtypes=dtypes, device=device) - - def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": - """Convert a Pandas batch to PyTorch tensors. - - Args: - batch: Pandas batch to convert - - Returns: - Collated PyTorch tensors - """ - ... - - -@DeveloperAPI -class DefaultArrowCollateFn(ArrowBatchCollateFn): - """Default collate function for converting Arrow batches to PyTorch tensors.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - super().__init__(dtypes=dtypes, device=device) - - def __call__( - self, batch: "pyarrow.Table" - ) -> Union[ - "torch.Tensor", - List["torch.Tensor"], - Dict[str, "torch.Tensor"], - Dict[str, List["torch.Tensor"]], - ]: - """Convert an Arrow batch to PyTorch tensors. - - Args: - batch: PyArrow Table to convert - - Returns: - Collated PyTorch tensors, can be: - - A single tensor - - A list of tensors - - A dict of column name to tensor - - A dict of column name to list of tensors - """ - from ray.air._internal.torch_utils import ( - arrow_batch_to_tensors, - ) - - combine_chunks = self.device == "cpu" - return arrow_batch_to_tensors( - batch, dtypes=self.dtypes, device=None, combine_chunks=combine_chunks - ) - - -@DeveloperAPI -class DefaultNumpyCollateFn(NumpyBatchCollateFn): - """Default collate function for converting Numpy batches to PyTorch tensors.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - super().__init__(dtypes=dtypes, device=device) - - def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": - """Convert a Pandas batch to PyTorch tensors. - - Args: - batch: Pandas batch to convert - - Returns: - Collated PyTorch tensors - """ - from ray.air._internal.torch_utils import ( - numpy_batch_to_torch_tensors, - ) - - return numpy_batch_to_torch_tensors(batch, dtypes=self.dtypes, device=None) - - -@DeveloperAPI -class DefaultPandasCollateFn(PandasBatchCollateFn): - """Default collate function for converting Pandas batches to PyTorch tensors.""" - - def __init__( - self, - dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, - device: Optional[str] = None, - ): - """Initialize the collate function. - - Args: - dtypes: Optional torch dtype(s) for the tensors - device: Optional device to place tensors on - """ - super().__init__(dtypes=dtypes, device=device) - - def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": - """Convert a Pandas batch to PyTorch tensors. - - Args: - batch: Pandas batch to convert - - Returns: - Collated PyTorch tensors - """ - from ray.air._internal.torch_utils import ( - convert_pandas_batch_to_torch_tensor_batch, - ) - - return convert_pandas_batch_to_torch_tensor_batch( - batch, dtypes=self.dtypes, device=None - ) - - -@DeveloperAPI -class DefaultFinalizeFn: - """Default finalize function for moving PyTorch tensors to device.""" - - def __init__( - self, - device: Optional[str] = None, - ): - """Initialize the finalize function. - - Args: - device: Optional device to place tensors on - """ - self.device = device - - def __call__( - self, - batch: Union[ - "torch.Tensor", - List["torch.Tensor"], - Dict[str, "torch.Tensor"], - Dict[str, List["torch.Tensor"]], - ], - ) -> Union[ - "torch.Tensor", - List["torch.Tensor"], - Dict[str, "torch.Tensor"], - Dict[str, List["torch.Tensor"]], - ]: - """Move tensors to device. - - Args: - batch: Tensor or collection of tensors to move to device - - Returns: - Tensor or collection of tensors moved to the target device - """ - from ray.air._internal.torch_utils import ( - move_tensors_to_device, - ) - - return move_tensors_to_device(batch, device=self.device) - - @PublicAPI class DataIterator(abc.ABC): """An iterator for reading records from a :class:`~Dataset`. @@ -586,17 +312,24 @@ def iter_torch_batches( Dataset is passed to Ray Train and ``collate_fn`` is not provided. Otherwise, defaults to CPU. You can't use this parameter with ``collate_fn``. - collate_fn: A function to convert a PyArrow Table or Numpy batch to PyTorch tensors. - When this parameter is specified, the user should manually handle the - host to device data transfer outside of ``collate_fn``. - This is useful for further processing the data after it has been - batched. Potential use cases include collating along a dimension other - than the first, padding sequences of various lengths, or generally - handling batches of different length tensors. If not provided, the - default collate function is used which simply converts the batch to - PyTorch tensors. This API is still experimental and is subject to - change. You can't use this parameter in conjunction with ``dtypes`` - or ``device``. + collate_fn: [Alpha] A function to customize how data batches are collated + before being passed to the model. This is useful for last-mile data + formatting such as padding, masking, or packaging tensors into custom + data structures. If not provided, `iter_torch_batches` automatically + converts batches to `torch.Tensor`s and moves them to the device + assigned to the current worker. The input to `collate_fn` may be: + (1) dict of np.ndarray, where you should provide a function that + takes in a dict of Numpy arrays + (2) pd.DataFrame, where you should provide a callable class that + subclasses `PandasCollateFn` + (3) pyarrow.Table, where you should provide a callable class that + subclasses `ArrowCollateFn` (recommended for best performance) + The output can be any type. If the output is a `torch.Tensor`, + `dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be + automatically moved to the current worker's device. For other types, + you must handle device transfer manually in your training loop. + Note: This function is called in a multi-threaded context; avoid using + thread-unsafe code. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the @@ -632,7 +365,7 @@ def iter_torch_batches( dtypes=dtypes, device=device, ) - finalize_fn = DefaultFinalizeFn(device=device) + finalize_fn = default_finalize_fn batch_format = "pyarrow" elif isinstance(collate_fn, ArrowBatchCollateFn): batch_format = "pyarrow" @@ -642,6 +375,10 @@ def iter_torch_batches( batch_format = "pandas" elif callable(collate_fn): batch_format = "numpy" + warnings.warn( + "Passing a callable to `iter_torch_batches` is deprecated and suggest using `ArrowBatchCollateFn` for the best performance.", + DeprecationWarning, + ) else: raise ValueError(f"Unsupported collate function: {type(collate_fn)}") diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index a04d6f87b7d0..d6734e5a2745 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -213,7 +213,8 @@ def __init__( dtypes: Optional[Dict[str, torch.dtype]] = None, device: Optional[str] = None, ): - super().__init__(dtypes=dtypes, device=device) + self.dtypes = dtypes + self.device = device def __call__(self, batch: pyarrow.Table) -> torch.Tensor: """Add 5 to the "id" column at the Arrow level.""" @@ -237,7 +238,8 @@ def __init__( dtypes: Optional[Dict[str, torch.dtype]] = None, device: Optional[str] = None, ): - super().__init__(dtypes=dtypes, device=device) + self.dtypes = dtypes + self.device = device def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor: """Add 5 to the "id" array.""" @@ -256,7 +258,8 @@ def __init__( dtypes: Optional[Dict[str, torch.dtype]] = None, device: Optional[str] = None, ): - super().__init__(dtypes=dtypes, device=device) + self.dtypes = dtypes + self.device = device def __call__(self, batch: pd.DataFrame) -> torch.Tensor: """Add 5 to the "id" column.""" From 5d4d7a1896d4a65f454b250d22a313c27b3df6df Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Mon, 28 Apr 2025 20:32:58 +0000 Subject: [PATCH 23/30] Misc fixes Signed-off-by: Srinath Krishnamachari --- release/train_tests/benchmark/image_classification/factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index c8be8a9834ad..f96f7aa3b4be 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -251,7 +251,7 @@ def get_train_dataloader(self): else None ), collate_fn=CustomArrowCollateFn(device=ray.train.torch.get_device()), - prefetch_batches=dataloader_config.prefetch_batches, + prefetch_batches=dataloader_config.ray_data_prefetch_batches, ) ) @@ -267,7 +267,7 @@ def get_val_dataloader(self): ds_iterator.iter_torch_batches( batch_size=dataloader_config.validation_batch_size, collate_fn=CustomArrowCollateFn(device=ray.train.torch.get_device()), - prefetch_batches=dataloader_config.prefetch_batches, + prefetch_batches=dataloader_config.ray_data_prefetch_batches, ) ) From 02a8e7c162162c689480272c297b3690ba2de96a Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Mon, 28 Apr 2025 22:53:51 +0000 Subject: [PATCH 24/30] Lint Signed-off-by: Srinath Krishnamachari --- python/ray/data/collate_fn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/data/collate_fn.py b/python/ray/data/collate_fn.py index 3df6e804ebb3..d1e31c716f9f 100644 --- a/python/ray/data/collate_fn.py +++ b/python/ray/data/collate_fn.py @@ -175,6 +175,7 @@ def __call__(self, batch: "pandas.DataFrame") -> "torch.Tensor": ) +@DeveloperAPI def default_finalize_fn( batch: Union[Dict[str, List["torch.Tensor"]], Any], device: Optional[str] = None ) -> Union[Dict[str, "torch.Tensor"], Any]: From 7d905a47742a5aca7871d93fe676b0ad6a88fd7b Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 30 Apr 2025 05:48:18 +0000 Subject: [PATCH 25/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- .../benchmark/image_classification/factory.py | 45 +++---------------- .../benchmark/ray_dataloader_factory.py | 45 ++++++++++++++----- 2 files changed, 38 insertions(+), 52 deletions(-) diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index f96f7aa3b4be..f4174c7fdb29 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -1,7 +1,7 @@ # Standard library imports import logging import time -from typing import Dict, Tuple, Iterator, Generator, Optional, Union +from typing import Dict, Tuple, Iterator, Generator, Optional, Union, Type # Third-party imports import torch @@ -9,6 +9,7 @@ import ray import ray.data import ray.train +from ray.data.iterator import ArrowBatchCollateFn # Local imports from config import BenchmarkConfig @@ -189,7 +190,7 @@ def create_batch_iterator( raise -class CustomArrowCollateFn(ray.data.iterator.ArrowBatchCollateFn): +class CustomArrowCollateFn(ArrowBatchCollateFn): """Custom collate function for converting Arrow batches to PyTorch tensors.""" def __init__( @@ -232,44 +233,8 @@ class ImageClassificationRayDataLoaderFactory(RayDataLoaderFactory): def __init__(self, benchmark_config: BenchmarkConfig): super().__init__(benchmark_config) - def get_train_dataloader(self): - """Get the training dataloader. - - Returns: - Iterator of training batches - """ - ds_iterator = self._ray_ds_iterators["train"] = ray.train.get_dataset_shard( - "train" - ) - dataloader_config = self.get_dataloader_config() - return iter( - ds_iterator.iter_torch_batches( - batch_size=dataloader_config.train_batch_size, - local_shuffle_buffer_size=( - dataloader_config.local_buffer_shuffle_size - if dataloader_config.local_buffer_shuffle_size > 0 - else None - ), - collate_fn=CustomArrowCollateFn(device=ray.train.torch.get_device()), - prefetch_batches=dataloader_config.ray_data_prefetch_batches, - ) - ) - - def get_val_dataloader(self): - """Get the validation dataloader. - - Returns: - Iterator of validation batches - """ - ds_iterator = self._ray_ds_iterators["val"] = ray.train.get_dataset_shard("val") - dataloader_config = self.get_dataloader_config() - return iter( - ds_iterator.iter_torch_batches( - batch_size=dataloader_config.validation_batch_size, - collate_fn=CustomArrowCollateFn(device=ray.train.torch.get_device()), - prefetch_batches=dataloader_config.ray_data_prefetch_batches, - ) - ) + def _get_collate_fn_cls(self) -> Type[ArrowBatchCollateFn]: + return CustomArrowCollateFn class ImageClassificationMockDataLoaderFactory(BaseDataLoaderFactory): diff --git a/release/train_tests/benchmark/ray_dataloader_factory.py b/release/train_tests/benchmark/ray_dataloader_factory.py index bef87c62b1ca..6b0973a96d64 100644 --- a/release/train_tests/benchmark/ray_dataloader_factory.py +++ b/release/train_tests/benchmark/ray_dataloader_factory.py @@ -1,8 +1,8 @@ from abc import abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Type import ray.train -from ray.data import Dataset +from ray.data.iterator import ArrowBatchCollateFn from constants import DatasetKey from config import BenchmarkConfig, RayDataConfig @@ -27,31 +27,52 @@ def __init__(self, benchmark_config: BenchmarkConfig) -> None: data_context.retried_io_errors.append("AWS Error ACCESS_DENIED") @abstractmethod - def get_ray_datasets(self) -> Dict[str, Dataset]: - """Get the Ray datasets for training and validation. - - Returns: - Dict with "train" and "val" Dataset objects - """ + def _get_collate_fn_cls(self) -> Type[ArrowBatchCollateFn]: + """Return the collate function class. Must be implemented by subclass.""" pass - @abstractmethod def get_train_dataloader(self): """Get the training dataloader. Returns: Iterator of training batches """ - pass + ds_iterator = ray.train.get_dataset_shard(DatasetKey.TRAIN) + self._ray_ds_iterators[DatasetKey.TRAIN] = ds_iterator + dataloader_config = self.get_dataloader_config() + return iter( + ds_iterator.iter_torch_batches( + batch_size=dataloader_config.train_batch_size, + local_shuffle_buffer_size=( + dataloader_config.local_buffer_shuffle_size + if dataloader_config.local_buffer_shuffle_size > 0 + else None + ), + collate_fn=self._get_collate_fn_cls()( + device=ray.train.torch.get_device() + ), + prefetch_batches=dataloader_config.ray_data_prefetch_batches, + ) + ) - @abstractmethod def get_val_dataloader(self): """Get the validation dataloader. Returns: Iterator of validation batches """ - pass + ds_iterator = ray.train.get_dataset_shard(DatasetKey.VALID) + self._ray_ds_iterators[DatasetKey.VALID] = ds_iterator + dataloader_config = self.get_dataloader_config() + return iter( + ds_iterator.iter_torch_batches( + batch_size=dataloader_config.validation_batch_size, + collate_fn=self._get_collate_fn_cls()( + device=ray.train.torch.get_device() + ), + prefetch_batches=dataloader_config.ray_data_prefetch_batches, + ) + ) def get_metrics(self) -> Dict[str, Any]: metrics = {} From 9415923965d2943e72a6ed4e62bf8cd07dea8c5f Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 30 Apr 2025 05:58:36 +0000 Subject: [PATCH 26/30] Fixes Signed-off-by: Srinath Krishnamachari --- release/train_tests/benchmark/ray_dataloader_factory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/release/train_tests/benchmark/ray_dataloader_factory.py b/release/train_tests/benchmark/ray_dataloader_factory.py index 6b0973a96d64..4b3e756e8f7d 100644 --- a/release/train_tests/benchmark/ray_dataloader_factory.py +++ b/release/train_tests/benchmark/ray_dataloader_factory.py @@ -39,6 +39,7 @@ def get_train_dataloader(self): """ ds_iterator = ray.train.get_dataset_shard(DatasetKey.TRAIN) self._ray_ds_iterators[DatasetKey.TRAIN] = ds_iterator + dataloader_config = self.get_dataloader_config() return iter( ds_iterator.iter_torch_batches( @@ -52,6 +53,7 @@ def get_train_dataloader(self): device=ray.train.torch.get_device() ), prefetch_batches=dataloader_config.ray_data_prefetch_batches, + drop_last=True, ) ) @@ -63,6 +65,7 @@ def get_val_dataloader(self): """ ds_iterator = ray.train.get_dataset_shard(DatasetKey.VALID) self._ray_ds_iterators[DatasetKey.VALID] = ds_iterator + dataloader_config = self.get_dataloader_config() return iter( ds_iterator.iter_torch_batches( @@ -71,6 +74,7 @@ def get_val_dataloader(self): device=ray.train.torch.get_device() ), prefetch_batches=dataloader_config.ray_data_prefetch_batches, + drop_last=True, ) ) From f41065e2f5d6a6f83aeeb8bc48271813c483593b Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 30 Apr 2025 06:00:18 +0000 Subject: [PATCH 27/30] Fixes Signed-off-by: Srinath Krishnamachari --- .../benchmark/torch_dataloader_factory.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/release/train_tests/benchmark/torch_dataloader_factory.py b/release/train_tests/benchmark/torch_dataloader_factory.py index 86e20e334032..6b361c83767c 100644 --- a/release/train_tests/benchmark/torch_dataloader_factory.py +++ b/release/train_tests/benchmark/torch_dataloader_factory.py @@ -16,18 +16,6 @@ logger = ContextLoggerAdapter(logging.getLogger(__name__)) -if torch.cuda.is_available(): - import torch.multiprocessing as mp - - try: - mp.set_start_method("spawn", force=True) - logger.info( - "Set multiprocessing start method to 'spawn' for CUDA compatibility" - ) - except RuntimeError: - logger.info("Multiprocessing start method already set") - - class TorchDataLoaderFactory(BaseDataLoaderFactory, ABC): """Factory for creating PyTorch DataLoaders.""" @@ -155,6 +143,7 @@ def get_train_dataloader(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: timeout=timeout, drop_last=True, worker_init_fn=self.worker_init_fn if num_workers > 0 else None, + multiprocessing_context="spawn", ) return self.create_batch_iterator(dataloader, device) @@ -208,5 +197,6 @@ def get_val_dataloader(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: timeout=timeout, drop_last=False, worker_init_fn=self.worker_init_fn if num_workers > 0 else None, + multiprocessing_context="spawn", ) return self.create_batch_iterator(dataloader, device) From cd601b6aabb0b396fefbc9e9ba7673da698a8b2a Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Wed, 30 Apr 2025 22:06:39 +0000 Subject: [PATCH 28/30] Misc Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/iterator.py | 14 ++++++++------ .../benchmark/image_classification/factory.py | 3 ++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 8439c978bb3c..c8ee85564fdc 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -318,12 +318,14 @@ def iter_torch_batches( data structures. If not provided, `iter_torch_batches` automatically converts batches to `torch.Tensor`s and moves them to the device assigned to the current worker. The input to `collate_fn` may be: - (1) dict of np.ndarray, where you should provide a function that - takes in a dict of Numpy arrays - (2) pd.DataFrame, where you should provide a callable class that - subclasses `PandasCollateFn` - (3) pyarrow.Table, where you should provide a callable class that - subclasses `ArrowCollateFn` (recommended for best performance) + + 1. dict of np.ndarray, where you should provide a function that + takes in a dict of Numpy arrays + 2. pd.DataFrame, where you should provide a callable class that + subclasses `PandasCollateFn` + 3. pyarrow.Table, where you should provide a callable class that + subclasses `ArrowCollateFn` (recommended for best performance) + The output can be any type. If the output is a `torch.Tensor`, `dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be automatically moved to the current worker's device. For other types, diff --git a/release/train_tests/benchmark/image_classification/factory.py b/release/train_tests/benchmark/image_classification/factory.py index f4174c7fdb29..7c45f0b2bff2 100644 --- a/release/train_tests/benchmark/image_classification/factory.py +++ b/release/train_tests/benchmark/image_classification/factory.py @@ -204,7 +204,8 @@ def __init__( dtypes: Optional torch dtype(s) for the tensors device: Optional device to place tensors on """ - super().__init__(dtypes=dtypes, device=device) + self.dtypes = dtypes + self.device = device def __call__(self, batch: "pyarrow.Table") -> Tuple[torch.Tensor, torch.Tensor]: """Convert an Arrow batch to PyTorch tensors. From ab4c78de955b2f99540b99a67b747e9c021ffd33 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 1 May 2025 18:14:46 +0000 Subject: [PATCH 29/30] Address CI failures Signed-off-by: Srinath Krishnamachari --- python/ray/air/_internal/torch_utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index cbf23c31dcab..8fdea060e565 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -457,6 +457,7 @@ def concat_tensors_to_device( result[row_start:row_end].copy_(t) row_start = row_end + assert isinstance(result, torch.Tensor), "Result must be a torch.Tensor" return result @@ -469,12 +470,7 @@ def move_tensors_to_device( Dict[str, List[torch.Tensor]], ], device: Optional[str] = None, -) -> Union[ - torch.Tensor, - List[torch.Tensor], - Dict[str, torch.Tensor], - Dict[str, List[torch.Tensor]], -]: +) -> Union[torch.Tensor, Dict[str, torch.Tensor],]: """Move tensors to the specified device. Args: @@ -487,6 +483,7 @@ def move_tensors_to_device( Returns: The input tensors moved to the specified device, maintaining the same structure. + Note that any lists of tensors will be concatenated into a single tensor. """ if device is None: return batch @@ -509,4 +506,14 @@ def move_tensors_to_device( else: batch = batch.to(device=device) + if isinstance(batch, dict): + assert all(isinstance(v, torch.Tensor) for v in batch.values()), ( + "All values in dict must be tensors, got: " + f"{[type(v) for v in batch.values()]}" + ) + else: + assert isinstance( + batch, torch.Tensor + ), "Batch must be a Tensor or dict of Tensors" + return batch From 9de36d2f111b51a557b2426c32333fefab57a0f6 Mon Sep 17 00:00:00 2001 From: Srinath Krishnamachari Date: Thu, 1 May 2025 22:29:00 +0000 Subject: [PATCH 30/30] CI Fixes Signed-off-by: Srinath Krishnamachari --- python/ray/data/collate_fn.py | 2 +- python/ray/data/iterator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/collate_fn.py b/python/ray/data/collate_fn.py index d1e31c716f9f..e7bb960707a1 100644 --- a/python/ray/data/collate_fn.py +++ b/python/ray/data/collate_fn.py @@ -109,7 +109,7 @@ def __call__(self, batch: "pyarrow.Table") -> Dict[str, List["torch.Tensor"]]: arrow_batch_to_tensors, ) - combine_chunks = self.device.type == "cpu" + combine_chunks = self.device == "cpu" return arrow_batch_to_tensors( batch, dtypes=self.dtypes, device=None, combine_chunks=combine_chunks ) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index c8ee85564fdc..390717c6c145 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -359,7 +359,7 @@ def iter_torch_batches( if device == "auto": # Use the appropriate device for Ray Train, or falls back to CPU if # Ray Train is not being used. - device = get_device() + device = get_device().type finalize_fn = None if collate_fn is None: