-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathbenchmark_indexing.py
More file actions
64 lines (49 loc) · 1.5 KB
/
benchmark_indexing.py
File metadata and controls
64 lines (49 loc) · 1.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
import gc
import os
import torch
import triton
from qwen3_moe_fused.kernels.indexing import (
get_expert_offsets_and_idx_blocks,
get_expert_offsets_and_idx_naive,
get_expert_offsets_and_idx_parallel,
)
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
providers = {
"naive": get_expert_offsets_and_idx_naive,
"parallel": get_expert_offsets_and_idx_parallel,
"blocks": get_expert_offsets_and_idx_blocks,
}
provider_names = list(providers)
@triton.testing.perf_report(
[
triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 21)],
line_arg="provider",
line_vals=provider_names,
line_names=provider_names,
ylabel="GB/s",
plot_name="indexing",
args={},
)
]
)
def benchmark(N, provider):
print("N", N, "provider", provider, "begin")
gc.collect()
torch.cuda.empty_cache()
E = 128
device = "cuda"
dtype = torch.int32
s = torch.randint(0, E, (N,), device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: providers[provider](s, E), warmup=100, rep=1000, quantiles=quantiles
)
gbps = lambda ms: 3 * N * 4 / ms * 1e-6
print("N", N, "E", E, "provider", provider, "end", gbps(ms))
return gbps(ms), gbps(max_ms), gbps(min_ms)
if __name__ == "__main__":
with torch.inference_mode():
benchmark.run(print_data=True, save_path="./")