Skip to content

Handle non-contiguous Tensors based GPU transfer #52548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
35edb58
WIP: Handle non-contiguous Tensors GPU transfer
srinathk10 Apr 23, 2025
f5d83ea
Lint
srinathk10 Apr 23, 2025
0cd3a21
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 23, 2025
309b72f
Lint Fixes
srinathk10 Apr 23, 2025
1c465dd
Misc fixes
srinathk10 Apr 23, 2025
d90b8da
Misc fixes
srinathk10 Apr 23, 2025
700e7fe
Misc fixes
srinathk10 Apr 23, 2025
c29dd57
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 23, 2025
3b90c2c
Misc fixes
srinathk10 Apr 23, 2025
678cf7d
Misc Fixes
srinathk10 Apr 23, 2025
ed5c31d
Handle Arrow Array null types in to_numpy
srinathk10 Apr 24, 2025
2f3933d
Misc Fixes
srinathk10 Apr 24, 2025
1ac0f4a
Misc fixes
srinathk10 Apr 24, 2025
1b35c3b
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 24, 2025
fc413df
Merge branch 'master' into srinathk10-to_numpy-null-types
srinathk10 Apr 24, 2025
1360a85
Merge branch 'srinathk10-to_numpy-null-types' into srinathk10-chunked…
srinathk10 Apr 24, 2025
2a0b3b3
Lint
srinathk10 Apr 24, 2025
b1d1835
Fixes
srinathk10 Apr 24, 2025
30f86e0
Merge branch 'master' into srinathk10-to_numpy-null-types
srinathk10 Apr 24, 2025
d4960a6
Merge branch 'srinathk10-to_numpy-null-types' into srinathk10-chunked…
srinathk10 Apr 24, 2025
2bb01f3
Misc Fixes
srinathk10 Apr 24, 2025
859228f
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 24, 2025
f0e8a25
Misc Fixes
srinathk10 Apr 24, 2025
6219895
Train release test: Enable multiprocess spawn (CUDA compatability)
srinathk10 Apr 25, 2025
817b5be
Fixes
srinathk10 Apr 25, 2025
9bc89a0
Misc Fixes
srinathk10 Apr 25, 2025
0cb8a67
Merge branch 'srinathk10-train-release-test-fixes' into srinathk10-ch…
srinathk10 Apr 25, 2025
e2333c3
Misc fixes
srinathk10 Apr 25, 2025
c22efa9
Merge branch 'master' into srinathk10-train-release-test-fixes
srinathk10 Apr 25, 2025
0b2debe
Lint
srinathk10 Apr 25, 2025
2ed3665
Fixes
srinathk10 Apr 25, 2025
c8e6947
Merge branch 'master' into srinathk10-train-release-test-fixes
srinathk10 Apr 25, 2025
cc77476
Merge branch 'srinathk10-train-release-test-fixes' into srinathk10-ch…
srinathk10 Apr 25, 2025
d2e626d
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 25, 2025
e4a760b
Addressed review comments
srinathk10 Apr 24, 2025
acf2be2
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 28, 2025
5d4d7a1
Misc fixes
srinathk10 Apr 28, 2025
02a8e7c
Lint
srinathk10 Apr 28, 2025
13ddd5c
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 28, 2025
f6e3c7f
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 30, 2025
7d905a4
Misc Fixes
srinathk10 Apr 30, 2025
7db43e6
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 30, 2025
9415923
Fixes
srinathk10 Apr 30, 2025
f41065e
Fixes
srinathk10 Apr 30, 2025
cd601b6
Misc Fixes
srinathk10 Apr 30, 2025
403ef84
Merge branch 'master' into srinathk10-chunked-gpu-transfer
srinathk10 Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import torch
import pyarrow

