44import multiprocessing as mp
55import os
66import traceback
7+ import secrets
78
89from typing import Optional
910
2425]
2526
2627
27- def do_bench_impl (out_fd : "multiprocessing.Pipe" , qualname : str , test_generator : TestGeneratorInterface ,
28+ def do_bench_impl (out_fd : "multiprocessing.Pipe" , signature : "multiprocessing.Pipe" , qualname : str , test_generator : TestGeneratorInterface ,
2829 test_args : dict , repeats : int , seed : int , stream : int = None , discard : bool = True ,
29- nvtx : bool = False , tb_conn = None ):
30+ nvtx : bool = False , tb_conn : "multiprocessing.Pipe" = None ):
3031 """
3132 Benchmarks the kernel referred to by `qualname` against the test case returned by `test_generator`.
3233 :param out_fd: Writable file descriptor to which benchmark results are written.
34+ :param signature: Authentication token read by the C++ layer before untrusted code runs.
3335 :param qualname: Fully qualified name of the kernel object, e.g. ``my_package.my_module.kernel``.
3436 :param test_generator: A function that takes the test arguments (including a seed) and returns a test case; i.e., a tuple of (input, expected)
3537 :param test_args: keyword arguments to be passed to `test_generator`. Seed will be generated automatically.
@@ -48,6 +50,7 @@ def do_bench_impl(out_fd: "multiprocessing.Pipe", qualname: str, test_generator:
4850 with DeterministicContext ():
4951 _pygpubench .do_bench (
5052 out_fd .fileno (),
53+ signature .fileno (),
5154 qualname ,
5255 test_generator ,
5356 test_args ,
@@ -141,6 +144,11 @@ def do_bench_isolated(
141144 read_fd = result_parent .fileno ()
142145 write_fd = result_child .fileno ()
143146
147+ sig_r , sig_w = ctx .Pipe (duplex = False )
148+ signature = secrets .token_hex (16 )
149+ os .write (sig_w .fileno (), signature .encode ())
150+ sig_w .close ()
151+
144152 try :
145153 import fcntl
146154 # F_SETPIPE_SZ is Linux-specific (1032); fall back silently on other OSes.
@@ -159,6 +167,7 @@ def do_bench_isolated(
159167 target = do_bench_impl ,
160168 args = (
161169 result_child ,
170+ sig_r ,
162171 qualname ,
163172 test_generator ,
164173 test_args ,
@@ -204,6 +213,7 @@ def do_bench_isolated(
204213 parent_tb_conn .close ()
205214
206215 results = BenchmarkResult (None , [- 1 ] * repeats , None , False )
216+ has_signature = False
207217 for line in raw .decode ().splitlines ():
208218 parts = line .strip ().split ('\t ' )
209219 if len (parts ) == 2 and parts [0 ].isdigit ():
@@ -214,5 +224,12 @@ def do_bench_isolated(
214224 results .event_overhead_us = float (parts [1 ].split ()[0 ])
215225 elif parts [0 ] == "error-count" :
216226 results .errors = int (parts [1 ])
227+ elif parts [0 ] == "signature" :
228+ if signature != parts [1 ]:
229+ raise AssertionError (f"Invalid signature" )
230+ has_signature = True
231+ if not has_signature :
232+ raise RuntimeError (f"No signature found in output" )
233+
217234 results .full = all ((t > 0 for t in results .time_us ))
218235 return results
0 commit comments