Skip to content

Handle non-contiguous Tensors based GPU transfer #52548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: master
Choose a base branch
from

Conversation

srinathk10
Copy link
Contributor

Why are these changes needed?

Handle non-contiguous Tensors based GPU transfer. This allows removing the overhead of combining Arrow chunked arrays during Arrow -> Numpy -> Tensor conversion.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

srinathk10 and others added 10 commits April 23, 2025 01:12
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
@justinvyu justinvyu self-assigned this Apr 23, 2025
batch = batch.to(device=device)
return batch

collate_fn = DefaultNumpyCollateFn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default should be arrow for better performance.
And we don't need the default numpy / pandas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed now. Changed to Numpy to chase down the issue of null->nan conversion.

elif isinstance(collate_fn, PandasBatchCollateFn):
batch_format = "pandas"
elif callable(collate_fn):
batch_format = "numpy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a warning that raw collate_fn will be deprecated. and suggest using ArrowBatchCollateFn for the best performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should still keep raw collate fn? (and just default to the numpy version).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we should keep it. but I'd like to emit a warning to ask users to migrate to the new API.
WDYT?

@@ -301,17 +585,17 @@ def iter_torch_batches(
Dataset is passed to Ray Train and ``collate_fn`` is not provided.
Otherwise, defaults to CPU. You can't use this parameter with
``collate_fn``.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
collate_fn: A function to convert a PyArrow Table or Numpy batch to PyTorch tensors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to rewrite this doc-string and make it more descriptive and easy to read.
E.g.

collate_fn: A function that collates data batches before feeding them to the model.  Potential use cases include ...
    If not specified, `iter_torch_batches` converts the data to torch.Tensors and transfer them to the given device. 
    If specified, you can customize the collation logic. The input batch type can be one of .... Arrow is the most recommended for best perf.  And the output can be any type. If the output type is one of ..., the data will be automatically transferred to the given device. Otherwise, you need to transfer the batches in your training loop. 
	Note, this collate_fn will be called in a multi-threaded manner. 

cc @justinvyu for more suggestions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above example needs to be updated as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to mention how to choose the batch type - subclassing one of ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
            collate_fn: [Alpha] A function to customize how data batches are collated before
                being passed to the model. This is useful for last-mile data formatting
                such as padding, masking, or packaging tensors into custom data
                structures. If not provided, `iter_torch_batches` automatically converts
                batches to `torch.Tensor`s and moves them to the device assigned to the
                current worker. The input to `collate_fn` may be:
                    (1) dict of np.ndarray, where you should provide a function
                        that takes in a dict of Numpy arrays
                    (2) pd.DataFrame, where you should provide a callable class
                        that subclasses `PandasCollateFn`
                    (3) pyarrow.Table, where you should provide a callable class
                        that subclasses `ArrowCollateFn` (recommended for best performance)
                The output can be any type. If the output is a `torch.Tensor`,
                `dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be
                automatically moved to the current worker's device. For other types,
                you must handle device transfer manually in your training loop.
                Note: This function is called in a multi-threaded context; avoid using
                thread-unsafe code.
"""

@srinathk10 srinathk10 changed the base branch from master to srinathk10-to_numpy-null-types April 24, 2025 04:16
srinathk10 and others added 2 commits April 23, 2025 21:18
Signed-off-by: Srinath Krishnamachari <[email protected]>
@srinathk10 srinathk10 added the go add ONLY when ready to merge, run all tests label Apr 24, 2025
@@ -301,17 +585,17 @@ def iter_torch_batches(
Dataset is passed to Ray Train and ``collate_fn`` is not provided.
Otherwise, defaults to CPU. You can't use this parameter with
``collate_fn``.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
collate_fn: A function to convert a PyArrow Table or Numpy batch to PyTorch tensors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
            collate_fn: [Alpha] A function to customize how data batches are collated before
                being passed to the model. This is useful for last-mile data formatting
                such as padding, masking, or packaging tensors into custom data
                structures. If not provided, `iter_torch_batches` automatically converts
                batches to `torch.Tensor`s and moves them to the device assigned to the
                current worker. The input to `collate_fn` may be:
                    (1) dict of np.ndarray, where you should provide a function
                        that takes in a dict of Numpy arrays
                    (2) pd.DataFrame, where you should provide a callable class
                        that subclasses `PandasCollateFn`
                    (3) pyarrow.Table, where you should provide a callable class
                        that subclasses `ArrowCollateFn` (recommended for best performance)
                The output can be any type. If the output is a `torch.Tensor`,
                `dict[str, torch.Tensor]`, or `list/tuple[torch.Tensor]`, it will be
                automatically moved to the current worker's device. For other types,
                you must handle device transfer manually in your training loop.
                Note: This function is called in a multi-threaded context; avoid using
                thread-unsafe code.
"""

elif isinstance(collate_fn, PandasBatchCollateFn):
batch_format = "pandas"
elif callable(collate_fn):
batch_format = "numpy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should still keep raw collate fn? (and just default to the numpy version).

Base automatically changed from srinathk10-to_numpy-null-types to master April 24, 2025 19:11
Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Srinath Krishnamachari <[email protected]>
@srinathk10 srinathk10 changed the base branch from master to srinathk10-train-release-test-fixes April 25, 2025 06:45
@srinathk10 srinathk10 changed the base branch from srinathk10-train-release-test-fixes to master April 25, 2025 23:16
@srinathk10 srinathk10 changed the title WIP: Handle non-contiguous Tensors based GPU transfer Handle non-contiguous Tensors based GPU transfer Apr 28, 2025
@srinathk10 srinathk10 marked this pull request as ready for review April 28, 2025 20:27
@srinathk10 srinathk10 requested a review from a team as a code owner April 28, 2025 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants