From 4350a3bf64552bbfa68921c333bac3972de6044b Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Thu, 4 Jun 2026 02:37:47 -0400 Subject: [PATCH] Keep deferred copy count storage alive --- manual_server.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/manual_server.cpp b/manual_server.cpp index ed8eb2e..22e4ce3 100644 --- a/manual_server.cpp +++ b/manual_server.cpp @@ -607,14 +607,13 @@ static uint32_t lupine_count_pending_dtoh_copies(conn_t *conn, CUstream stream, return count; } -static int lupine_write_pending_dtoh_copies(conn_t *conn, CUstream stream, - bool all_streams, - bool write_count = true) { +static int lupine_write_pending_dtoh_copies(uint32_t *copy_count, conn_t *conn, + CUstream stream, + bool all_streams) { auto &pending = lupine_pending_dtoh_copies()[conn]; - if (write_count) { - uint32_t count = - lupine_count_pending_dtoh_copies(conn, stream, all_streams); - if (rpc_write(conn, &count, sizeof(count)) < 0) { + if (copy_count != nullptr) { + *copy_count = lupine_count_pending_dtoh_copies(conn, stream, all_streams); + if (rpc_write(conn, copy_count, sizeof(*copy_count)) < 0) { return -1; } } @@ -1890,8 +1889,9 @@ static void CUDA_CB lupine_stream_callback(CUstream stream, CUresult status, void *fn = reinterpret_cast(callback->callback); void *client_user_data = callback->userData; void *response = nullptr; + uint32_t copy_count = 0; if (rpc_write_start_request(conn, 2) >= 0 && - lupine_write_pending_dtoh_copies(conn, stream, false) >= 0 && + lupine_write_pending_dtoh_copies(©_count, conn, stream, false) >= 0 && rpc_write(conn, &stream, sizeof(stream)) >= 0 && rpc_write(conn, &status, sizeof(status)) >= 0 && rpc_write(conn, &fn, sizeof(fn)) >= 0 && @@ -2417,12 +2417,13 @@ int handle_manual_cuEventQuery(conn_t *conn) { if (rpc_write_start_response(conn, request_id) < 0) { return -1; } + uint32_t copy_count = 0; if (result == CUDA_SUCCESS) { - if (lupine_write_pending_dtoh_copies(conn, nullptr, true) < 0) { + if (lupine_write_pending_dtoh_copies(©_count, conn, nullptr, true) < + 0) { return -1; } } else { - uint32_t copy_count = 0; if (rpc_write(conn, ©_count, sizeof(copy_count)) < 0) { return -1; } @@ -3046,9 +3047,10 @@ int handle_manual_cuCtxSynchronize(conn_t *conn) { lupine_start_stdout_capture(&capture); CUresult result = cuCtxSynchronize(); lupine_finish_stdout_capture(&capture); + uint32_t copy_count = 0; uint64_t stdout_size = 0; if (rpc_write_start_response(conn, request_id) < 0 || - lupine_write_pending_dtoh_copies(conn, nullptr, true) < 0 || + lupine_write_pending_dtoh_copies(©_count, conn, nullptr, true) < 0 || lupine_write_captured_stdout(conn, capture, &stdout_size) < 0 || rpc_write(conn, &result, sizeof(result)) < 0 || rpc_write_end(conn) < 0) { return -1; @@ -3096,7 +3098,7 @@ int handle_manual_cuStreamSynchronize(conn_t *conn) { (copy.bytes != 0 && rpc_write(conn, copy.server_src, copy.bytes) < 0); })) || - lupine_write_pending_dtoh_copies(conn, nullptr, true, false) < 0 || + lupine_write_pending_dtoh_copies(nullptr, conn, nullptr, true) < 0 || lupine_write_captured_stdout(conn, capture, &stdout_size) < 0 || rpc_write(conn, &result, sizeof(result)) < 0 || rpc_write_end(conn) < 0) { return -1; @@ -3142,9 +3144,10 @@ int handle_manual_cuEventSynchronize(conn_t *conn) { lupine_start_stdout_capture(&capture); CUresult result = cuEventSynchronize(event); lupine_finish_stdout_capture(&capture); + uint32_t copy_count = 0; uint64_t stdout_size = 0; if (rpc_write_start_response(conn, request_id) < 0 || - lupine_write_pending_dtoh_copies(conn, nullptr, true) < 0 || + lupine_write_pending_dtoh_copies(©_count, conn, nullptr, true) < 0 || lupine_write_captured_stdout(conn, capture, &stdout_size) < 0 || rpc_write(conn, &result, sizeof(result)) < 0 || rpc_write_end(conn) < 0) { return -1;