forked from dmlc/dlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdlpack.h
More file actions
627 lines (601 loc) · 22.3 KB
/
dlpack.h
File metadata and controls
627 lines (601 loc) · 22.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
/*!
* Copyright (c) 2017 - by Contributors
* \file dlpack.h
* \brief The common header of DLPack.
*/
#ifndef DLPACK_DLPACK_H_
#define DLPACK_DLPACK_H_
/**
* \brief Compatibility with C++
*/
#ifdef __cplusplus
#define DLPACK_EXTERN_C extern "C"
#else
#define DLPACK_EXTERN_C
#endif
/*! \brief The current major version of dlpack */
#define DLPACK_MAJOR_VERSION 1
/*! \brief The current minor version of dlpack */
#define DLPACK_MINOR_VERSION 1
/*! \brief DLPACK_DLL prefix for windows */
#ifdef _WIN32
#ifdef DLPACK_EXPORTS
#define DLPACK_DLL __declspec(dllexport)
#else
#define DLPACK_DLL __declspec(dllimport)
#endif
#else
#define DLPACK_DLL
#endif
#include <stdint.h>
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
#endif
/*!
* \brief The DLPack version.
*
* A change in major version indicates that we have changed the
* data layout of the ABI - DLManagedTensorVersioned.
*
* A change in minor version indicates that we have added new
* code, such as a new device type, but the ABI is kept the same.
*
* If an obtained DLPack tensor has a major version that disagrees
* with the version number specified in this header file
* (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
* (and it is safe to do so). It is not safe to access any other fields
* as the memory layout will have changed.
*
* In the case of a minor version mismatch, the tensor can be safely used as
* long as the consumer knows how to interpret all fields. Minor version
* updates indicate the addition of enumeration values.
*/
typedef struct {
/*! \brief DLPack major version. */
uint32_t major;
/*! \brief DLPack minor version. */
uint32_t minor;
} DLPackVersion;
/*!
* \brief The device type in DLDevice.
*/
#ifdef __cplusplus
typedef enum : int32_t {
#else
typedef enum {
#endif
/*! \brief CPU device */
kDLCPU = 1,
/*! \brief CUDA GPU device */
kDLCUDA = 2,
/*!
* \brief Pinned CUDA CPU memory by cudaMallocHost
*/
kDLCUDAHost = 3,
/*! \brief OpenCL devices. */
kDLOpenCL = 4,
/*! \brief Vulkan buffer for next generation graphics. */
kDLVulkan = 7,
/*! \brief Metal for Apple GPU. */
kDLMetal = 8,
/*! \brief Verilog simulator buffer */
kDLVPI = 9,
/*! \brief ROCm GPUs for AMD GPUs */
kDLROCM = 10,
/*!
* \brief Pinned ROCm CPU memory allocated by hipMallocHost
*/
kDLROCMHost = 11,
/*!
* \brief Reserved extension device type,
* used for quickly test extension device
* The semantics can differ depending on the implementation.
*/
kDLExtDev = 12,
/*!
* \brief CUDA managed/unified memory allocated by cudaMallocManaged
*/
kDLCUDAManaged = 13,
/*!
* \brief Unified shared memory allocated on a oneAPI non-partititioned
* device. Call to oneAPI runtime is required to determine the device
* type, the USM allocation type and the sycl context it is bound to.
*
*/
kDLOneAPI = 14,
/*! \brief GPU support for next generation WebGPU standard. */
kDLWebGPU = 15,
/*! \brief Qualcomm Hexagon DSP */
kDLHexagon = 16,
/*! \brief Microsoft MAIA devices */
kDLMAIA = 17,
/*! \brief AWS Trainium */
kDLTrn = 18,
} DLDeviceType;
/*!
* \brief A Device for Tensor and operator.
*/
typedef struct {
/*! \brief The device type used in the device. */
DLDeviceType device_type;
/*!
* \brief The device index.
* For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
*/
int32_t device_id;
} DLDevice;
/*!
* \brief The type code options DLDataType.
*/
typedef enum {
/*! \brief signed integer */
kDLInt = 0U,
/*! \brief unsigned integer */
kDLUInt = 1U,
/*! \brief IEEE floating point */
kDLFloat = 2U,
/*!
* \brief Opaque handle type, reserved for testing purposes.
* Frameworks need to agree on the handle data type for the exchange to be well-defined.
*/
kDLOpaqueHandle = 3U,
/*! \brief bfloat16 */
kDLBfloat = 4U,
/*!
* \brief complex number
* (C/C++/Python layout: compact struct per complex number)
*/
kDLComplex = 5U,
/*! \brief boolean */
kDLBool = 6U,
/*! \brief FP8 data types */
kDLFloat8_e3m4 = 7U,
kDLFloat8_e4m3 = 8U,
kDLFloat8_e4m3b11fnuz = 9U,
kDLFloat8_e4m3fn = 10U,
kDLFloat8_e4m3fnuz = 11U,
kDLFloat8_e5m2 = 12U,
kDLFloat8_e5m2fnuz = 13U,
kDLFloat8_e8m0fnu = 14U,
/*! \brief FP6 data types
* Setting bits != 6 is currently unspecified, and the producer must ensure it is set
* while the consumer must stop importing if the value is unexpected.
*/
kDLFloat6_e2m3fn = 15U,
kDLFloat6_e3m2fn = 16U,
/*! \brief FP4 data types
* Setting bits != 4 is currently unspecified, and the producer must ensure it is set
* while the consumer must stop importing if the value is unexpected.
*/
kDLFloat4_e2m1fn = 17U,
} DLDataTypeCode;
/*!
* \brief The data type the tensor can hold. The data type is assumed to follow the
* native endian-ness. An explicit error message should be raised when attempting to
* export an array with non-native endianness
*
* Examples
* - float: type_code = 2, bits = 32, lanes = 1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4
* - int8: type_code = 0, bits = 8, lanes = 1
* - std::complex<float>: type_code = 5, bits = 64, lanes = 1
* - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
* - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory)
* - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory)
* - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory)
*
* When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e.,
* for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element.
*/
typedef struct {
/*!
* \brief Type code of base types.
* We keep it uint8_t instead of DLDataTypeCode for minimal memory
* footprint, but the value should be one of DLDataTypeCode enum values.
* */
uint8_t code;
/*!
* \brief Number of bits, common choices are 8, 16, 32.
*/
uint8_t bits;
/*! \brief Number of lanes in the type, used for vector types. */
uint16_t lanes;
} DLDataType;
/*!
* \brief Plain C Tensor object, does not manage memory.
*/
typedef struct {
/*!
* \brief The data pointer points to the allocated data. This will be CUDA
* device pointer or cl_mem handle in OpenCL. It may be opaque on some device
* types. This pointer is always aligned to 256 bytes as in CUDA. The
* `byte_offset` field should be used to point to the beginning of the data.
*
* Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement
* on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed
* (after which this note will be updated); at the moment it is recommended
* to not rely on the data pointer being correctly aligned.
*
* For given DLTensor, the size of memory required to store the contents of
* data is calculated as follows:
*
* \code{.c}
* static inline size_t GetDataSize(const DLTensor* t) {
* size_t size = 1;
* for (tvm_index_t i = 0; i < t->ndim; ++i) {
* size *= t->shape[i];
* }
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
* return size;
* }
* \endcode
*
* Note that if the tensor is of size zero, then the data pointer should be
* set to `NULL`.
*/
void* data;
/*! \brief The device of the tensor */
DLDevice device;
/*! \brief Number of dimensions */
int32_t ndim;
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be NULL, indicating tensor is compact and row-majored.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;
/*!
* \brief C Tensor object, manage memory of DLTensor. This data structure is
* intended to facilitate the borrowing of DLTensor by another framework. It is
* not meant to transfer the tensor. When the borrowing framework doesn't need
* the tensor, it should call the deleter to notify the host that the resource
* is no longer needed.
*
* \note This data structure is used as Legacy DLManagedTensor
* in DLPack exchange and is deprecated after DLPack v0.8
* Use DLManagedTensorVersioned instead.
* This data structure may get renamed or deleted in future versions.
*
* \sa DLManagedTensorVersioned
*/
typedef struct DLManagedTensor {
/*! \brief DLTensor which is being memory managed */
DLTensor dl_tensor;
/*! \brief the context of the original host framework of DLManagedTensor in
* which DLManagedTensor is used in the framework. It can also be NULL.
*/
void * manager_ctx;
/*!
* \brief Destructor - this should be called
* to destruct the manager_ctx which backs the DLManagedTensor. It can be
* NULL if there is no way for the caller to provide a reasonable destructor.
* The destructor deletes the argument self as well.
*/
void (*deleter)(struct DLManagedTensor * self);
} DLManagedTensor;
// bit masks used in the DLManagedTensorVersioned
/*! \brief bit mask to indicate that the tensor is read only. */
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
/*!
* \brief bit mask to indicate that the tensor is a copy made by the producer.
*
* If set, the tensor is considered solely owned throughout its lifetime by the
* consumer, until the producer-provided deleter is invoked.
*/
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
/*!
* \brief bit mask to indicate that whether a sub-byte type is packed or padded.
*
* The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can
* be set by the producer to signal that a tensor of sub-byte type is padded.
*/
#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL)
/*!
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
*
* This data structure is intended to facilitate the borrowing of DLTensor by
* another framework. It is not meant to transfer the tensor. When the borrowing
* framework doesn't need the tensor, it should call the deleter to notify the
* host that the resource is no longer needed.
*
* \note This is the current standard DLPack exchange data structure.
*/
typedef struct DLManagedTensorVersioned {
/*!
* \brief The API and ABI version of the current managed Tensor
*/
DLPackVersion version;
/*!
* \brief the context of the original host framework.
*
* Stores DLManagedTensorVersioned is used in the
* framework. It can also be NULL.
*/
void *manager_ctx;
/*!
* \brief Destructor.
*
* This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
* It can be NULL if there is no way for the caller to provide a reasonable
* destructor. The destructor deletes the argument self as well.
*/
void (*deleter)(struct DLManagedTensorVersioned *self);
/*!
* \brief Additional bitmask flags information about the tensor.
*
* By default the flags should be set to 0.
*
* \note Future ABI changes should keep everything until this field
* stable, to ensure that deleter can be correctly called.
*
* \sa DLPACK_FLAG_BITMASK_READ_ONLY
* \sa DLPACK_FLAG_BITMASK_IS_COPIED
*/
uint64_t flags;
/*! \brief DLTensor which is being memory managed */
DLTensor dl_tensor;
} DLManagedTensorVersioned;
//----------------------------------------------------------------------
// DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions
//----------------------------------------------------------------------
/*!
* \brief Request a producer library to create a new tensor.
*
* Create a new `DLManagedTensorVersioned` within the context of the producer
* library. The allocation is defined via the prototype DLTensor.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param prototype The prototype DLTensor. Only the dtype, ndim, shape,
* and device fields are used.
* \param out The output DLManagedTensorVersioned.
* \param error_ctx Context for `SetError`.
* \param SetError The function to set the error.
* \return The owning DLManagedTensorVersioned* or NULL on failure.
* SetError is called exactly when NULL is returned (the implementor
* must ensure this).
* \note - As a C function, must not thrown C++ exceptions.
* - Error propagation via SetError to avoid any direct need
* of Python API. Due to this `SetError` may have to ensure the GIL is
* held since it will presumably set a Python error.
*
* \sa DLPackExchangeAPI
*/
typedef int (*DLPackManagedTensorAllocator)( //
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, //
void (*SetError)(void* error_ctx, const char* kind, const char* message) //
);
/*!
* \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned.
*
* This function does not perform any stream synchronization. The consumer should query
* DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param py_object The Python object to convert. Must have the same type
* as the one the `DLPackExchangeAPI` was discovered from.
* \return The owning DLManagedTensorVersioned* or NULL on failure with a
* Python exception set. If the data cannot be described using DLPack
* this should be a BufferError if possible.
* \note - As a C function, must not thrown C++ exceptions.
*
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
*/
typedef int (*DLPackManagedTensorFromPyObjectNoSync)( //
void* py_object, //
DLManagedTensorVersioned** out //
);
/*!
* \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor.
*
* This function provides a faster interface for temporary, non-owning, exchange.
* The producer (implementor) still owns the memory of data, strides, shape.
* The liveness of the DLTensor and the data it views is only guaranteed until
* control is returned.
*
* This function currently assumes that the producer (implementor) can fill
* in the DLTensor shape and strides without the need for temporary allocations.
*
* This function does not perform any stream synchronization. The consumer should query
* DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param py_object The Python object to convert. Must have the same type
* as the one the `DLPackExchangeAPI` was discovered from.
* \param out The output DLTensor, whose space is pre-allocated on stack.
* \return 0 on success, -1 on failure with a Python exception set.
* \note - As a C function, must not thrown C++ exceptions.
*
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
*/
typedef int (*DLPackDLTensorFromPyObjectNoSync)( //
void* py_object, //
DLTensor* out //
);
/*!
* \brief Obtain the current work stream of a device.
*
* Obtain the current work stream of a device from the producer framework.
* For example, it should map to torch.cuda.current_stream in PyTorch.
*
* When device_type is kDLCPU, the consumer do not have to query the stream
* and the producer can simply return NULL when queried.
* The consumer do not have to do anything on stream sync or setting.
* So CPU only framework can just provide a dummy implementation that
* always set out_current_stream[0] to NULL.
*
* \param device_type The device type.
* \param device_id The device id.
* \param out_current_stream The output current work stream.
*
* \return 0 on success, -1 on failure with a Python exception set.
* \note - As a C function, must not thrown C++ exceptions.
*
* \sa DLPackExchangeAPI
*/
typedef int (*DLPackCurrentWorkStream)( //
DLDeviceType device_type, //
int32_t device_id, //
void** out_current_stream //
);
/*!
* \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray.
*
* Convert an owning DLManagedTensorVersioned* to the Python tensor of the
* producer (implementor) library with the correct type.
*
* This function does not perform any stream synchronization.
*
* This function is exposed by the framework through the DLPackExchangeAPI.
*
* \param tensor The DLManagedTensorVersioned to convert the ownership of the
* tensor is stolen.
* \param out_py_object The output Python object.
* \return 0 on success, -1 on failure with a Python exception set.
*
* \sa DLPackExchangeAPI
*/
typedef int (*DLPackManagedTensorToPyObjectNoSync)( //
DLManagedTensorVersioned* tensor, //
void** out_py_object //
);
/*!
* \brief DLPackExchangeAPI stable header.
* \sa DLPackExchangeAPI
*/
typedef struct DLPackExchangeAPIHeader {
/*!
* \brief The provided DLPack version the consumer must check major version
* compatibility before using this struct.
*/
DLPackVersion version;
/*!
* \brief Optional pointer to an older DLPackExchangeAPI in the chain.
*
* It must be NULL if the framework does not support older versions.
* If the current major version is larger than the one supported by the
* consumer, the consumer may walk this to find an earlier supported version.
*
* \sa DLPackExchangeAPI
*/
struct DLPackExchangeAPIHeader* prev_version_api;
} DLPackExchangeAPIHeader;
/*!
* \brief Framework-specific function pointers table for DLPack exchange.
*
* Additionally to `__dlpack__()` we define a C function table sharable by
* Python implementations via `__c_dlpack_exchange_api__`.
* This attribute must be set on the type as a Python integer compatible
* with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`.
*
* A consumer library may use a pattern such as:
*
* \code
*
* PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code
* MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj);
* if (api == NULL && PyErr_Occurred()) { goto handle_error; }
*
* \endcode
*
* Note that this must be defined on the type. The consumer should look up the
* attribute on the type and may cache the result for each unique type.
*
* The precise API table is given by:
* \code
* struct MyDLPackExchangeAPI : public DLPackExchangeAPI {
* MyDLPackExchangeAPI() {
* header.version.major = DLPACK_MAJOR_VERSION;
* header.version.minor = DLPACK_MINOR_VERSION;
* header.prev_version_api = nullptr;
*
* managed_tensor_allocator = MyDLPackManagedTensorAllocator;
* managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync;
* managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync;
* dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync;
* current_work_stream = MyDLPackCurrentWorkStream;
* }
*
* static const DLPackExchangeAPI* Global() {
* static MyDLPackExchangeAPI inst;
* return &inst;
* }
* };
* \endcode
*
* Guidelines for leveraging DLPackExchangeAPI:
*
* There are generally two kinds of consumer needs for DLPack exchange:
* - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel
* with the data from x, y, z. The consumer is also expected to run the kernel with the same
* stream context as the producer. For example, when x, y, z is torch.Tensor,
* consumer should query exchange_api->current_work_stream to get the
* current stream and launch the kernel with the same stream.
* This setup is necessary for no synchronization in kernel launch and maximum compatibility
* with CUDA graph capture in the producer.
* This is the desirable behavior for library extension support for frameworks like PyTorch.
* - N1: data ingestion and retention
*
* Note that obj.__dlpack__() API should provide useful ways for N1.
* The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0
* with the support of the function pointer current_work_stream.
*
* Array/Tensor libraries should statically create and initialize this structure
* then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
* The DLPackExchangeAPI* must stay alive throughout the lifetime of the process.
*
* One simple way to do so is to create a static instance of DLPackExchangeAPI
* within the framework and return a pointer to it. The following code
* shows an example to do so in C++. It should also be reasonably easy
* to do so in other languages.
*/
typedef struct DLPackExchangeAPI {
/*!
* \brief The header that remains stable across versions.
*/
DLPackExchangeAPIHeader header;
/*!
* \brief Producer function pointer for DLPackManagedTensorAllocator
* This function must not be NULL.
* \sa DLPackManagedTensorAllocator
*/
DLPackManagedTensorAllocator managed_tensor_allocator;
/*!
* \brief Producer function pointer for DLPackManagedTensorFromPyObject
* This function must be not NULL.
* \sa DLPackManagedTensorFromPyObject
*/
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackManagedTensorToPyObject
* This function must be not NULL.
* \sa DLPackManagedTensorToPyObject
*/
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackDLTensorFromPyObject
* This function can be NULL when the producer does not support this function.
* \sa DLPackDLTensorFromPyObjectNoSync
*/
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackCurrentWorkStream
* This function must be not NULL.
* \sa DLPackCurrentWorkStream
*/
DLPackCurrentWorkStream current_work_stream;
} DLPackExchangeAPI;
#ifdef __cplusplus
} // DLPACK_EXTERN_C
#endif
#endif // DLPACK_DLPACK_H_