Feature request
It seems loading straight to a device is currently only supported for PyTorch. I would be great to have the support for JAX as well.
Motivation
Loading arrays into JAX currently goes via Numpy which forces the arrays into CPU memory and the creates a copy to load into the device memory. This introduces additional overhead.
Your contribution
As a temporary workaround I think one can to load the data using PyTorch and then donate the buffers to a JAX array using the dlpack protocol (which enables buffer donation on GPUs). This requires PyTorch to be installed.