Skip to content

Commit 5eba83d

Browse files
committed
and break the exploit again
1 parent 39bc324 commit 5eba83d

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

csrc/binding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
namespace nb = nanobind;
1212

1313

14-
void do_bench(int result_fd, const std::string& kernel_qualname, const nb::object& test_generator, const nb::dict& test_kwargs, int repeats, std::uint64_t seed, std::uintptr_t stream, bool discard, bool nvtx) {
15-
BenchmarkManager mgr(result_fd, seed, discard, nvtx);
14+
void do_bench(int result_fd, int signature_fd, const std::string& kernel_qualname, const nb::object& test_generator, const nb::dict& test_kwargs, int repeats, std::uint64_t seed, std::uintptr_t stream, bool discard, bool nvtx) {
15+
BenchmarkManager mgr(result_fd, signature_fd, seed, discard, nvtx);
1616
auto [args, expected] = mgr.setup_benchmark(nb::cast<nb::callable>(test_generator), test_kwargs, repeats);
1717
mgr.do_bench_py(kernel_qualname, args, expected, reinterpret_cast<cudaStream_t>(stream));
1818
}

csrc/manager.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static nb::callable kernel_from_qualname(const std::string& qualname) {
5252
return nb::cast<nb::callable>(mod.attr(attr.c_str()));
5353
}
5454

55-
BenchmarkManager::BenchmarkManager(int result_fd, std::uint64_t seed, bool discard, bool nvtx) {
55+
BenchmarkManager::BenchmarkManager(int result_fd, int signature_fd, std::uint64_t seed, bool discard, bool nvtx) {
5656
int device;
5757
CUDA_CHECK(cudaGetDevice(&device));
5858
CUDA_CHECK(cudaDeviceGetAttribute(&mL2CacheSize, cudaDevAttrL2CacheSize, device));
@@ -63,6 +63,11 @@ BenchmarkManager::BenchmarkManager(int result_fd, std::uint64_t seed, bool disca
6363
mNVTXEnabled = nvtx;
6464
mDiscardCache = discard;
6565
mSeed = seed;
66+
char sig_buf[256];
67+
FILE* sig_file = fdopen(signature_fd, "r");
68+
fgets(sig_buf, sizeof(sig_buf), sig_file);
69+
fclose(sig_file);
70+
mSignature = std::string(sig_buf);
6671
}
6772

6873
BenchmarkManager::~BenchmarkManager() {
@@ -371,6 +376,7 @@ void BenchmarkManager::do_bench_py(const std::string& kernel_qualname, const std
371376
CUDA_CHECK(cudaEventElapsedTime(&duration, mStartEvents.at(i), mEndEvents.at(i)));
372377
fprintf(mOutputFile, "%d\t%f\n", test_order.at(i) - 1, duration * 1000);
373378
}
379+
fprintf(mOutputFile, "signature\t%s", mSignature.c_str());
374380
fflush(mOutputFile);
375381

376382
// cleanup events

csrc/manager.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using nb_cuda_array = nb::ndarray<nb::c_contig, nb::device::cuda>;
1919

2020
class BenchmarkManager {
2121
public:
22-
BenchmarkManager(int result_fd, std::uint64_t seed, bool discard, bool nvtx);
22+
BenchmarkManager(int result_fd, int signature_fd, std::uint64_t seed, bool discard, bool nvtx);
2323
~BenchmarkManager();
2424
std::pair<std::vector<nb::tuple>, std::vector<nb::tuple>> setup_benchmark(const nb::callable& generate_test_case, const nb::dict& kwargs, int repeats);
2525
void do_bench_py(const std::string& kernel_qualname, const std::vector<nb::tuple>& args, const std::vector<nb::tuple>& expected, cudaStream_t stream);
@@ -67,6 +67,7 @@ class BenchmarkManager {
6767
std::vector<Expected> mExpectedOutputs;
6868

6969
FILE* mOutputFile;
70+
std::string mSignature;
7071

7172
static ShadowArgumentList make_shadow_args(const nb::tuple& args, cudaStream_t stream);
7273

python/pygpubench/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing as mp
55
import os
66
import traceback
7+
import secrets
78

89
from typing import Optional
910

@@ -24,12 +25,13 @@
2425
]
2526

2627

27-
def do_bench_impl(out_fd: "multiprocessing.Pipe", qualname: str, test_generator: TestGeneratorInterface,
28+
def do_bench_impl(out_fd: "multiprocessing.Pipe", signature: "multiprocessing.Pipe", qualname: str, test_generator: TestGeneratorInterface,
2829
test_args: dict, repeats: int, seed: int, stream: int = None, discard: bool = True,
29-
nvtx: bool = False, tb_conn=None):
30+
nvtx: bool = False, tb_conn: "multiprocessing.Pipe" = None):
3031
"""
3132
Benchmarks the kernel referred to by `qualname` against the test case returned by `test_generator`.
3233
:param out_fd: Writable file descriptor to which benchmark results are written.
34+
:param signature: Authentication token read by the C++ layer before untrusted code runs.
3335
:param qualname: Fully qualified name of the kernel object, e.g. ``my_package.my_module.kernel``.
3436
:param test_generator: A function that takes the test arguments (including a seed) and returns a test case; i.e., a tuple of (input, expected)
3537
:param test_args: keyword arguments to be passed to `test_generator`. Seed will be generated automatically.
@@ -48,6 +50,7 @@ def do_bench_impl(out_fd: "multiprocessing.Pipe", qualname: str, test_generator:
4850
with DeterministicContext():
4951
_pygpubench.do_bench(
5052
out_fd.fileno(),
53+
signature.fileno(),
5154
qualname,
5255
test_generator,
5356
test_args,
@@ -141,6 +144,11 @@ def do_bench_isolated(
141144
read_fd = result_parent.fileno()
142145
write_fd = result_child.fileno()
143146

147+
sig_r, sig_w = ctx.Pipe(duplex=False)
148+
signature = secrets.token_hex(16)
149+
os.write(sig_w.fileno(), signature.encode())
150+
sig_w.close()
151+
144152
try:
145153
import fcntl
146154
# F_SETPIPE_SZ is Linux-specific (1032); fall back silently on other OSes.
@@ -159,6 +167,7 @@ def do_bench_isolated(
159167
target=do_bench_impl,
160168
args=(
161169
result_child,
170+
sig_r,
162171
qualname,
163172
test_generator,
164173
test_args,
@@ -204,6 +213,7 @@ def do_bench_isolated(
204213
parent_tb_conn.close()
205214

206215
results = BenchmarkResult(None, [-1] * repeats, None, False)
216+
has_signature = False
207217
for line in raw.decode().splitlines():
208218
parts = line.strip().split('\t')
209219
if len(parts) == 2 and parts[0].isdigit():
@@ -214,5 +224,12 @@ def do_bench_isolated(
214224
results.event_overhead_us = float(parts[1].split()[0])
215225
elif parts[0] == "error-count":
216226
results.errors = int(parts[1])
227+
elif parts[0] == "signature":
228+
if signature != parts[1]:
229+
raise AssertionError(f"Invalid signature")
230+
has_signature = True
231+
if not has_signature:
232+
raise RuntimeError(f"No signature found in output")
233+
217234
results.full = all((t > 0 for t in results.time_us))
218235
return results

0 commit comments

Comments
 (0)