- 
                Notifications
    
You must be signed in to change notification settings  - Fork 730
 
Description
Feature description
Support for consuming and exporting DLPack (DLManagedTensor) tensors. This would allow for zero-copy tensor sharing between Burn and other frameworks.
https://dmlc.github.io/dlpack/latest/
Feature motivation
We're building libraries (https://metatensor.org/) that bridge multiple frameworks (PyTorch, JAX, and hopefully Burn via our Rust core). DLPack is the standard for zero-copy sharing.
Right now, moving a JAX tensor to Burn would require a jax -> numpy (cpu) -> burn (gpu) copy, or a jax -> torch -> dlpack -> metatensor-torch flow which is clumsy.
Direct DLPack support in Burn would let us (and others) build backends that can operate on JAX, PyTorch, or CuPy memory in-place, which is critical for performance.
(Optional) Suggest a Solution
This will be unsafe and highly backend-specific.
- 
A new constructor on
burn-tensor(or backend-specific tensor primitives) that can be initialized from aDLManagedTensorpointer. - 
This implementation would have to:
- 
Read the
DLManagedTensorstruct (device type, device id, pointer, shape, strides). - 
Find the corresponding Burn
Device. - 
The hard part: Use
unsafefunctions to wrap the existing native GPU buffer. 
 - 
 - 
For
burn-wgpu, this means using the underlyingwgpuinstance to import a native VulkanVkBuffer, MetalMTLBuffer, etc. This seems related to the work discussed inwgpufor underlying API interop: Proposal for Underlying Api Interoperability gfx-rs/wgpu#4067 - 
An equivalent
to_dlpack()method would be needed to export a Burn tensor, which would package its own buffer handle into aDLManagedTensorstruct with a valid deleter.