from ray.air._internal.device_manager import get_torch_device_manager_by_context
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
Expand Down Expand Up @@ -292,3 +293,238 @@ def consume_prefix_in_state_dict_if_present_not_in_place(
metadata[newkey] = metadata.pop(key)

return state_dict


def convert_ndarray_list_to_torch_tensor_list(
ndarrays: Union[List[np.ndarray], Dict[str, List[np.ndarray]]],
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
device: Optional[str] = None,
) -> Union[List[torch.Tensor], Dict[str, List[torch.Tensor]]]:
"""Convert a list of NumPy ndarrays or dict of lists of ndarrays to Torch Tensors.

Args:
ndarrays: A list of NumPy ndarrays or a dict mapping column names to lists of
ndarrays that we wish to convert to Torch Tensors.
dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype
will be inferred from the NumPy ndarray data.
device: The device on which the tensor(s) should be placed; if None, the Torch
tensor(s) will be constructed on the CPU.

Returns: A list of Torch Tensors or a dict mapping column names to lists of Tensors.
"""
if isinstance(ndarrays, list):
# Single column case - list of ndarrays
if isinstance(dtypes, dict):
if len(dtypes) != 1:
raise ValueError(
"When constructing a single-column batch, only a single dtype "
f"should be given, instead got: {dtypes}"
)
dtypes = next(iter(dtypes.values()))
return [
convert_ndarray_batch_to_torch_tensor_batch(
ndarray, dtypes=dtypes, device=device
)
for ndarray in ndarrays
]
else:
# Multi-column case - dict of lists of ndarrays
return {
col_name: [
convert_ndarray_batch_to_torch_tensor_batch(
ndarray,
dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
device=device,
)
for ndarray in col_ndarrays
]
for col_name, col_ndarrays in ndarrays.items()
}


def arrow_batch_to_tensors(
batch: pyarrow.Table,
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
device: Optional[str] = None,
) -> Union[
torch.Tensor,
List[torch.Tensor],
Dict[str, torch.Tensor],
Dict[str, List[torch.Tensor]],
]:
"""Convert PyArrow batch to PyTorch tensors.

Args:
batch: PyArrow batch to convert
dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype
will be inferred from the NumPy ndarray data.
device: Optional device to place tensors on

Returns:
PyTorch tensors converted from the Arrow batch, can be:
- A single tensor
- A list of tensors
- A dict of column name to tensor
- A dict of column name to list of tensors
"""
from ray.data._internal.arrow_ops import transform_pyarrow

# Handle both single dtype and dict of dtypes
has_object_dtype = False
if dtypes is not None:
if isinstance(dtypes, dict):
has_object_dtype = any(dtype is object for dtype in dtypes.values())
else:
has_object_dtype = dtypes is object

combine_chunks: bool = (
dtypes is None or not has_object_dtype or device is None or device == "cpu"
)

if combine_chunks:
numpy_batch = transform_pyarrow.table_to_numpy_dict_combined(
batch,
zero_copy_only=False,
)
return convert_ndarray_batch_to_torch_tensor_batch(
numpy_batch,
dtypes=dtypes,
device=device,
)
else:
numpy_list = transform_pyarrow.table_to_numpy_dict_chunked(
batch,
zero_copy_only=False,
)
return convert_ndarray_list_to_torch_tensor_list(
numpy_list,
dtypes=dtypes,
device=device,
)


def numpy_batch_to_torch_tensors(
batch: Dict[str, np.ndarray],
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
device: Optional[str] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Convert a dictionary of numpy arrays to PyTorch tensors.

Args:
batch: Dictionary mapping column names to numpy arrays
dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype
will be inferred from the NumPy ndarray data.
device: Optional device to place tensors on

Returns:
Either a single PyTorch tensor or a dict mapping column names to tensors
"""
from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
)

return convert_ndarray_batch_to_torch_tensor_batch(
batch,
dtypes=dtypes,
device=device,
)


@torch.no_grad()
def concat_tensors_to_device(
tensor_list: List[torch.Tensor],
device: str,
) -> torch.Tensor:
"""Stack list of tensors into a contiguous GPU tensor.

Args:
tensor_list: List of tensors to stack
device: The device to move tensors to

Returns:
A contiguous tensor on the target device
"""
# Assumes tensors have the same shape/dtype
assert tensor_list, "Cannot stack empty list of tensors"
assert all(
isinstance(t, torch.Tensor) for t in tensor_list
), "All items must be torch.Tensor"
assert all(
t.dtype == tensor_list[0].dtype for t in tensor_list
), "All tensors must have the same dtype"
assert all(
t.shape[1:] == tensor_list[0].shape[1:] for t in tensor_list
), "All tensors must have the same shape[1:]"

first = tensor_list[0]
dtype = first.dtype
shape_tail = first.shape[1:]
total_rows = sum(t.shape[0] for t in tensor_list)

# Allocate an empty Tensor on device
result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device)

row_start = 0
for t in tensor_list:
row_end = row_start + t.shape[0]
if t.is_pinned():
# Perform non-blocking transfer if the tensor is pinned
result[row_start:row_end].copy_(t, non_blocking=True)
else:
# Perform blocking transfer if the tensor is not pinned
result[row_start:row_end].copy_(t)
row_start = row_end

return result


@torch.no_grad()
def move_tensors_to_device(
batch: Union[
torch.Tensor,
List[torch.Tensor],
Dict[str, torch.Tensor],
Dict[str, List[torch.Tensor]],
],
device: Optional[str] = None,
) -> Union[
torch.Tensor,
List[torch.Tensor],
Dict[str, torch.Tensor],
Dict[str, List[torch.Tensor]],
]:
"""Move tensors to the specified device.

Args:
batch: A tensor or collection of tensors to move to device. Can be:
- A single tensor
- A list of tensors
- A dict mapping keys to tensors
- A dict mapping keys to lists of tensors
device: The device to move tensors to. If None, tensors are not moved.

Returns:
The input tensors moved to the specified device, maintaining the same structure.
"""
if device is None:
return batch

if isinstance(batch, dict):
for k, v in batch.items():
if isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v):
batch[k] = concat_tensors_to_device(v, device=device)
elif isinstance(v, torch.Tensor):
if v.is_pinned():
batch[k] = v.to(device=device, non_blocking=True)
else:
batch[k] = v.to(device=device)
elif isinstance(batch, list) and all(isinstance(t, torch.Tensor) for t in batch):
batch = concat_tensors_to_device(batch, device=device)
else:
assert isinstance(batch, torch.Tensor), "Batch must be a Tensor"
if batch.is_pinned():
batch = batch.to(device=device, non_blocking=True)
else:
batch = batch.to(device=device)

return batch
61 changes: 61 additions & 0 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,63 @@ def concat_and_sort(
return take_table(ret, indices)


def table_to_numpy_dict_combined(
table: "pyarrow.Table",
*,
zero_copy_only: bool = False,
) -> Dict[str, np.ndarray]:
"""Convert a PyArrow table to a dictionary of numpy arrays.

Args:
table: The PyArrow table to convert.
zero_copy_only: Whether to only use zero-copy transfers.

Returns:
A dictionary of numpy arrays.
"""

numpy_batch = {}
for col_name in table.column_names:
col = table[col_name]
if isinstance(col, pyarrow.ChunkedArray):
combined_array = combine_chunked_array(col)
numpy_batch[col_name] = to_numpy(
combined_array, zero_copy_only=zero_copy_only
)
else:
numpy_batch[col_name] = to_numpy(col, zero_copy_only=zero_copy_only)
return numpy_batch


def table_to_numpy_dict_chunked(
table: "pyarrow.Table",
*,
zero_copy_only: bool = False,
) -> Dict[str, List[np.ndarray]]:
"""Convert a PyArrow table to a dictionary of lists of numpy arrays.

Args:
table: The PyArrow table to convert.
zero_copy_only: Whether to only use zero-copy transfers.

Returns:
A dictionary mapping column names to either:
- A list of numpy arrays (for chunked columns)
- A single numpy array (for non-chunked columns)
"""

numpy_batch = {}
for col_name in table.column_names:
col = table[col_name]
if isinstance(col, pyarrow.ChunkedArray):
numpy_batch[col_name] = [
to_numpy(chunk, zero_copy_only=zero_copy_only) for chunk in col.chunks
]
else:
numpy_batch[col_name] = to_numpy(col, zero_copy_only=zero_copy_only)
return numpy_batch


def to_numpy(
array: Union["pyarrow.Array", "pyarrow.ChunkedArray"],
*,
Expand All @@ -594,8 +651,12 @@ def to_numpy(
import pyarrow as pa

if isinstance(array, pa.Array):
if pa.types.is_null(array.type):
return np.full(len(array), np.nan, dtype=np.float32)
return array.to_numpy(zero_copy_only=zero_copy_only)
elif isinstance(array, pa.ChunkedArray):
if pa.types.is_null(array.type):
return np.full(array.length(), np.nan, dtype=np.float32)
if PYARROW_VERSION >= MIN_PYARROW_VERSION_CHUNKED_ARRAY_TO_NUMPY_ZERO_COPY_ONLY:
return array.to_numpy(zero_copy_only=zero_copy_only)
else:
Expand Down
Loading