Skip to content

Commit 566b63c

Browse files
galexiteabulavin
andauthored
[RELEASE] Cherry-pick use of device-agnostic DeviceInterface in do_bench (#4733)
This PR cherry picks a device agnostic change to `do_bench` to use the `DeviceInterface` over the hard-coding of the device type onto the release/3.1.x branch: * triton-lang/triton@36789d2 This will allow for further device agnostic changes in PyTorch Inductor. --- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because it is a cherry-pick of a commit which was landed to `main`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Co-authored-by: Artemiy Bulavin <43750406+abulavin@users.noreply.github.com>
1 parent fc879bd commit 566b63c

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

python/triton/testing.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"):
7979
return getattr(torch, return_mode)(times).item()
8080

8181

82-
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
82+
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
83+
device_type="cuda"):
8384
"""
8485
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
8586
the 20-th and 80-th performance percentile.
@@ -100,33 +101,35 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
100101
assert return_mode in ["min", "max", "mean", "median"]
101102
import torch
102103

104+
di = torch._dynamo.device_interface.get_interface_for_device(device_type)
105+
103106
fn()
104-
torch.cuda.synchronize()
107+
di.synchronize()
105108

106109
# We maintain a buffer of 256 MB that we clear
107110
# before each kernel call to make sure that the L2
108111
# doesn't contain any input data before the run
109112
if fast_flush:
110-
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
113+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type)
111114
else:
112-
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
115+
cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type)
113116

114117
# Estimate the runtime of the function
115-
start_event = torch.cuda.Event(enable_timing=True)
116-
end_event = torch.cuda.Event(enable_timing=True)
118+
start_event = di.Event(enable_timing=True)
119+
end_event = di.Event(enable_timing=True)
117120
start_event.record()
118121
for _ in range(5):
119122
cache.zero_()
120123
fn()
121124
end_event.record()
122-
torch.cuda.synchronize()
125+
di.synchronize()
123126
estimate_ms = start_event.elapsed_time(end_event) / 5
124127

125128
# compute number of warmup and repeat
126129
n_warmup = max(1, int(warmup / estimate_ms))
127130
n_repeat = max(1, int(rep / estimate_ms))
128-
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
129-
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
131+
start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
132+
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
130133
# Warm-up
131134
for _ in range(n_warmup):
132135
fn()
@@ -145,7 +148,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
145148
fn()
146149
end_event[i].record()
147150
# Record clocks
148-
torch.cuda.synchronize()
151+
di.synchronize()
149152
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
150153
if quantiles is not None:
151154
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()

0 commit comments

Comments
 (0)