Skip to content

Commit a97a468

Browse files
committed
grpc implementation additional coderabbit review fixes
1 parent eebd7d8 commit a97a468

File tree

10 files changed

+53
-21
lines changed

10 files changed

+53
-21
lines changed

GRPC_QUICK_START.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ With solver options:
194194
cuopt_cli model.mps --time-limit 30 --relaxation
195195
```
196196

197-
### C API
197+
### C++ API
198198

199-
```c
199+
```cpp
200200
#include <cuopt/linear_programming/solve.hpp>
201201
#include <cuopt/linear_programming/cpu_optimization_problem.hpp>
202202

@@ -212,11 +212,10 @@ gRPC server when they are set.
212212

213213
| Symptom | Check |
214214
|---------|-------|
215-
| `CUOPT_REMOTE_HOST and/or CUOPT_REMOTE_PORT not set` | Both env vars must be set for remote execution. |
216215
| Connection refused | Verify the server is running and the host/port are correct. |
217216
| TLS handshake failure | Ensure `CUOPT_TLS_ENABLED=1` is set and certificate paths are correct. |
218-
| Timeout on large problems | Increase `CUOPT_CHUNK_SIZE` or `CUOPT_MAX_MESSAGE_BYTES`. |
219217
| `Cannot open TLS file: ...` | The path in the TLS env var does not exist or is not readable. |
218+
| Timeout on large problems | Increase the solver `time_limit` or the client `timeout_seconds`. |
220219

221220
## Further Reading
222221

ci/utils/install_protobuf_grpc.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ while [[ $# -gt 0 ]]; do
6969
esac
7070
done
7171

72+
if [[ -z "$PREFIX" || "$PREFIX" == "/" ]]; then
73+
echo "ERROR: Invalid PREFIX: '$PREFIX'" >&2
74+
exit 1
75+
fi
76+
if [[ -z "$BUILD_DIR" || "$BUILD_DIR" == "/" ]]; then
77+
echo "ERROR: Invalid BUILD_DIR: '$BUILD_DIR'" >&2
78+
exit 1
79+
fi
80+
7281
echo "=============================================="
7382
echo "Installing gRPC ${GRPC_VERSION} from source"
7483
echo " Prefix: ${PREFIX}"

cpp/src/grpc/client/grpc_client.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,28 +217,20 @@ void grpc_client_t::start_log_streaming(const std::string& job_id)
217217

218218
void grpc_client_t::stop_log_streaming()
219219
{
220-
constexpr auto kLogJoinTimeout = std::chrono::seconds(5);
221-
222220
stop_logs_.store(true);
223-
// Cancel the in-flight streaming RPC from this thread. reader->Read()
224-
// blocks until the server sends a message, so the stop_logs_ flag alone
225-
// is not enough — TryCancel makes Read() return false immediately.
221+
// Cancel the in-flight streaming RPC so reader->Read() returns false
222+
// immediately instead of blocking until the server sends a message.
226223
{
227224
std::lock_guard<std::mutex> lk(log_context_mutex_);
228225
if (active_log_context_) {
229226
static_cast<grpc::ClientContext*>(active_log_context_)->TryCancel();
230227
}
231228
}
232-
if (log_thread_ && log_thread_->joinable()) {
233-
auto future = std::async(std::launch::async, [this]() { log_thread_->join(); });
234-
if (future.wait_for(kLogJoinTimeout) == std::future_status::timeout) {
235-
GRPC_CLIENT_DEBUG_LOG(config_,
236-
"[grpc_client] WARNING: log streaming thread did not exit within "
237-
<< kLogJoinTimeout.count() << "s; detaching");
238-
log_thread_->detach();
239-
}
240-
}
241-
log_thread_.reset();
229+
// Move to local so we can join without racing against other callers.
230+
// TryCancel above guarantees the thread will unblock promptly.
231+
std::unique_ptr<std::thread> t;
232+
std::swap(t, log_thread_);
233+
if (t && t->joinable()) { t->join(); }
242234
}
243235

244236
// =============================================================================
@@ -1082,6 +1074,11 @@ bool grpc_client_t::download_chunked_result(const std::string& job_id,
10821074
int64_t elems_received = chunk_resp.elements_in_chunk();
10831075
const auto& data = chunk_resp.data();
10841076

1077+
if (elems_received < 0 || elems_received > elems_wanted ||
1078+
elems_received > total_elems - elem_offset) {
1079+
last_error_ = "GetResultChunk: invalid element count";
1080+
return false;
1081+
}
10851082
if (static_cast<int64_t>(data.size()) != elems_received * elem_size) {
10861083
last_error_ = "GetResultChunk: data size mismatch";
10871084
return false;

cpp/src/grpc/grpc_problem_mapper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ void map_chunked_arrays_to_problem(const cuopt::remote::ChunkedProblemHeader& he
447447
const char* nul = static_cast<const char*>(std::memchr(s, '\0', s_end - s));
448448
if (!nul) nul = s_end;
449449
names.emplace_back(s, nul);
450+
if (nul == s_end) break;
450451
s = nul + 1;
451452
}
452453
return names;

cpp/src/grpc/server/grpc_job_management.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ bool send_incumbent_pipe(int fd, const std::vector<uint8_t>& data)
4949

5050
bool recv_incumbent_pipe(int fd, std::vector<uint8_t>& data)
5151
{
52+
static constexpr uint64_t kMaxIncumbentBytes = 256ULL * 1024 * 1024;
5253
uint64_t size;
5354
if (!read_from_pipe(fd, &size, sizeof(size))) return false;
55+
if (size > kMaxIncumbentBytes) return false;
5456
data.resize(size);
5557
if (size > 0 && !read_from_pipe(fd, data.data(), size)) return false;
5658
return true;
@@ -88,6 +90,7 @@ std::pair<bool, std::string> submit_job_async(std::vector<uint8_t>&& request_dat
8890
job_queue[slot].worker_index.store(-1);
8991
job_queue[slot].data_sent.store(false);
9092
job_queue[slot].is_chunked = false;
93+
job_queue[slot].worker_pid = 0;
9194

9295
{
9396
std::lock_guard<std::mutex> lock(pending_data_mutex);
@@ -135,6 +138,7 @@ std::pair<bool, std::string> submit_chunked_job_async(PendingChunkedUpload&& chu
135138
job_queue[slot].worker_index.store(-1);
136139
job_queue[slot].data_sent.store(false);
137140
job_queue[slot].is_chunked = true;
141+
job_queue[slot].worker_pid = 0;
138142

139143
{
140144
std::lock_guard<std::mutex> lock(pending_data_mutex);

cpp/src/grpc/server/grpc_pipe_serialization.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
// may silently cap this to /proc/sys/fs/pipe-max-size.
2323
static constexpr int kPipeBufferSize = 1024 * 1024;
2424

25+
static constexpr uint64_t kMaxPipeArrayBytes = 4ULL * 1024 * 1024 * 1024;
26+
static constexpr uint32_t kMaxPipeArrayFields = 10000;
27+
2528
// Pipe I/O primitives defined in grpc_job_management.cpp.
2629
bool write_to_pipe(int fd, const void* data, size_t size);
2730
bool read_from_pipe(int fd, void* data, size_t size, int timeout_ms = 120000);
@@ -135,13 +138,15 @@ inline bool read_chunked_request_from_pipe(int fd,
135138

136139
uint32_t num_arrays;
137140
if (!read_from_pipe(fd, &num_arrays, sizeof(num_arrays))) return false;
141+
if (num_arrays > kMaxPipeArrayFields) return false;
138142

139143
// Read each field's raw bytes directly into the output map, keyed by field_id.
140144
for (uint32_t i = 0; i < num_arrays; ++i) {
141145
int32_t field_id;
142146
uint64_t total_bytes;
143147
if (!read_from_pipe(fd, &field_id, sizeof(field_id))) return false;
144148
if (!read_from_pipe(fd, &total_bytes, sizeof(total_bytes))) return false;
149+
if (total_bytes > kMaxPipeArrayBytes) return false;
145150
auto& dest = arrays_out[field_id];
146151
dest.resize(static_cast<size_t>(total_bytes));
147152
if (total_bytes > 0 && !read_from_pipe(fd, dest.data(), static_cast<size_t>(total_bytes)))
@@ -188,12 +193,14 @@ inline bool read_result_from_pipe(int fd,
188193

189194
uint32_t num_arrays;
190195
if (!read_from_pipe(fd, &num_arrays, sizeof(num_arrays))) return false;
196+
if (num_arrays > kMaxPipeArrayFields) return false;
191197

192198
for (uint32_t i = 0; i < num_arrays; ++i) {
193199
int32_t field_id;
194200
uint64_t total_bytes;
195201
if (!read_from_pipe(fd, &field_id, sizeof(field_id))) return false;
196202
if (!read_from_pipe(fd, &total_bytes, sizeof(total_bytes))) return false;
203+
if (total_bytes > kMaxPipeArrayBytes) return false;
197204
auto& dest = arrays_out[field_id];
198205
dest.resize(static_cast<size_t>(total_bytes));
199206
if (total_bytes > 0 && !read_from_pipe(fd, dest.data(), static_cast<size_t>(total_bytes)))

cpp/src/grpc/server/grpc_server_main.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ int main(int argc, char** argv)
156156
std::cerr << "[Server] Failed to mmap control: " << strerror(errno) << "\n";
157157
return 1;
158158
}
159+
new (shm_ctrl) SharedMemoryControl{};
159160

160161
for (size_t i = 0; i < MAX_JOBS; ++i) {
161162
new (&job_queue[i]) JobQueueEntry{};
@@ -222,8 +223,15 @@ int main(int argc, char** argv)
222223
creds = grpc::InsecureServerCredentials();
223224
}
224225

226+
signal(SIGPIPE, SIG_IGN);
225227
spawn_workers();
226228

229+
if (worker_pids.empty()) {
230+
std::cerr << "[Server] No workers started; exiting\n";
231+
cleanup_shared_memory();
232+
return 1;
233+
}
234+
227235
std::thread result_thread(result_retrieval_thread);
228236
std::thread incumbent_thread(incumbent_retrieval_thread);
229237
std::thread monitor_thread(worker_monitor_thread);

cpp/src/grpc/server/grpc_worker_infra.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ void spawn_workers()
147147
void wait_for_workers()
148148
{
149149
for (pid_t pid : worker_pids) {
150+
if (pid <= 0) continue;
150151
int status;
151152
while (waitpid(pid, &status, 0) < 0 && errno == EINTR) {}
152153
}

cpp/tests/linear_programming/grpc/grpc_test_log_capture.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class GrpcTestLogCapture {
370370
std::vector<LogEntry> client_logs_;
371371
std::string server_log_path_;
372372
std::streampos server_log_start_pos_ = 0; // Position in server log file when test started
373-
bool test_start_marked_ = false;
373+
std::atomic<bool> test_start_marked_{false};
374374
};
375375

376376
} // namespace cuopt::linear_programming::testing

python/cuopt/cuopt/tests/linear_programming/test_cpu_only_execution.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def _wait_for_port(port, timeout=15):
7777
def _cpu_only_env(port):
7878
"""Return an env dict that hides all GPUs and enables remote mode."""
7979
env = os.environ.copy()
80+
for key in [k for k in env if k.startswith("CUOPT_TLS_")]:
81+
env.pop(key)
8082
env["CUDA_VISIBLE_DEVICES"] = ""
8183
env["CUOPT_REMOTE_HOST"] = "localhost"
8284
env["CUOPT_REMOTE_PORT"] = str(port)
@@ -110,14 +112,18 @@ def _run(cmd):
110112
server_key = os.path.join(cert_dir, "server.key")
111113
server_csr = os.path.join(cert_dir, "server.csr")
112114
server_crt = os.path.join(cert_dir, "server.crt")
115+
server_ext = os.path.join(cert_dir, "server.ext")
113116
if not _run(
114117
f"openssl req -newkey rsa:2048 -keyout {server_key} -out {server_csr} "
115118
f"-nodes -subj '/CN=localhost' 2>/dev/null"
116119
):
117120
return False
121+
with open(server_ext, "w") as f:
122+
f.write("subjectAltName=DNS:localhost,IP:127.0.0.1\n")
118123
if not _run(
119124
f"openssl x509 -req -in {server_csr} -CA {ca_crt} -CAkey {ca_key} "
120-
f"-CAcreateserial -out {server_crt} -days 1 2>/dev/null"
125+
f"-CAcreateserial -out {server_crt} -days 1 "
126+
f"-extfile {server_ext} 2>/dev/null"
121127
):
122128
return False
123129

0 commit comments

Comments
 (0)