forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_tvm_ffi.py
More file actions
156 lines (136 loc) · 6.54 KB
/
test_tvm_ffi.py
File metadata and controls
156 lines (136 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import platform
import unittest
from typing import TYPE_CHECKING
import numpy as np
import tvm_ffi.cpp
import paddle
if TYPE_CHECKING:
from tvm_ffi import Module
class TestTVMFFIEnvStream(unittest.TestCase):
def test_tvm_ffi_env_stream_for_gpu_tensor(self):
if not paddle.is_compiled_with_cuda():
return
tensor = paddle.to_tensor([1.0, 2.0, 3.0]).cuda()
current_raw_stream_ptr = tensor.__tvm_ffi_env_stream__()
self.assertIsInstance(current_raw_stream_ptr, int)
self.assertNotEqual(current_raw_stream_ptr, 0)
def test_tvm_ffi_env_stream_for_cpu_tensor(self):
tensor = paddle.to_tensor([1.0, 2.0, 3.0]).cpu()
with self.assertRaisesRegex(
RuntimeError, r"the __tvm_ffi_env_stream__ method"
):
tensor.__tvm_ffi_env_stream__()
class TestCDLPackExchangeAPI(unittest.TestCase):
def test_c_dlpack_exchange_api_cpu(self):
cpp_source = r"""
void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";
for (int i = 0; i < x->shape[0]; ++i) {
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
}
}
"""
mod: Module = tvm_ffi.cpp.load_inline(
name='mod', cpp_sources=cpp_source, functions='add_one_cpu'
)
x = paddle.full((3,), 1.0, dtype='float32').cpu()
y = paddle.zeros((3,), dtype='float32').cpu()
mod.add_one_cpu(x, y)
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
def test_c_dlpack_exchange_api_gpu(self):
if not paddle.is_compiled_with_cuda():
return
if paddle.is_compiled_with_rocm():
# Skip on DCU because CUDA_HOME is not available
return
if platform.system() == "Windows":
# Temporary skip this test case on windows because compile bug on TVM FFI
return
cpp_sources = r"""
void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y);
"""
cuda_sources = r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1;
}
}
void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";
int64_t n = x->shape[0];
int64_t nthread_per_block = 256;
int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
// Obtain the current stream from the environment by calling TVMFFIEnvGetStream
cudaStream_t stream = static_cast<cudaStream_t>(
TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
// launch the kernel
AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(static_cast<float*>(x->data),
static_cast<float*>(y->data), n);
}
"""
mod: Module = tvm_ffi.cpp.load_inline(
name='mod',
cpp_sources=cpp_sources,
cuda_sources=cuda_sources,
functions=['add_one_cuda'],
)
x = paddle.full((3,), 1.0, dtype='float32').cuda()
y = paddle.zeros((3,), dtype='float32').cuda()
mod.add_one_cuda(x, y)
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
def test_c_dlpack_exchange_api_alloc_tensor(self):
if platform.system() == "Windows":
# Temporary skip this test case on windows because return owned tensor created by
# TVMFFIEnvGetTensorAllocator will cause double free error
return
cpp_source = r"""
inline tvm::ffi::Tensor alloc_tensor(tvm::ffi::Shape shape, DLDataType dtype, DLDevice device) {
return tvm::ffi::Tensor::FromDLPackAlloc(TVMFFIEnvGetTensorAllocator(), shape, dtype, device);
}
tvm::ffi::Tensor add_one_cpu(tvm::ffi::TensorView x) {
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
tvm::ffi::Shape x_shape(x->shape, x->shape + x->ndim);
tvm::ffi::Tensor y = alloc_tensor(x_shape, f32_dtype, x->device);
for (int i = 0; i < x->shape[0]; ++i) {
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
}
return y;
}
"""
mod: Module = tvm_ffi.cpp.load_inline(
name='mod', cpp_sources=cpp_source, functions=['add_one_cpu']
)
x = paddle.full((3,), 1.0, dtype='float32').cpu()
y = mod.add_one_cpu(x)
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
if __name__ == '__main__':
unittest.main()