Skip to content

Commit 0205245

Browse files
authored
[warmup] Change default warmup and rep time to adaptive (#1040)
1 parent ec05cc8 commit 0205245

7 files changed

Lines changed: 113 additions & 54 deletions

File tree

tritonbench/components/do_bench/run.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .common import summarize_statistics
1515
from .gpu_events import do_bench_events
1616
from .power import do_bench_power
17+
from .utils import estimate_cuda_runtime_ms, resolve_warmup_and_rep
1718

1819
NS_TO_MS = 1e-6
1920
logger = logging.getLogger(__name__)
@@ -219,6 +220,7 @@ def _do_bench_cudagraph_with_cache_clear(
219220
end_event.record()
220221
torch.cuda.synchronize()
221222
estimate_ms = start_event.elapsed_time(end_event) / 5
223+
_, rep = resolve_warmup_and_rep(None, rep, estimate_ms)
222224

223225
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))
224226

@@ -301,13 +303,11 @@ def _do_bench_profiler(
301303
else None
302304
)
303305

304-
# First, estimate the runtime to calculate iterations
305-
estimate_ms = triton.testing.do_bench(
306+
clear_cache_fn = cache.zero_ if not skip_cache_clearing else lambda *args: None
307+
estimate_ms = estimate_cuda_runtime_ms(
306308
fn,
307-
warmup=warmup,
308-
rep=rep,
309309
grad_to_none=grad_to_none,
310-
return_mode="mean",
310+
clear_cache_fn=clear_cache_fn,
311311
)
312312

313313
# Calculate number of iterations based on target rep time
@@ -316,8 +316,6 @@ def _do_bench_profiler(
316316
else:
317317
n_repeat = max(1, int(rep / estimate_ms))
318318

319-
clear_cache_fn = cache.zero_ if not skip_cache_clearing else lambda *args: None
320-
321319
# Helper function to execute one iteration
322320
def run_iteration():
323321
if grad_to_none is not None:
@@ -432,7 +430,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
432430

433431

434432
def _do_bench_cpu(
435-
fn, warmup, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"
433+
fn, warmup, rep, grad_to_none=None, quantiles=None, return_mode="mean"
436434
):
437435
"""Measure latency of a function on CPU."""
438436
assert return_mode in ["min", "max", "mean", "median", "all"]
@@ -474,8 +472,8 @@ def _do_bench_cpu(
474472

475473
def _do_bench_entropy(
476474
fn,
477-
warmup=25,
478-
rep=100,
475+
warmup,
476+
rep,
479477
grad_to_none=None,
480478
quantiles=None,
481479
return_mode="mean",
@@ -528,6 +526,7 @@ def _do_bench_entropy(
528526
precision_increase = False
529527

530528
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()
529+
clear_cache_fn = lambda: triton.runtime.driver.active.clear_cache(cache)
531530

532531
# Adaptive warmup loop with batched synchronization
533532
while True:
@@ -545,7 +544,7 @@ def _do_bench_entropy(
545544
if grad_to_none is not None:
546545
for x in grad_to_none:
547546
x.grad = None
548-
triton.runtime.driver.active.clear_cache(cache)
547+
clear_cache_fn()
549548
batch_start_events[i].record()
550549
fn()
551550
batch_end_events[i].record()
@@ -619,7 +618,7 @@ def _do_bench_entropy(
619618
if grad_to_none is not None:
620619
for x in grad_to_none:
621620
x.grad = None
622-
triton.runtime.driver.active.clear_cache(cache)
621+
clear_cache_fn()
623622
start_events[i].record()
624623
fn()
625624
end_events[i].record()
@@ -661,6 +660,14 @@ def do_bench_wrapper(
661660
entropy_window_size: Size of rolling window for entropy tracking
662661
entropy_max_samples: Maximum samples before stopping warmup (safety limit)
663662
"""
663+
if (warmup is None or rep is None) and not repcnt:
664+
estimate_runtime = estimate_cuda_runtime_ms(fn, grad_to_none=grad_to_none)
665+
warmup, rep = resolve_warmup_and_rep(
666+
warmup,
667+
rep,
668+
estimate_runtime,
669+
)
670+
664671
try:
665672
if device == "cpu":
666673
return Latency(
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Callable, Iterable, Optional, Tuple
2+
3+
import torch
4+
from tritonbench.utils.constants import DEFAULT_WARMUP_REP_BY_ESTIMATED_KERNEL_MS
5+
6+
7+
def resolve_warmup_and_rep(
8+
warmup: Optional[int], rep: Optional[int], estimate_ms: float
9+
) -> Tuple[int, int]:
10+
if estimate_ms <= 1:
11+
default_warmup, default_rep = DEFAULT_WARMUP_REP_BY_ESTIMATED_KERNEL_MS["1"]
12+
elif estimate_ms <= 10:
13+
default_warmup, default_rep = DEFAULT_WARMUP_REP_BY_ESTIMATED_KERNEL_MS["10"]
14+
else:
15+
default_warmup, default_rep = DEFAULT_WARMUP_REP_BY_ESTIMATED_KERNEL_MS["100"]
16+
return (
17+
default_warmup if warmup is None else warmup,
18+
default_rep if rep is None else rep,
19+
)
20+
21+
22+
def estimate_cuda_runtime_ms(
23+
fn: Callable,
24+
grad_to_none: Optional[Iterable[torch.Tensor]] = None,
25+
clear_cache_fn: Optional[Callable[[], None]] = None,
26+
iters: int = 5,
27+
prime: bool = True,
28+
) -> float:
29+
clear_cache_fn = clear_cache_fn or (lambda: None)
30+
31+
def run_once() -> None:
32+
if grad_to_none is not None:
33+
for x in grad_to_none:
34+
x.grad = None
35+
clear_cache_fn()
36+
fn()
37+
38+
if prime:
39+
run_once()
40+
torch.cuda.synchronize()
41+
42+
start_event = torch.cuda.Event(enable_timing=True)
43+
end_event = torch.cuda.Event(enable_timing=True)
44+
start_event.record()
45+
for _ in range(iters):
46+
run_once()
47+
end_event.record()
48+
torch.cuda.synchronize()
49+
return start_event.elapsed_time(end_event) / iters

tritonbench/components/kineto/trace.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
import torch
1010
import torch.profiler as profiler
11+
from tritonbench.components.do_bench.utils import (
12+
estimate_cuda_runtime_ms,
13+
resolve_warmup_and_rep,
14+
)
1115
from tritonbench.utils.constants import DEFAULT_N_REP, DEFAULT_N_WARMUP
1216
from tritonbench.utils.env_utils import has_manifold
1317

@@ -122,8 +126,8 @@ def do_bench_kineto_cudagraph(
122126

123127
def do_bench_kineto(
124128
fn: Callable,
125-
warmup: int,
126-
rep: int,
129+
warmup: Optional[int],
130+
rep: Optional[int],
127131
grad_to_none=None,
128132
fast_flush=True,
129133
profile_opts=None,
@@ -154,7 +158,6 @@ def do_bench_kineto(
154158

155159
fn()
156160
torch.cuda.synchronize()
157-
158161
# We maintain a buffer of 256 MB that we clear
159162
# before each kernel call to make sure that the L2
160163
# doesn't contain any input data before the run
@@ -167,16 +170,8 @@ def do_bench_kineto(
167170
else:
168171
clear_cache = lambda *args: None
169172

170-
# Estimate the runtime of the function
171-
start_event = torch.cuda.Event(enable_timing=True)
172-
end_event = torch.cuda.Event(enable_timing=True)
173-
start_event.record()
174-
for _ in range(5):
175-
clear_cache()
176-
fn()
177-
end_event.record()
178-
torch.cuda.synchronize()
179-
estimate_ms = start_event.elapsed_time(end_event) / 5
173+
estimate_ms = estimate_cuda_runtime_ms(fn, clear_cache_fn=clear_cache)
174+
warmup, rep = resolve_warmup_and_rep(warmup, rep, estimate_ms)
180175

181176
# Calculate number of iterations based on target rep time
182177
if estimate_ms == 0:

tritonbench/components/ncu/__init__.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Callable
22

33
import torch
4+
from tritonbench.components.do_bench.utils import (
5+
estimate_cuda_runtime_ms,
6+
resolve_warmup_and_rep,
7+
)
48

59

610
class cuda_profiler_range:
@@ -40,20 +44,12 @@ def do_bench_in_task(
4044

4145
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
4246

43-
if warmup:
44-
# Estimate the runtime of the function
45-
start_event = torch.cuda.Event(enable_timing=True)
46-
end_event = torch.cuda.Event(enable_timing=True)
47-
start_event.record()
48-
for _ in range(5):
49-
cache.zero_()
50-
fn()
51-
end_event.record()
52-
torch.cuda.synchronize()
53-
estimate_ms = start_event.elapsed_time(end_event) / 5
47+
if warmup == True:
48+
estimate_ms = estimate_cuda_runtime_ms(fn, clear_cache_fn=cache.zero_)
49+
warmup, _ = resolve_warmup_and_rep(warmup, None, estimate_ms)
5450

5551
# compute number of warmup and repeat
56-
n_warmup = max(1, int(warmup / estimate_ms))
52+
n_warmup = 1 if estimate_ms == 0 else max(1, int(warmup / estimate_ms))
5753
# Warm-up
5854
for _ in range(n_warmup):
5955
fn()

tritonbench/utils/constants.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
DEFAULT_WARMUP = 3000
2-
DEFAULT_REP = 3000
1+
from typing import Dict, Tuple
2+
3+
DEFAULT_WARMUP_REP_BY_ESTIMATED_KERNEL_MS: Dict[str, Tuple[int, int]] = {
4+
"1": (100, 100),
5+
"10": (1000, 1000),
6+
"100": (3000, 3000),
7+
}
38
DEFAULT_POWER_REPCNT = 2000
49
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
510
DEFAULT_SLEEP = 0.0

tritonbench/utils/parser.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
DEFAULT_ENTROPY_MAX_SAMPLES,
88
DEFAULT_ENTROPY_MIN_R2,
99
DEFAULT_ENTROPY_WINDOW_SIZE,
10-
DEFAULT_REP,
11-
DEFAULT_WARMUP,
1210
)
1311
from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode
1412
from tritonbench.utils.gpu_utils import get_gpu_device_name
@@ -76,14 +74,14 @@ def get_parser(args=None):
7674
parser.add_argument(
7775
"--warmup",
7876
type=int,
79-
default=DEFAULT_WARMUP,
80-
help="Num of warmup runs for each benchmark run.",
77+
default=None,
78+
help="Warmup time in ms for each benchmark run. Default: auto by estimated kernel latency.",
8179
)
8280
parser.add_argument(
8381
"--rep",
8482
type=int,
85-
default=DEFAULT_REP,
86-
help="The rep time for each benchmark run.",
83+
default=None,
84+
help="Target measurement time in ms for each benchmark run. Default: auto by estimated kernel latency.",
8785
)
8886
parser.add_argument(
8987
"--autotune-warmup",

tritonbench/utils/triton_op.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@
2929
from torch.utils._pytree import tree_map
3030
from triton.runtime.errors import OutOfResources as TritonOutOfResources
3131
from tritonbench.components.do_bench import do_bench_wrapper, Latency
32+
from tritonbench.components.do_bench.utils import (
33+
estimate_cuda_runtime_ms,
34+
resolve_warmup_and_rep,
35+
)
3236
from tritonbench.components.export import export_data
3337
from tritonbench.components.power import PowerManagerTask
3438
from tritonbench.data import get_input_loader
3539
from tritonbench.utils.constants import (
40+
DEFAULT_N_REP,
41+
DEFAULT_N_WARMUP,
3642
DEFAULT_POWER_REPCNT,
3743
DEFAULT_QUANTILES,
38-
DEFAULT_REP,
3944
DEFAULT_SLEEP,
40-
DEFAULT_WARMUP,
4145
)
4246
from tritonbench.utils.cudagraph_utils import CudaGraphConfig, CudaGraphError
4347
from tritonbench.utils.diode_utils import (
@@ -159,7 +163,7 @@ def __exit__(self, *args, **kwargs):
159163
self.elapsed_ms = (end_time - self._start_time) * 1e3
160164

161165

162-
def do_bench_walltime(fn, warmup=25, rep=DEFAULT_REP):
166+
def do_bench_walltime(fn, warmup=None, rep=None):
163167
fn()
164168
torch.cuda.synchronize()
165169

@@ -168,10 +172,15 @@ def do_bench_walltime(fn, warmup=25, rep=DEFAULT_REP):
168172
fn()
169173
torch.cuda.synchronize()
170174
estimate_ms = timer.elapsed_ms / 5
175+
warmup, rep = resolve_warmup_and_rep(warmup, rep, estimate_ms)
171176

172177
# compute number of warmup and repeat
173-
n_warmup = max(1, int(warmup / estimate_ms))
174-
n_repeat = max(1, int(rep / estimate_ms))
178+
if estimate_ms == 0:
179+
n_warmup = DEFAULT_N_WARMUP
180+
n_repeat = DEFAULT_N_REP
181+
else:
182+
n_warmup = max(1, int(warmup / estimate_ms))
183+
n_repeat = max(1, int(rep / estimate_ms))
175184

176185
# Warm-up
177186
for _ in range(n_warmup):
@@ -1034,8 +1043,8 @@ def benchmark_fn():
10341043

10351044
def run(
10361045
self,
1037-
warmup=DEFAULT_WARMUP,
1038-
rep=DEFAULT_REP,
1046+
warmup: int | None = None,
1047+
rep: int | None = None,
10391048
quantiles=DEFAULT_QUANTILES,
10401049
sleep=DEFAULT_SLEEP,
10411050
) -> None:
@@ -1901,8 +1910,8 @@ def _do_bench(
19011910
self,
19021911
input_id: int,
19031912
fn_name: str,
1904-
warmup=DEFAULT_WARMUP,
1905-
rep=DEFAULT_REP,
1913+
warmup: int | None,
1914+
rep: int | None,
19061915
repcnt=None,
19071916
quantiles=DEFAULT_QUANTILES,
19081917
baseline: bool = False,

0 commit comments

Comments
 (0)