Skip to content

Commit f84ac1c

Browse files
authored
fix: Fix memory bandwidth calculation in MLA benchmarks (#2479)
<!-- .github/pull_request_template.md --> ## 📌 Description Summary * Fixed incorrect memory bandwidth calculation in `testBatchMLAPagedAttentionWrapper` that was using full tensor allocations instead of actual bytes accessed based on sequence lengths * Updated `bench_trtllm_gen_mla.py` to use the unified `bench_gpu_time()` utility with CUPTI for consistent timing with the benchmark framework cc @hypdeb <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Improved benchmarking: switched to CUDA/CUPTI-based timing with refined iteration controls (dry/run and repeat by iterations) and optional CUDA graph support. * Updated performance reporting to use explicit memory accounting from actual token usage (query, KV, output), and adjusted bandwidth and FLOPs printouts for clearer, more accurate throughput metrics. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 6ae5bfe commit f84ac1c

2 files changed

Lines changed: 43 additions & 24 deletions

File tree

benchmarks/bench_trtllm_gen_mla.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
import flashinfer
5-
from flashinfer.testing.utils import bench_gpu_time_with_cudagraph
5+
from flashinfer.testing.utils import bench_gpu_time
66

77
num_q_heads = 128
88
qk_nope_head_dim = 128
@@ -83,7 +83,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
8383
bmm2_scale=1.0,
8484
)
8585
# benchmark
86-
measurements = bench_gpu_time_with_cudagraph(
86+
measurements = bench_gpu_time(
8787
lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
8888
query=query,
8989
kv_cache=kv_cache.unsqueeze(1),
@@ -97,27 +97,40 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
9797
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
9898
bmm2_scale=1.0,
9999
),
100-
dry_run_time_ms=100,
101-
repeat_time_ms=1000,
102-
)
103-
io = (
104-
query.numel() * query.element_size()
105-
+ kv_cache.numel() * kv_cache.element_size()
100+
dry_run_iters=5,
101+
repeat_iters=30,
102+
enable_cupti=False,
103+
use_cuda_graph=True,
104+
cold_l2_cache=True,
106105
)
107106
ms = np.median(measurements)
107+
108+
# Memory bandwidth calculation based on actual bytes accessed
109+
elem_size = query.element_size()
110+
# Query bytes: batch_size * q_len_per_request * num_heads * head_dim
111+
q_mem_bytes = query.numel() * elem_size
112+
# KV cache bytes: actual tokens accessed (sum of seq_lens), not full allocation
113+
actual_kv_tokens = sum(seq_lens)
114+
kv_mem_bytes = actual_kv_tokens * (kv_lora_rank + qk_rope_head_dim) * elem_size
115+
# Output bytes: batch_size * q_len_per_request * num_heads * kv_lora_rank
116+
o_mem_bytes = (
117+
batch_size * q_len_per_request * num_q_heads * kv_lora_rank * elem_size
118+
)
119+
total_mem_bytes = q_mem_bytes + kv_mem_bytes + o_mem_bytes
120+
108121
flops = (
109122
2
110123
* num_q_heads
111124
* (2 * kv_lora_rank + qk_rope_head_dim)
112-
* sum(seq_lens)
125+
* actual_kv_tokens
113126
* q_len_per_request
114127
)
115128
print(
116129
f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}"
117130
)
118-
print(f"execution time: {ms} ms")
119-
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
120-
print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs/s")
131+
print(f"execution time: {ms:.4f} ms")
132+
print(f"memory bandwidth: {total_mem_bytes / ms / 1e6:.2f} GB/s")
133+
print(f"FLOPs: {flops / ms / 1e9:.2f} TFLOPs/s")
121134

122135

123136
if __name__ == "__main__":

benchmarks/routines/attention.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,20 +2244,26 @@ def run_backend_wrapper(
22442244
actual_seq_lens_q_flat = torch.ones_like(
22452245
actual_seq_lens_kv.flatten().to("cpu")
22462246
)
2247-
o_mem_bytes = (
2248-
actual_seq_lens_q_flat.numel()
2249-
* num_qo_heads
2250-
* head_dim_ckv
2251-
* q_dtype.itemsize
2247+
2248+
# Query bytes (q_nope + q_pe): batch_size * num_heads * head_dim
2249+
q_mem_bytes = (
2250+
q_nope.numel() * q_nope.element_size()
2251+
+ q_pe.numel() * q_pe.element_size()
22522252
)
2253-
qkv_mem_bytes = sum(
2254-
[
2255-
_.numel() * _.element_size()
2256-
for _ in [q_nope, q_pe, ckv_cache, kpe_cache]
2257-
]
2253+
2254+
# KV cache bytes: based on actual sequence lengths accessed, not full allocation
2255+
actual_kv_tokens = actual_seq_lens_kv_flat.sum().item()
2256+
kv_elem_size = ckv_cache.element_size() # Same dtype for ckv and kpe
2257+
kv_mem_bytes = (
2258+
actual_kv_tokens * (head_dim_ckv + head_dim_kpe) * kv_elem_size
22582259
)
2259-
total_mem_bytes = o_mem_bytes + qkv_mem_bytes
2260-
tb_per_sec = (total_mem_bytes / (median_time * 1e9)).item()
2260+
2261+
# Output bytes: batch_size * num_heads * head_dim_ckv
2262+
o_elem_size = q_nope.element_size() # Output has same dtype as query
2263+
o_mem_bytes = batch_size * num_qo_heads * head_dim_ckv * o_elem_size
2264+
2265+
total_mem_bytes = q_mem_bytes + kv_mem_bytes + o_mem_bytes
2266+
tb_per_sec = total_mem_bytes / (median_time * 1e9)
22612267
tflops_total = (
22622268
2
22632269
* torch.dot(

0 commit comments

Comments
 (0)