Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions manual_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,13 @@ static uint32_t lupine_count_pending_dtoh_copies(conn_t *conn, CUstream stream,
return count;
}

static int lupine_write_pending_dtoh_copies(conn_t *conn, CUstream stream,
bool all_streams,
bool write_count = true) {
static int lupine_write_pending_dtoh_copies(uint32_t *copy_count, conn_t *conn,
CUstream stream,
bool all_streams) {
auto &pending = lupine_pending_dtoh_copies()[conn];
if (write_count) {
uint32_t count =
lupine_count_pending_dtoh_copies(conn, stream, all_streams);
if (rpc_write(conn, &count, sizeof(count)) < 0) {
if (copy_count != nullptr) {
*copy_count = lupine_count_pending_dtoh_copies(conn, stream, all_streams);
if (rpc_write(conn, copy_count, sizeof(*copy_count)) < 0) {
return -1;
}
}
Expand Down Expand Up @@ -1890,8 +1889,9 @@ static void CUDA_CB lupine_stream_callback(CUstream stream, CUresult status,
void *fn = reinterpret_cast<void *>(callback->callback);
void *client_user_data = callback->userData;
void *response = nullptr;
uint32_t copy_count = 0;
if (rpc_write_start_request(conn, 2) >= 0 &&
lupine_write_pending_dtoh_copies(conn, stream, false) >= 0 &&
lupine_write_pending_dtoh_copies(&copy_count, conn, stream, false) >= 0 &&
rpc_write(conn, &stream, sizeof(stream)) >= 0 &&
rpc_write(conn, &status, sizeof(status)) >= 0 &&
rpc_write(conn, &fn, sizeof(fn)) >= 0 &&
Expand Down Expand Up @@ -2417,12 +2417,13 @@ int handle_manual_cuEventQuery(conn_t *conn) {
if (rpc_write_start_response(conn, request_id) < 0) {
return -1;
}
uint32_t copy_count = 0;
if (result == CUDA_SUCCESS) {
if (lupine_write_pending_dtoh_copies(conn, nullptr, true) < 0) {
if (lupine_write_pending_dtoh_copies(&copy_count, conn, nullptr, true) <
0) {
return -1;
}
} else {
uint32_t copy_count = 0;
if (rpc_write(conn, &copy_count, sizeof(copy_count)) < 0) {
return -1;
}
Expand Down Expand Up @@ -3046,9 +3047,10 @@ int handle_manual_cuCtxSynchronize(conn_t *conn) {
lupine_start_stdout_capture(&capture);
CUresult result = cuCtxSynchronize();
lupine_finish_stdout_capture(&capture);
uint32_t copy_count = 0;
uint64_t stdout_size = 0;
if (rpc_write_start_response(conn, request_id) < 0 ||
lupine_write_pending_dtoh_copies(conn, nullptr, true) < 0 ||
lupine_write_pending_dtoh_copies(&copy_count, conn, nullptr, true) < 0 ||
lupine_write_captured_stdout(conn, capture, &stdout_size) < 0 ||
rpc_write(conn, &result, sizeof(result)) < 0 || rpc_write_end(conn) < 0) {
return -1;
Expand Down Expand Up @@ -3096,7 +3098,7 @@ int handle_manual_cuStreamSynchronize(conn_t *conn) {
(copy.bytes != 0 &&
rpc_write(conn, copy.server_src, copy.bytes) < 0);
})) ||
lupine_write_pending_dtoh_copies(conn, nullptr, true, false) < 0 ||
lupine_write_pending_dtoh_copies(nullptr, conn, nullptr, true) < 0 ||
lupine_write_captured_stdout(conn, capture, &stdout_size) < 0 ||
rpc_write(conn, &result, sizeof(result)) < 0 || rpc_write_end(conn) < 0) {
return -1;
Expand Down Expand Up @@ -3142,9 +3144,10 @@ int handle_manual_cuEventSynchronize(conn_t *conn) {
lupine_start_stdout_capture(&capture);
CUresult result = cuEventSynchronize(event);
lupine_finish_stdout_capture(&capture);
uint32_t copy_count = 0;
uint64_t stdout_size = 0;
if (rpc_write_start_response(conn, request_id) < 0 ||
lupine_write_pending_dtoh_copies(conn, nullptr, true) < 0 ||
lupine_write_pending_dtoh_copies(&copy_count, conn, nullptr, true) < 0 ||
lupine_write_captured_stdout(conn, capture, &stdout_size) < 0 ||
rpc_write(conn, &result, sizeof(result)) < 0 || rpc_write_end(conn) < 0) {
return -1;
Expand Down
Loading