forked from apache/tvm-ffi
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.pxi
More file actions
458 lines (385 loc) · 15.9 KB
/
base.pxi
File metadata and controls
458 lines (385 loc) · 15.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import ctypes
from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, int16_t
from libc.string cimport memcpy
from libcpp.vector cimport vector
from cpython.bytes cimport PyBytes_AsStringAndSize, PyBytes_FromStringAndSize, PyBytes_AsString
from cpython cimport Py_INCREF, Py_DECREF
from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, PyGILState_Release, PyObject
from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone
cdef extern from "dlpack/dlpack.h":
int DLPACK_MAJOR_VERSION
int DLPACK_MINOR_VERSION
cdef enum:
kDLCPU = 1,
kDLCUDA = 2,
kDLCUDAHost = 3,
kDLOpenCL = 4,
kDLVulkan = 7,
kDLMetal = 8,
kDLVPI = 9,
kDLROCM = 10,
kDLROCMHost = 11,
kDLExtDev = 12,
kDLCUDAManaged = 13,
kDLOneAPI = 14,
kDLWebGPU = 15,
kDLHexagon = 16,
kDLMAIA = 17
kDLTrn = 18
ctypedef struct DLDataType:
uint8_t code
uint8_t bits
int16_t lanes
ctypedef struct DLDevice:
int device_type
int device_id
ctypedef struct DLTensor:
void* data
DLDevice device
int ndim
DLDataType dtype
int64_t* shape
int64_t* strides
uint64_t byte_offset
ctypedef struct DLPackVersion:
uint32_t major
uint32_t minor
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
ctypedef struct DLManagedTensorVersioned:
DLPackVersion version
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags
# DLPack Exchange API function pointer types
ctypedef int (*DLPackManagedTensorAllocator)(
DLTensor* prototype,
DLManagedTensorVersioned** out,
void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) noexcept
ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)(
void* py_object,
DLManagedTensorVersioned** out
) noexcept
ctypedef int (*DLPackManagedTensorToPyObjectNoSync)(
DLManagedTensorVersioned* tensor,
void** out_py_object
) noexcept
ctypedef int (*DLPackCurrentWorkStream)(
int device_type,
int32_t device_id,
void** out_current_stream
) noexcept
ctypedef int (*DLPackDLTensorFromPyObjectNoSync)(
void* py_object,
DLTensor* out
) noexcept
ctypedef struct DLPackExchangeAPIHeader:
DLPackVersion version
DLPackExchangeAPIHeader* prev_api
ctypedef struct DLPackExchangeAPI:
DLPackExchangeAPIHeader header
DLPackManagedTensorAllocator managed_tensor_allocator
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync
DLPackCurrentWorkStream current_work_stream
# Cython binding for TVM FFI C API
cdef extern from "tvm/ffi/c_api.h":
cdef enum TVMFFITypeIndex:
kTVMFFIAny = -1
kTVMFFINone = 0
kTVMFFIInt = 1
kTVMFFIBool = 2
kTVMFFIFloat = 3
kTVMFFIOpaquePtr = 4
kTVMFFIDataType = 5
kTVMFFIDevice = 6
kTVMFFIDLTensorPtr = 7
kTVMFFIRawStr = 8
kTVMFFIByteArrayPtr = 9
kTVMFFIObjectRValueRef = 10
kTVMFFISmallStr = 11
kTVMFFISmallBytes = 12
kTVMFFIStaticObjectBegin = 64
kTVMFFIObject = 64
kTVMFFIStr = 65
kTVMFFIBytes = 66
kTVMFFIError = 67
kTVMFFIFunction = 68
kTVMFFIShape = 69
kTVMFFITensor = 70
kTVMFFIArray = 71
kTVMFFIMap = 72
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74
ctypedef void* TVMFFIObjectHandle
ctypedef struct TVMFFIObject:
uint64_t combined_ref_count
int32_t type_index
uint32_t __padding
void (*deleter)(TVMFFIObject* self)
ctypedef struct TVMFFIAny:
int32_t type_index
int32_t zero_padding
int64_t v_int64
double v_float64
void* v_ptr
TVMFFIObject* v_obj
const char* v_c_str
DLDataType v_dtype
DLDevice v_device
ctypedef struct TVMFFIByteArray:
const char* data
size_t size
ctypedef struct TVMFFIOpaqueObjectCell:
void* handle
ctypedef struct TVMFFIShapeCell:
const int64_t* data
size_t size
ctypedef enum TVMFFIBacktraceUpdateMode:
kTVMFFIBacktraceUpdateModeReplace = 0
kTVMFFIBacktraceUpdateModeAppend = 1
ctypedef struct TVMFFIErrorCell:
TVMFFIByteArray kind
TVMFFIByteArray message
TVMFFIByteArray backtrace
void (*update_backtrace)(
TVMFFIObjectHandle self, const TVMFFIByteArray* backtrace, int32_t update_mode
)
ctypedef int (*TVMFFISafeCallType)(
void* handle, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) noexcept
cdef enum TVMFFIFieldFlagBitMask:
kTVMFFIFieldFlagBitMaskWritable = 1 << 0
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept
ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept
ctypedef struct TVMFFIFieldInfo:
TVMFFIByteArray name
TVMFFIByteArray doc
TVMFFIByteArray metadata
int64_t flags
int64_t size
int64_t alignment
int64_t offset
TVMFFIFieldGetter getter
TVMFFIFieldSetter setter
TVMFFIAny default_value
int32_t field_static_type_index
ctypedef struct TVMFFIMethodInfo:
TVMFFIByteArray name
TVMFFIByteArray doc
TVMFFIByteArray metadata
int64_t flags
TVMFFIAny method
ctypedef struct TVMFFITypeMetadata:
TVMFFIByteArray doc
TVMFFIObjectCreator creator
int64_t total_size
ctypedef struct TVMFFITypeInfo:
int32_t type_index
int32_t type_depth
TVMFFIByteArray type_key
const TVMFFITypeInfo** type_ancestors
uint64_t type_key_hash
int32_t num_fields
int32_t num_methods
const TVMFFIFieldInfo* fields
const TVMFFIMethodInfo* methods
const TVMFFITypeMetadata* metadata
int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil
int TVMFFIObjectIncRef(TVMFFIObjectHandle obj) nogil
int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index,
void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) nogil
int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call,
void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) nogil
int TVMFFIFunctionSetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) nogil
int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* out) nogil
void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil
void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil
int TVMFFIErrorCreate(TVMFFIByteArray* kind, TVMFFIByteArray* message,
TVMFFIByteArray* backtrace, TVMFFIObjectHandle* out) nogil
int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil
int TVMFFIStringFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil
int TVMFFIBytesFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil
const TVMFFIByteArray* TVMFFIBacktrace(const char* filename, int lineno,
const char* func, int cross_ffi_boundary) nogil
int TVMFFITensorFromDLPack(DLManagedTensor* src, int32_t require_alignment,
int32_t require_contiguous, TVMFFIObjectHandle* out) nogil
int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* src,
int32_t require_alignment,
int32_t require_contiguous,
TVMFFIObjectHandle* out) nogil
int TVMFFITensorToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) nogil
int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle src,
DLManagedTensorVersioned** out) nogil
const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil
TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil
TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) nogil
TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil
DLTensor* TVMFFITensorGetDLTensorPtr(TVMFFIObjectHandle obj) nogil
DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil
cdef extern from "tvm/ffi/extra/c_env_api.h":
ctypedef void* TVMFFIStreamHandle
int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) nogil
int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream) nogil
def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
cdef TVMFFIStreamHandle prev_stream = NULL
CHECK_CALL(TVMFFIEnvSetStream(
device_type,
device_id,
<void*>stream,
&prev_stream))
return <uint64_t>prev_stream
def _env_get_current_stream(int device_type, int device_id):
cdef void* current_stream
current_stream = TVMFFIEnvGetStream(device_type, device_id)
return <uint64_t>current_stream
cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context setter data structure
ctypedef int (*DLPackFromPyObject)(
void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
) except -1
ctypedef int (*DLPackToPyObject)(
DLManagedTensorVersioned* tensor, void** py_obj_out
) except -1
ctypedef int (*DLPackTensorAllocator)(
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) except -1
ctypedef struct TVMFFIPyCallContext:
int device_type
int device_id
TVMFFIStreamHandle stream
const DLPackExchangeAPI* c_dlpack_exchange_api
ctypedef struct TVMFFIPyArgSetter:
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1
const DLPackExchangeAPI* c_dlpack_exchange_api
ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1
# The main call function
int TVMFFIPyFuncCall(
TVMFFIPyArgSetterFactory setter_factory,
void* chandle,
PyObject* py_arg_tuple,
TVMFFIAny* result,
int* c_api_ret_code,
int release_gil,
DLPackToPyObject* out_dlpack_importer
) except -1
int TVMFFIPyConstructorCall(
TVMFFIPyArgSetterFactory setter_factory,
void* chandle,
PyObject* py_arg_tuple,
TVMFFIAny* result,
int* c_api_ret_code,
TVMFFIPyCallContext* parent_ctx
) except -1
int TVMFFIPyCallFieldSetter(
TVMFFIPyArgSetterFactory setter_factory,
TVMFFIFieldSetter field_setter,
void* field_ptr,
PyObject* py_arg,
int* c_api_ret_code
) except -1
int TVMFFIPyPyObjectToFFIAny(
TVMFFIPyArgSetterFactory setter_factory,
PyObject* py_arg,
TVMFFIAny* out,
int* c_api_ret_code
) except -1
size_t TVMFFIPyGetDispatchMapSize() noexcept
void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept
void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept
# the predefined setters for common POD types
int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
cdef class ByteArrayArg:
cdef TVMFFIByteArray cdata
cdef object py_data
def __cinit__(self, py_data):
if isinstance(py_data, bytearray):
py_data = bytes(py_data)
cdef char* data
cdef Py_ssize_t size
self.py_data = py_data
PyBytes_AsStringAndSize(py_data, &data, &size)
self.cdata.data = data
self.cdata.size = size
cdef inline TVMFFIByteArray* cptr(self):
return &self.cdata
cdef inline py_str(const char* x):
"""Convert a c_char_p to a python string
Parameters
----------
x : c_char_p
A char pointer that can be passed to C API
"""
return x.decode("utf-8")
cdef inline str bytearray_to_str(const TVMFFIByteArray* x):
return PyBytes_FromStringAndSize(x.data, x.size).decode("utf-8")
cdef inline bytes bytearray_to_bytes(const TVMFFIByteArray* x):
return PyBytes_FromStringAndSize(x.data, x.size)
cdef inline c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
v_ptr = handle.value
return <void*>(v_ptr)
cdef _init_env_api():
# Initialize env api for signal handling
# Also registers the gil state release and ensure as PyErr_CheckSignals
# function is called with gil released and we need to regrab the gil
CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyErr_CheckSignals"), <void*>PyErr_CheckSignals))
CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure))
CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release))
_init_env_api()