Skip to content

Commit 3032d23

Browse files
committed
cleanup
1 parent 5eba83d commit 3032d23

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

csrc/manager.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,32 @@ BenchmarkManager::BenchmarkManager(int result_fd, int signature_fd, std::uint64_
5959
CUDA_CHECK(cudaMalloc(&mDeviceDummyMemory, 2 * mL2CacheSize));
6060
// allocate a large arena (2MiB) to place the error counter in
6161
CUDA_CHECK(cudaMalloc(&mDeviceErrorBase, ArenaSize));
62-
mOutputFile = fdopen(result_fd, "w");
62+
mOutputPipe = fdopen(result_fd, "w");
63+
if (!mOutputPipe) {
64+
throw std::runtime_error("Could not open output pipe");
65+
}
66+
6367
mNVTXEnabled = nvtx;
6468
mDiscardCache = discard;
6569
mSeed = seed;
6670
char sig_buf[256];
6771
FILE* sig_file = fdopen(signature_fd, "r");
68-
fgets(sig_buf, sizeof(sig_buf), sig_file);
72+
if (!sig_file) {
73+
throw std::runtime_error("Could not open signature pipe");
74+
}
75+
if (!fgets(sig_buf, sizeof(sig_buf), sig_file)) {
76+
fclose(sig_file);
77+
throw std::runtime_error("Could not read signature");
78+
}
6979
fclose(sig_file);
7080
mSignature = std::string(sig_buf);
7181
}
7282

