@@ -39,9 +39,8 @@ static int init = 0;
39
39
static jmp_buf catch_segfault;
40
40
static void *faulting_address = nullptr ;
41
41
42
- conn_t *rpc_client_get_connection (unsigned int index) { return &conns[index ]; }
43
-
44
42
std::map<conn_t *, std::map<void *, size_t >> unified_devices;
43
+ std::map<void *, void *> host_funcs;
45
44
46
45
static void segfault (int sig, siginfo_t *info, void *unused) {
47
46
faulting_address = info->si_addr ;
@@ -152,6 +151,98 @@ void rpc_close(conn_t *conn) {
152
151
pthread_mutex_unlock (&conn_mutex);
153
152
}
154
153
154
+ typedef void (*func_t )(void *);
155
+
156
+ void add_host_node (void *fn, void *udata) { host_funcs[fn] = udata; }
157
+
158
+ void invoke_host_func (void *fn) {
159
+ for (const auto &pair : host_funcs) {
160
+ if (pair.first == fn) {
161
+ func_t func = reinterpret_cast <func_t >(pair.first );
162
+ std::cout << " Invoking function at: " << pair.first << std::endl;
163
+ func (pair.second );
164
+ return ;
165
+ }
166
+ }
167
+ }
168
+
169
+ void *rpc_client_dispatch_thread (void *arg) {
170
+ conn_t *conn = (conn_t *)arg;
171
+ int op;
172
+
173
+ while (true ) {
174
+ op = rpc_dispatch (conn, 1 );
175
+
176
+ if (op == 1 ) {
177
+ std::cout << " Transferring memory..." << std::endl;
178
+
179
+ int found = 0 ;
180
+
181
+ rpc_read (conn, &found, sizeof (int ));
182
+
183
+ if (found > 0 ) {
184
+ void *host_data = nullptr ;
185
+ void *dst = nullptr ;
186
+ const void *src = nullptr ;
187
+ size_t count = 0 ;
188
+ cudaError_t result;
189
+ int request_id;
190
+
191
+ if (rpc_read (conn, &dst, sizeof (void *)) < 0 ||
192
+ rpc_read (conn, &count, sizeof (size_t )) < 0 ) {
193
+ std::cerr << " Failed to read transfer parameters." << std::endl;
194
+ break ;
195
+ }
196
+
197
+ host_data = malloc (count);
198
+ if (!host_data) {
199
+ std::cerr << " Memory allocation failed." << std::endl;
200
+ break ;
201
+ }
202
+
203
+ // Read the actual data from the server (sent from `src` in device
204
+ // memory)
205
+ if (rpc_read (conn, host_data, count) < 0 ) {
206
+ std::cerr << " Failed to read device data from server." << std::endl;
207
+ free (host_data);
208
+ break ;
209
+ }
210
+
211
+ // Copy received data to the destination (dst) on the host
212
+ memcpy (dst, host_data, count);
213
+ }
214
+
215
+ void *temp_mem;
216
+ if (rpc_read (conn, &temp_mem, sizeof (void *)) <= 0 ) {
217
+ std::cerr << " rpc_read failed for mem. Closing connection."
218
+ << std::endl;
219
+ break ;
220
+ }
221
+
222
+ int request_id = rpc_read_end (conn);
223
+ void *mem = temp_mem;
224
+
225
+ if (mem == nullptr ) {
226
+ std::cerr << " Invalid function pointer!" << std::endl;
227
+ continue ;
228
+ }
229
+
230
+ invoke_host_func (mem);
231
+
232
+ void *res = nullptr ;
233
+ if (rpc_write_start_response (conn, request_id) < 0 ||
234
+ rpc_write (conn, &res, sizeof (void *)) < 0 ||
235
+ rpc_write_end (conn) < 0 ) {
236
+ std::cerr << " rpc_write failed. Closing connection." << std::endl;
237
+ break ;
238
+ }
239
+ }
240
+ }
241
+
242
+ std::cerr << " Exiting dispatch thread due to an error." << std::endl;
243
+ return nullptr ;
244
+ }
245
+
155
246
int rpc_open () {
156
247
set_segfault_handlers ();
157
248
@@ -166,6 +257,8 @@ int rpc_open() {
166
257
return 0 ;
167
258
}
168
259
260
+ std::cout << " Opening connection to server" << std::endl;
261
+
169
262
char *server_ips = getenv (" SCUDA_SERVER" );
170
263
if (server_ips == NULL ) {
171
264
printf (" SCUDA_SERVER environment variable not set\n " );
@@ -214,19 +307,13 @@ int rpc_open() {
214
307
exit (1 );
215
308
}
216
309
217
- std::cout << " connected on " << sockfd << std::endl;
218
-
219
- conns[nconns] = {sockfd,
220
- 0 ,
221
- 0 ,
222
- 0 ,
223
- 0 ,
224
- 0 ,
225
- PTHREAD_MUTEX_INITIALIZER,
226
- PTHREAD_MUTEX_INITIALIZER,
227
- PTHREAD_COND_INITIALIZER};
310
+ conns[nconns] = {sockfd, 0 };
311
+ if (pthread_mutex_init (&conns[nconns].read_mutex , NULL ) < 0 ||
312
+ pthread_mutex_init (&conns[nconns].write_mutex , NULL ) < 0 ) {
313
+ return -1 ;
314
+ }
228
315
229
- pthread_create (&conns[nconns].read_thread , NULL , rpc_read_thread ,
316
+ pthread_create (&conns[nconns].read_thread , NULL , rpc_client_dispatch_thread ,
230
317
(void *)&conns[nconns]);
231
318
232
319
nconns++;
@@ -239,6 +326,12 @@ int rpc_open() {
239
326
return 0 ;
240
327
}
241
328
329
+ conn_t *rpc_client_get_connection (unsigned int index) {
330
+ if (rpc_open () < 0 )
331
+ return nullptr ;
332
+ return &conns[index ];
333
+ }
334
+
242
335
int rpc_size () { return nconns; }
243
336
244
337
void allocate_unified_mem_pointer (conn_t *conn, void *dev_ptr, size_t size) {
@@ -325,6 +418,8 @@ CUresult cuGetProcAddress_v2(const char *symbol, void **pfn, int cudaVersion,
325
418
}
326
419
327
420
void *dlsym (void *handle, const char *name) __THROW {
421
+ std::cout << " dlsym: " << name << std::endl;
422
+
328
423
void *func = get_function_pointer (name);
329
424
330
425
/* * proc address function calls are basically dlsym; we should handle this
@@ -335,8 +430,8 @@ void *dlsym(void *handle, const char *name) __THROW {
335
430
}
336
431
337
432
if (func != nullptr ) {
338
- // std::cout << "[dlsym] Function address from cudaFunctionMap: " << func <<
339
- // " " << name << std::endl;
433
+ // std::cout << "[dlsym] Function address from cudaFunctionMap: " << func
434
+ // << " " << name << std::endl;
340
435
return func;
341
436
}
342
437
0 commit comments