-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Handle non-contiguous Tensors based GPU transfer #52548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
python/ray/data/iterator.py
Outdated
batch = batch.to(device=device) | ||
return batch | ||
|
||
collate_fn = DefaultNumpyCollateFn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default should be arrow for better performance.
And we don't need the default numpy / pandas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, fixed now. Changed to Numpy to chase down the issue of null->nan conversion.
elif isinstance(collate_fn, PandasBatchCollateFn): | ||
batch_format = "pandas" | ||
elif callable(collate_fn): | ||
batch_format = "numpy" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a warning that raw collate_fn will be deprecated. and suggest using ArrowBatchCollateFn for the best performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should still keep raw collate fn? (and just default to the numpy version).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, we should keep it. but I'd like to emit a warning to ask users to migrate to the new API.
WDYT?
python/ray/data/iterator.py
Outdated
@@ -301,17 +585,17 @@ def iter_torch_batches( | |||
Dataset is passed to Ray Train and ``collate_fn`` is not provided. | |||
Otherwise, defaults to CPU. You can't use this parameter with | |||
``collate_fn``. | |||
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. | |||
collate_fn: A function to convert a PyArrow Table or Numpy batch to PyTorch tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to rewrite this doc-string and make it more descriptive and easy to read.
E.g.
collate_fn: A function that collates data batches before feeding them to the model. Potential use cases include ...
If not specified, `iter_torch_batches` converts the data to torch.Tensors and transfer them to the given device.
If specified, you can customize the collation logic. The input batch type can be one of .... Arrow is the most recommended for best perf. And the output can be any type. If the output type is one of ..., the data will be automatically transferred to the given device. Otherwise, you need to transfer the batches in your training loop.
Note, this collate_fn will be called in a multi-threaded manner.
cc @justinvyu for more suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and the above example needs to be updated as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to mention how to choose the batch type - subclassing one of ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""
collate_fn: [Alpha] A function to customize how data batches are collated before
being passed to the model. This is useful for last-mile data formatting
such as padding, masking, or packaging tensors into custom data
structures. If not provided, `iter_torch_batches` automatically converts
batches to `torch.Tensor`s and moves them to the device assigned to the
current worker. The input to `collate_fn` may be:
(1) dict of np.ndarray, where you should provide a function
that takes in a dict of Numpy arrays
(2) pd.DataFrame, where you should provide a callable class
that subclasses `PandasCollateFn`
(3) pyarrow.Table, where you should provide a callable class
that subclasses `ArrowCollateFn` (recommended for best performance)
The output can be any type. If the output is a `torch.Tensor`,
`dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be
automatically moved to the current worker's device. For other types,
you must handle device transfer manually in your training loop.
Note: This function is called in a multi-threaded context; avoid using
thread-unsafe code.
"""
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
…-gpu-transfer Signed-off-by: srinathk10 <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
python/ray/data/iterator.py
Outdated
@@ -301,17 +585,17 @@ def iter_torch_batches( | |||
Dataset is passed to Ray Train and ``collate_fn`` is not provided. | |||
Otherwise, defaults to CPU. You can't use this parameter with | |||
``collate_fn``. | |||
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. | |||
collate_fn: A function to convert a PyArrow Table or Numpy batch to PyTorch tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""
collate_fn: [Alpha] A function to customize how data batches are collated before
being passed to the model. This is useful for last-mile data formatting
such as padding, masking, or packaging tensors into custom data
structures. If not provided, `iter_torch_batches` automatically converts
batches to `torch.Tensor`s and moves them to the device assigned to the
current worker. The input to `collate_fn` may be:
(1) dict of np.ndarray, where you should provide a function
that takes in a dict of Numpy arrays
(2) pd.DataFrame, where you should provide a callable class
that subclasses `PandasCollateFn`
(3) pyarrow.Table, where you should provide a callable class
that subclasses `ArrowCollateFn` (recommended for best performance)
The output can be any type. If the output is a `torch.Tensor`,
`dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be
automatically moved to the current worker's device. For other types,
you must handle device transfer manually in your training loop.
Note: This function is called in a multi-threaded context; avoid using
thread-unsafe code.
"""
elif isinstance(collate_fn, PandasBatchCollateFn): | ||
batch_format = "pandas" | ||
elif callable(collate_fn): | ||
batch_format = "numpy" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should still keep raw collate fn? (and just default to the numpy version).
Signed-off-by: srinathk10 <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
…unked-gpu-transfer
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
…unked-gpu-transfer
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: srinathk10 <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Why are these changes needed?
Handle non-contiguous Tensors based GPU transfer. This allows removing the overhead of combining Arrow chunked arrays during Arrow -> Numpy -> Tensor conversion.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.