11import dataclasses
22import math
3- import multiprocessing
43import multiprocessing as mp
54import os
65import traceback
76import secrets
87
9- from typing import Optional
8+ from typing import Optional , TYPE_CHECKING
109
1110from . import _pygpubench
1211from ._types import *
1312from .utils import DeterministicContext
1413
14+ if TYPE_CHECKING :
15+ import multiprocessing .connection
16+
17+
1518__all__ = [
1619 "do_bench_impl" ,
1720 "do_bench_isolated" ,
2528]
2629
2730
28- def do_bench_impl (out_fd : "multiprocessing.Pipe " , signature : "multiprocessing.Pipe " , qualname : str , test_generator : TestGeneratorInterface ,
31+ def do_bench_impl (out_fd : "multiprocessing.connection.Connection " , signature : "multiprocessing.connection.Connection " , qualname : str , test_generator : TestGeneratorInterface ,
2932 test_args : dict , repeats : int , seed : int , stream : int = None , discard : bool = True ,
30- nvtx : bool = False , tb_conn : "multiprocessing.Pipe " = None ):
33+ nvtx : bool = False , tb_conn : "multiprocessing.connection.Connection " = None ):
3134 """
3235 Benchmarks the kernel referred to by `qualname` against the test case returned by `test_generator`.
3336 :param out_fd: Writable file descriptor to which benchmark results are written.
@@ -119,6 +122,13 @@ def basic_stats(time_us: list[float]) -> BenchmarkStats:
119122 return BenchmarkStats (runs , len (time_us ), fastest , slowest , median , mean , std , err )
120123
121124
125+ def read_all (fd : int ) -> str :
126+ chunks = []
127+ while chunk := os .read (fd , 65536 ):
128+ chunks .append (chunk )
129+ return (b"" .join (chunks )).decode ()
130+
131+
122132def do_bench_isolated (
123133 qualname : str ,
124134 test_generator : TestGeneratorInterface ,
@@ -189,6 +199,7 @@ def do_bench_isolated(
189199 if process .is_alive ():
190200 process .kill ()
191201 process .join ()
202+ parent_tb_conn .close ()
192203 result_parent .close ()
193204 raise RuntimeError (
194205 f"Benchmark subprocess timed out after { timeout } s -- "
@@ -208,25 +219,37 @@ def do_bench_isolated(
208219 raise RuntimeError (msg )
209220
210221 # Child has exited and closed its write-end, so this read is bounded.
211- raw = os . read (read_fd , _PIPE_CAPACITY )
222+ response = read_all (read_fd )
212223 result_parent .close ()
213224 parent_tb_conn .close ()
214225
215226 results = BenchmarkResult (None , [- 1 ] * repeats , None , False )
216227 has_signature = False
217- for line in raw .decode ().splitlines ():
218- parts = line .strip ().split ('\t ' )
219- if len (parts ) == 2 and parts [0 ].isdigit ():
228+ for line in response .splitlines ():
229+ line = line .strip ()
230+ if len (line ) == 0 :
231+ continue
232+ parts = line .split ('\t ' )
233+ if len (parts ) != 2 :
234+ raise RuntimeError (f"Invalid benchmark output: { line } " )
235+ if has_signature :
236+ raise RuntimeError (f"Unexpected output after signature: { line } " )
237+
238+ if parts [0 ].isdigit ():
220239 iteration = int (parts [0 ])
221240 time_us = float (parts [1 ])
241+ if results .time_us [iteration ] != - 1 :
242+ raise RuntimeError (f"Duplicate iteration { iteration } in benchmark output" )
222243 results .time_us [iteration ] = time_us
223244 elif parts [0 ] == "event-overhead" :
224245 results .event_overhead_us = float (parts [1 ].split ()[0 ])
225246 elif parts [0 ] == "error-count" :
247+ if results .errors is not None :
248+ raise RuntimeError (f"Duplicate error count in benchmark output" )
226249 results .errors = int (parts [1 ])
227250 elif parts [0 ] == "signature" :
228251 if signature != parts [1 ]:
229- raise AssertionError ( f"Invalid signature" )
252+ raise RuntimeError ( "Benchmark subprocess output failed authentication: invalid signature" )
230253 has_signature = True
231254 if not has_signature :
232255 raise RuntimeError (f"No signature found in output" )
0 commit comments