Skip to content

Support loading straight to device for JAX / FLAX #636

@adonath

Description

@adonath

Feature request

It seems loading straight to a device is currently only supported for PyTorch. I would be great to have the support for JAX as well.

Motivation

Loading arrays into JAX currently goes via Numpy which forces the arrays into CPU memory and the creates a copy to load into the device memory. This introduces additional overhead.

Your contribution

As a temporary workaround I think one can to load the data using PyTorch and then donate the buffers to a JAX array using the dlpack protocol (which enables buffer donation on GPUs). This requires PyTorch to be installed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions