@@ -240,7 +240,7 @@ BenchmarkManager::ShadowArgument& BenchmarkManager::ShadowArgument::operator=(Sh
240240 return *this ;
241241}
242242
243- void BenchmarkManager::do_bench_py (const std::string& kernel_qualname, const std::vector<nb::tuple>& args, const std::vector<nb::tuple>& expected, cudaStream_t stream) {
243+ void BenchmarkManager::do_bench_py (const std::string& kernel_qualname, const std::vector<nb::tuple>& args, std::vector<nb::tuple> expected, cudaStream_t stream) {
244244 if (args.size () < 5 ) {
245245 throw std::runtime_error (" Not enough test cases to run benchmark" );
246246 }
@@ -284,6 +284,11 @@ void BenchmarkManager::do_bench_py(const std::string& kernel_qualname, const std
284284 }
285285 }
286286
287+ // The benchmark loop only needs the unmanaged output copies after this point.
288+ // Release Python-held expected tuples before importing untrusted code.
289+ expected.clear ();
290+ expected.shrink_to_fit ();
291+
287292 // clean up as much python state as we can
288293 trigger_gc ();
289294
@@ -295,10 +300,14 @@ void BenchmarkManager::do_bench_py(const std::string& kernel_qualname, const std
295300 // after this, we cannot trust python anymore
296301 nb::callable kernel = kernel_from_qualname (kernel_qualname);
297302
303+ std::random_device warmup_rd;
304+ std::mt19937 warmup_rng (warmup_rd ());
305+ std::uniform_int_distribution<int > warmup_dist (0 , static_cast <int >(args.size ()) - 1 );
306+
298307 // ok, first run for compilations etc
299308 nvtx_push (" warmup" );
300309 CUDA_CHECK (cudaDeviceSynchronize ());
301- kernel (*args.at (0 ));
310+ kernel (*args.at (warmup_dist (warmup_rng) ));
302311 CUDA_CHECK (cudaDeviceSynchronize ());
303312 nvtx_pop ();
304313
@@ -312,7 +321,7 @@ void BenchmarkManager::do_bench_py(const std::string& kernel_qualname, const std
312321 // this is only potentially problematic for in-place kernels;
313322 CUDA_CHECK (cudaDeviceSynchronize ());
314323 clear_cache (stream);
315- kernel (*args.at (0 ));
324+ kernel (*args.at (warmup_dist (warmup_rng) ));
316325 CUDA_CHECK (cudaDeviceSynchronize ());
317326 std::chrono::high_resolution_clock::time_point cpu_end = std::chrono::high_resolution_clock::now ();
318327 std::chrono::duration<double > elapsed_seconds = cpu_end - cpu_start;
0 commit comments