-
Notifications
You must be signed in to change notification settings - Fork 234
/
Copy pathsimple_http_cudashm_client.cc
324 lines (278 loc) · 11 KB
/
simple_http_cudashm_client.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cuda_runtime_api.h>
#include <unistd.h>
#include <iostream>
#include <string>
#include "http_client.h"
#include "shm_utils.h"
namespace tc = triton::client;
#define FAIL_IF_ERR(X, MSG) \
{ \
tc::Error err = (X); \
if (!err.IsOk()) { \
std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
exit(1); \
} \
}
namespace {
void
ValidateShapeAndDatatype(
const std::string& name, std::shared_ptr<tc::InferResult> result)
{
std::vector<int64_t> shape;
FAIL_IF_ERR(
result->Shape(name, &shape), "unable to get shape for '" + name + "'");
// Validate shape
if ((shape.size() != 2) || (shape[0] != 1) || (shape[1] != 16)) {
std::cerr << "error: received incorrect shapes for '" << name << "'"
<< std::endl;
exit(1);
}
std::string datatype;
FAIL_IF_ERR(
result->Datatype(name, &datatype),
"unable to get datatype for '" + name + "'");
// Validate datatype
if (datatype.compare("INT32") != 0) {
std::cerr << "error: received incorrect datatype for '" << name
<< "': " << datatype << std::endl;
exit(1);
}
}
void
Usage(char** argv, const std::string& msg = std::string())
{
if (!msg.empty()) {
std::cerr << "error: " << msg << std::endl;
}
std::cerr << "Usage: " << argv[0] << " [options]" << std::endl;
std::cerr << "\t-v" << std::endl;
std::cerr << "\t-u <URL for inference service>" << std::endl;
std::cerr << "\t-H <HTTP header>" << std::endl;
std::cerr << std::endl;
std::cerr
<< "For -H, header must be 'Header:Value'. May be given multiple times."
<< std::endl;
exit(1);
}
} // namespace
#define FAIL_IF_CUDA_ERR(FUNC) \
{ \
const cudaError_t result = FUNC; \
if (result != cudaSuccess) { \
std::cerr << "CUDA exception (line " << __LINE__ \
<< "): " << cudaGetErrorName(result) << " (" \
<< cudaGetErrorString(result) << ")" << std::endl; \
exit(1); \
} \
}
void
CreateCUDAIPCHandle(
cudaIpcMemHandle_t* cuda_handle, void* input_d_ptr, int device_id = 0)
{
// Set the GPU device to the desired GPU
FAIL_IF_CUDA_ERR(cudaSetDevice(device_id));
// Create IPC handle for data on the gpu
FAIL_IF_CUDA_ERR(cudaIpcGetMemHandle(cuda_handle, input_d_ptr));
}
int
main(int argc, char** argv)
{
bool verbose = false;
std::string url("localhost:8000");
tc::Headers http_headers;
// Parse commandline...
int opt;
while ((opt = getopt(argc, argv, "vu:H:")) != -1) {
switch (opt) {
case 'v':
verbose = true;
break;
case 'u':
url = optarg;
break;
case 'H': {
std::string arg = optarg;
std::string header = arg.substr(0, arg.find(":"));
http_headers[header] = arg.substr(header.size() + 1);
break;
}
case '?':
Usage(argv);
break;
}
}
// We use a simple model that takes 2 input tensors of 16 integers
// each and returns 2 output tensors of 16 integers each. One output
// tensor is the element-wise sum of the inputs and one output is
// the element-wise difference.
std::string model_name = "simple";
std::string model_version = "";
// Create a InferenceServerHttpClient instance to communicate with the
// server using http protocol.
std::unique_ptr<tc::InferenceServerHttpClient> client;
FAIL_IF_ERR(
tc::InferenceServerHttpClient::Create(&client, url, verbose),
"unable to create http client");
// Unregistering all shared memory regions for a clean
// start.
FAIL_IF_ERR(
client->UnregisterSystemSharedMemory(),
"unable to unregister all system shared memory regions");
FAIL_IF_ERR(
client->UnregisterCudaSharedMemory(),
"unable to unregister all cuda shared memory regions");
std::vector<int64_t> shape{1, 16};
size_t input_byte_size = 64;
size_t output_byte_size = 64;
// Initialize the inputs with the data.
tc::InferInput* input0;
tc::InferInput* input1;
FAIL_IF_ERR(
tc::InferInput::Create(&input0, "INPUT0", shape, "INT32"),
"unable to get INPUT0");
std::shared_ptr<tc::InferInput> input0_ptr;
input0_ptr.reset(input0);
FAIL_IF_ERR(
tc::InferInput::Create(&input1, "INPUT1", shape, "INT32"),
"unable to get INPUT1");
std::shared_ptr<tc::InferInput> input1_ptr;
input1_ptr.reset(input1);
// Create Input0 and Input1 in CUDA Shared Memory. Initialize Input0 to
// unique integers and Input1 to all ones.
int input_data[32];
for (size_t i = 0; i < 16; ++i) {
input_data[i] = i;
input_data[16 + i] = 1;
}
// copy INPUT0 and INPUT1 data in GPU shared memory
int* input_d_ptr;
cudaMalloc((void**)&input_d_ptr, input_byte_size * 2);
cudaMemcpy(
(void*)input_d_ptr, (void*)input_data, input_byte_size * 2,
cudaMemcpyHostToDevice);
cudaIpcMemHandle_t input_cuda_handle;
CreateCUDAIPCHandle(&input_cuda_handle, (void*)input_d_ptr);
FAIL_IF_ERR(
client->RegisterCudaSharedMemory(
"input_data", input_cuda_handle, 0 /* device_id */,
input_byte_size * 2),
"failed to register input shared memory region");
FAIL_IF_ERR(
input0_ptr->SetSharedMemory(
"input_data", input_byte_size, 0 /* offset */),
"unable to set shared memory for INPUT0");
FAIL_IF_ERR(
input1_ptr->SetSharedMemory(
"input_data", input_byte_size, input_byte_size /* offset */),
"unable to set shared memory for INPUT1");
// Generate the outputs to be requested.
tc::InferRequestedOutput* output0;
tc::InferRequestedOutput* output1;
FAIL_IF_ERR(
tc::InferRequestedOutput::Create(&output0, "OUTPUT0"),
"unable to get 'OUTPUT0'");
std::shared_ptr<tc::InferRequestedOutput> output0_ptr;
output0_ptr.reset(output0);
FAIL_IF_ERR(
tc::InferRequestedOutput::Create(&output1, "OUTPUT1"),
"unable to get 'OUTPUT1'");
std::shared_ptr<tc::InferRequestedOutput> output1_ptr;
output1_ptr.reset(output1);
// Create Output0 and Output1 in CUDA Shared Memory
int *output0_d_ptr, *output1_d_ptr;
cudaMalloc((void**)&output0_d_ptr, output_byte_size * 2);
output1_d_ptr = (int*)output0_d_ptr + 16;
cudaIpcMemHandle_t output_cuda_handle;
CreateCUDAIPCHandle(&output_cuda_handle, (void*)output0_d_ptr);
FAIL_IF_ERR(
client->RegisterCudaSharedMemory(
"output_data", output_cuda_handle, 0 /* device_id */,
output_byte_size * 2),
"failed to register output shared memory region");
FAIL_IF_ERR(
output0_ptr->SetSharedMemory(
"output_data", output_byte_size, 0 /* offset */),
"unable to set shared memory for 'OUTPUT0'");
FAIL_IF_ERR(
output1_ptr->SetSharedMemory(
"output_data", output_byte_size, output_byte_size /* offset */),
"unable to set shared memory for 'OUTPUT1'");
// The inference settings. Will be using default for now.
tc::InferOptions options(model_name);
options.model_version_ = model_version;
std::vector<tc::InferInput*> inputs = {input0_ptr.get(), input1_ptr.get()};
std::vector<const tc::InferRequestedOutput*> outputs = {
output0_ptr.get(), output1_ptr.get()};
tc::InferResult* results;
FAIL_IF_ERR(
client->Infer(&results, options, inputs, outputs, http_headers),
"unable to run model");
std::shared_ptr<tc::InferResult> results_ptr;
results_ptr.reset(results);
// Validate the results...
ValidateShapeAndDatatype("OUTPUT0", results_ptr);
ValidateShapeAndDatatype("OUTPUT1", results_ptr);
// Copy input and output data back to the CPU
int output0_data[16], output1_data[16];
cudaMemcpy(
output0_data, output0_d_ptr, output_byte_size, cudaMemcpyDeviceToHost);
cudaMemcpy(
output1_data, output1_d_ptr, output_byte_size, cudaMemcpyDeviceToHost);
for (size_t i = 0; i < 16; ++i) {
std::cout << input_data[i] << " + " << input_data[16 + i] << " = "
<< output0_data[i] << std::endl;
std::cout << input_data[i] << " + " << input_data[16 + i] << " = "
<< output1_data[i] << std::endl;
if ((input_data[i] + input_data[16 + i]) != output0_data[i]) {
std::cerr << "error: incorrect sum" << std::endl;
exit(1);
}
if ((input_data[i] - input_data[16 + i]) != output1_data[i]) {
std::cerr << "error: incorrect difference" << std::endl;
exit(1);
}
}
// Get shared memory regions active/registered within triton
std::string shm_status;
FAIL_IF_ERR(
client->CudaSharedMemoryStatus(&shm_status),
"failed to get shared memory status");
std::cout << "Shared Memory Status:\n" << shm_status << "\n";
// Unregister shared memory
FAIL_IF_ERR(
client->UnregisterCudaSharedMemory("input_data"),
"unable to unregister shared memory input region");
FAIL_IF_ERR(
client->UnregisterCudaSharedMemory("output_data"),
"unable to unregister shared memory output region");
// Free GPU memory
FAIL_IF_CUDA_ERR(cudaFree(input_d_ptr));
FAIL_IF_CUDA_ERR(cudaFree(output0_d_ptr));
std::cout << "PASS : Cuda Shared Memory " << std::endl;
return 0;
}