7383
BenchmarkManager::~BenchmarkManager() {
74-
fclose(mOutputFile);
84+
if (mOutputPipe) {
85+
fclose(mOutputPipe);
86+
mOutputPipe = nullptr;
87+
}
7588
cudaFree(mDeviceDummyMemory);
7689
cudaFree(mDeviceErrorBase);
7790
for (auto& event : mStartEvents) cudaEventDestroy(event);
@@ -315,7 +328,7 @@ void BenchmarkManager::do_bench_py(const std::string& kernel_qualname, const std
315328
}
316329
std::sort(empty_event_times.begin(), empty_event_times.end());
317330
float median = empty_event_times.at(empty_event_times.size() / 2);
318-
fprintf(mOutputFile, "event-overhead\t%f µs\n", median * 1000);
331+
fprintf(mOutputPipe, "event-overhead\t%f µs\n", median * 1000);
319332

320333
// create a randomized order for running the tests
321334
std::vector<int> test_order(actual_calls);
@@ -368,16 +381,16 @@ void BenchmarkManager::do_bench_py(const std::string& kernel_qualname, const std
368381
error_count -= mErrorCountShift;
369382

370383
if (error_count > 0) {
371-
fprintf(mOutputFile, "error-count\t%u\n", error_count);
384+
fprintf(mOutputPipe, "error-count\t%u\n", error_count);
372385
}
373386

374387
for (int i = 0; i < actual_calls; i++) {
375388
float duration;
376389
CUDA_CHECK(cudaEventElapsedTime(&duration, mStartEvents.at(i), mEndEvents.at(i)));
377-
fprintf(mOutputFile, "%d\t%f\n", test_order.at(i) - 1, duration * 1000);
390+
fprintf(mOutputPipe, "%d\t%f\n", test_order.at(i) - 1, duration * 1000);
378391
}
379-
fprintf(mOutputFile, "signature\t%s", mSignature.c_str());
380-
fflush(mOutputFile);
392+
fprintf(mOutputPipe, "signature\t%s\n", mSignature.c_str());
393+
fflush(mOutputPipe);
381394

382395
// cleanup events
383396
for (auto& event : mStartEvents) CUDA_CHECK(cudaEventDestroy(event));

csrc/manager.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <functional>
99
#include <chrono>
10+
#include <cstdio>
1011
#include <fstream>
1112
#include <cuda_runtime.h>
1213
#include <optional>
@@ -66,7 +67,7 @@ class BenchmarkManager {
6667
std::uint64_t mSeed = -1;
6768
std::vector<Expected> mExpectedOutputs;
6869

69-
FILE* mOutputFile;
70+
FILE* mOutputPipe = nullptr;
7071
std::string mSignature;
7172

7273
static ShadowArgumentList make_shadow_args(const nb::tuple& args, cudaStream_t stream);

python/pygpubench/__init__.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import dataclasses
22
import math
3-
import multiprocessing
43
import multiprocessing as mp
54
import os
65
import traceback
76
import secrets
87

9-
from typing import Optional
8+
from typing import Optional, TYPE_CHECKING
109

1110
from . import _pygpubench
1211
from ._types import *
1312
from .utils import DeterministicContext
1413

14+
if TYPE_CHECKING:
15+
import multiprocessing.connection
16+
17+
1518
__all__ = [
1619
"do_bench_impl",
1720
"do_bench_isolated",
@@ -25,9 +28,9 @@
2528
]
2629

2730

28-
def do_bench_impl(out_fd: "multiprocessing.Pipe", signature: "multiprocessing.Pipe", qualname: str, test_generator: TestGeneratorInterface,
31+
def do_bench_impl(out_fd: "multiprocessing.connection.Connection", signature: "multiprocessing.connection.Connection", qualname: str, test_generator: TestGeneratorInterface,
2932
test_args: dict, repeats: int, seed: int, stream: int = None, discard: bool = True,
30-
nvtx: bool = False, tb_conn: "multiprocessing.Pipe" = None):
33+
nvtx: bool = False, tb_conn: "multiprocessing.connection.Connection" = None):
3134
"""
3235
Benchmarks the kernel referred to by `qualname` against the test case returned by `test_generator`.
3336
:param out_fd: Writable file descriptor to which benchmark results are written.
@@ -119,6 +122,13 @@ def basic_stats(time_us: list[float]) -> BenchmarkStats:
119122
return BenchmarkStats(runs, len(time_us), fastest, slowest, median, mean, std, err)
120123

121124

125+
def read_all(fd: int) -> str:
126+
chunks = []
127+
while chunk := os.read(fd, 65536):
128+
chunks.append(chunk)
129+
return (b"".join(chunks)).decode()
130+
131+
122132
def do_bench_isolated(
123133
qualname: str,
124134
test_generator: TestGeneratorInterface,
@@ -189,6 +199,7 @@ def do_bench_isolated(
189199
if process.is_alive():
190200
process.kill()
191201
process.join()
202+
parent_tb_conn.close()
192203
result_parent.close()
193204
raise RuntimeError(
194205
f"Benchmark subprocess timed out after {timeout}s -- "
@@ -208,25 +219,37 @@ def do_bench_isolated(
208219
raise RuntimeError(msg)
209220

210221
# Child has exited and closed its write-end, so this read is bounded.
211-
raw = os.read(read_fd, _PIPE_CAPACITY)
222+
response = read_all(read_fd)
212223
result_parent.close()
213224
parent_tb_conn.close()
214225

215226
results = BenchmarkResult(None, [-1] * repeats, None, False)
216227
has_signature = False
217-
for line in raw.decode().splitlines():
218-
parts = line.strip().split('\t')
219-
if len(parts) == 2 and parts[0].isdigit():
228+
for line in response.splitlines():
229+
line = line.strip()
230+
if len(line) == 0:
231+
continue
232+
parts = line.split('\t')
233+
if len(parts) != 2:
234+
raise RuntimeError(f"Invalid benchmark output: {line}")
235+
if has_signature:
236+
raise RuntimeError(f"Unexpected output after signature: {line}")
237+
238+
if parts[0].isdigit():
220239
iteration = int(parts[0])
221240
time_us = float(parts[1])
241+
if results.time_us[iteration] != -1:
242+
raise RuntimeError(f"Duplicate iteration {iteration} in benchmark output")
222243
results.time_us[iteration] = time_us
223244
elif parts[0] == "event-overhead":
224245
results.event_overhead_us = float(parts[1].split()[0])
225246
elif parts[0] == "error-count":
247+
if results.errors is not None:
248+
raise RuntimeError(f"Duplicate error count in benchmark output")
226249
results.errors = int(parts[1])
227250
elif parts[0] == "signature":
228251
if signature != parts[1]:
229-
raise AssertionError(f"Invalid signature")
252+
raise RuntimeError("Benchmark subprocess output failed authentication: invalid signature")
230253
has_signature = True
231254
if not has_signature:
232255
raise RuntimeError(f"No signature found in output")

0 commit comments

Comments
 (0)