1414from .common import summarize_statistics
1515from .gpu_events import do_bench_events
1616from .power import do_bench_power
17+ from .utils import estimate_cuda_runtime_ms , resolve_warmup_and_rep
1718
1819NS_TO_MS = 1e-6
1920logger = logging .getLogger (__name__ )
@@ -219,6 +220,7 @@ def _do_bench_cudagraph_with_cache_clear(
219220 end_event .record ()
220221 torch .cuda .synchronize ()
221222 estimate_ms = start_event .elapsed_time (end_event ) / 5
223+ _ , rep = resolve_warmup_and_rep (None , rep , estimate_ms )
222224
223225 n_repeat = 1000 if estimate_ms == 0 else max (1 , int (rep / estimate_ms ))
224226
@@ -301,13 +303,11 @@ def _do_bench_profiler(
301303 else None
302304 )
303305
304- # First, estimate the runtime to calculate iterations
305- estimate_ms = triton . testing . do_bench (
306+ clear_cache_fn = cache . zero_ if not skip_cache_clearing else lambda * args : None
307+ estimate_ms = estimate_cuda_runtime_ms (
306308 fn ,
307- warmup = warmup ,
308- rep = rep ,
309309 grad_to_none = grad_to_none ,
310- return_mode = "mean" ,
310+ clear_cache_fn = clear_cache_fn ,
311311 )
312312
313313 # Calculate number of iterations based on target rep time
@@ -316,8 +316,6 @@ def _do_bench_profiler(
316316 else :
317317 n_repeat = max (1 , int (rep / estimate_ms ))
318318
319- clear_cache_fn = cache .zero_ if not skip_cache_clearing else lambda * args : None
320-
321319 # Helper function to execute one iteration
322320 def run_iteration ():
323321 if grad_to_none is not None :
@@ -432,7 +430,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
432430
433431
434432def _do_bench_cpu (
435- fn , warmup , rep = 20 , grad_to_none = None , quantiles = None , return_mode = "mean"
433+ fn , warmup , rep , grad_to_none = None , quantiles = None , return_mode = "mean"
436434):
437435 """Measure latency of a function on CPU."""
438436 assert return_mode in ["min" , "max" , "mean" , "median" , "all" ]
@@ -474,8 +472,8 @@ def _do_bench_cpu(
474472
475473def _do_bench_entropy (
476474 fn ,
477- warmup = 25 ,
478- rep = 100 ,
475+ warmup ,
476+ rep ,
479477 grad_to_none = None ,
480478 quantiles = None ,
481479 return_mode = "mean" ,
@@ -528,6 +526,7 @@ def _do_bench_entropy(
528526 precision_increase = False
529527
530528 cache = triton .runtime .driver .active .get_empty_cache_for_benchmark ()
529+ clear_cache_fn = lambda : triton .runtime .driver .active .clear_cache (cache )
531530
532531 # Adaptive warmup loop with batched synchronization
533532 while True :
@@ -545,7 +544,7 @@ def _do_bench_entropy(
545544 if grad_to_none is not None :
546545 for x in grad_to_none :
547546 x .grad = None
548- triton . runtime . driver . active . clear_cache ( cache )
547+ clear_cache_fn ( )
549548 batch_start_events [i ].record ()
550549 fn ()
551550 batch_end_events [i ].record ()
@@ -619,7 +618,7 @@ def _do_bench_entropy(
619618 if grad_to_none is not None :
620619 for x in grad_to_none :
621620 x .grad = None
622- triton . runtime . driver . active . clear_cache ( cache )
621+ clear_cache_fn ( )
623622 start_events [i ].record ()
624623 fn ()
625624 end_events [i ].record ()
@@ -661,6 +660,14 @@ def do_bench_wrapper(
661660 entropy_window_size: Size of rolling window for entropy tracking
662661 entropy_max_samples: Maximum samples before stopping warmup (safety limit)
663662 """
663+ if (warmup is None or rep is None ) and not repcnt :
664+ estimate_runtime = estimate_cuda_runtime_ms (fn , grad_to_none = grad_to_none )
665+ warmup , rep = resolve_warmup_and_rep (
666+ warmup ,
667+ rep ,
668+ estimate_runtime ,
669+ )
670+
664671 try :
665672 if device == "cpu" :
666673 return Latency (
0 commit comments