Skip to content

Commit 1bbce15

Browse files
Kathryn-catroot
authored andcommitted
[DLPack] C Functions for DLPack Speed Exchange and Stream Handling (pytorch#165483)
## Addressed Issue Issue pytorch#162845 ## Summary of Changes This PR introduces a unified `DLPackExchangeAPI` struct as described in proposal [175](dmlc/dlpack#175). This new convention replaces the previous mechanism of separate function pointers, and aligns with the latest DLPack standard as shown in PR [174](dmlc/dlpack#174). Specifically, the new `DLPackExchangeAPI` struct is exposed as `torch.Tensor.__c_dlpack_exchange_api__`, which stores and exposes the following function pointers: * `managed_tensor_allocator` * `managed_tensor_from_py_object_no_sync` * `managed_tensor_to_py_object_no_sync` * `dltensor_from_py_object_no_sync` * `current_work_stream` Within the new `DLPackExchangeAPI` struct, the new `current_work_stream` function pointer allows more robust and integrated querying of the current device stream (e.g., CUDA stream) during DLPack tensor exchanges. All the conversion from/to DLPack has been updated to `_no_sync`, meaning you should use `current_work_stream` to explicitly handle stream synchronization. It also includes a non-owning DLTensor conversion `dltensor_from_py_object_no_sync` to avoid unnecessary reference counting. Following this change, the `dlpack.h` has been updated to the latest DLPack. Unit tests are added using `torch.utils.cpp_extension.load_inline` to avoid GIL release issues when calling `THPVariable_Wrap`. Pull Request resolved: pytorch#165483 Approved by: https://github.com/tqchen, https://github.com/albanD
1 parent 3089c80 commit 1bbce15

File tree

7 files changed

+666
-12
lines changed

7 files changed

+666
-12
lines changed

aten/src/ATen/DLConvertor.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ DLDevice torchDeviceToDLDevice(at::Device device) {
152152
return ctx;
153153
}
154154

155-
static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) {
155+
Device dlDeviceToTorchDevice(
156+
DLDeviceType type,
157+
c10::DeviceIndex index,
158+
void* data) {
156159
switch (type) {
157160
case DLDeviceType::kDLCPU:
158161
return at::Device(DeviceType::CPU);
@@ -437,7 +440,8 @@ at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
437440
}
438441

439442
DLTensor& dl_tensor = src->dl_tensor;
440-
Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data);
443+
Device device = dlDeviceToTorchDevice(
444+
dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data);
441445
ScalarType stype = toScalarType(dl_tensor.dtype);
442446

443447
if (!dl_tensor.strides) {
@@ -465,6 +469,21 @@ template at::Tensor fromDLPackImpl<DLManagedTensorVersioned>(DLManagedTensorVers
465469

466470
} // namespace
467471

472+
void toDLPackNonOwning(const Tensor& src, DLTensor* out) {
473+
// Fill in the pre-allocated DLTensor struct with direct pointers
474+
// This is a non-owning conversion - the caller owns the tensor
475+
// and must keep it alive for the duration of DLTensor usage
476+
out->data = src.data_ptr();
477+
out->device = torchDeviceToDLDevice(src.device());
478+
out->ndim = static_cast<int32_t>(src.dim());
479+
out->dtype = getDLDataType(src);
480+
// sizes() and strides() return pointers to TensorImpl's stable storage
481+
// which remains valid as long as the tensor is alive
482+
out->shape = const_cast<int64_t*>(src.sizes().data());
483+
out->strides = const_cast<int64_t*>(src.strides().data());
484+
out->byte_offset = 0;
485+
}
486+
468487
DLManagedTensor* toDLPack(const Tensor& src) {
469488
return toDLPackImpl<DLManagedTensor>(src);
470489
}
@@ -489,7 +508,7 @@ Tensor maybeCopyTensor(
489508
bool force_move = copy.has_value() && !*copy;
490509

491510
if (optional_dl_device.has_value()) {
492-
auto device = at::getATenDevice(
511+
auto device = at::dlDeviceToTorchDevice(
493512
optional_dl_device->device_type,
494513
static_cast<c10::DeviceIndex>(optional_dl_device->device_id));
495514

aten/src/ATen/DLConvertor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace at {
1313
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
1414
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
1515
TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
16+
TORCH_API void toDLPackNonOwning(const Tensor& src, DLTensor* out);
1617
TORCH_API Tensor
1718
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
1819
TORCH_API Tensor fromDLPackVersioned(
@@ -31,6 +32,12 @@ TORCH_API Tensor maybeCopyTensor(
3132
// Converts the given at::Device into a DLDevice.
3233
TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
3334

35+
// Converts the DLDevice to an ATen device.
36+
TORCH_API Device dlDeviceToTorchDevice(
37+
DLDeviceType type,
38+
c10::DeviceIndex index,
39+
void* data = nullptr);
40+
3441
// This trait class is used for retrieving different attributes, such as the
3542
// PyCapsule names and conversion functions for both DLPack tensor classes:
3643
// `DLManagedTensor` and `DLManagedTensorVersioned`.

0 commit comments

Comments
 (0)