-
Notifications
You must be signed in to change notification settings - Fork 937
[Perf] Add FMHAv2 to flashinfer_benchmark.py and eliminate unnecessary H2D #2841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
49e4fe9
b23a6e4
87b130f
5bdf14e
47e4256
e827028
d0ab5f8
2215687
e841a0a
204cf49
698c8f1
42fc839
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| if not is_lib_missing: | ||
| raise | ||
| from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache | ||
| from flashinfer.prefill import trtllm_fmha_v2_prefill | ||
| from flashinfer.testing.utils import ( | ||
| attention_tb_per_sec_with_actual_seq_lens, | ||
| attention_tflops_per_sec_with_actual_seq_lens, | ||
|
|
@@ -111,6 +112,7 @@ def parse_attention_args(line, parser): | |
| "cutlass", | ||
| "trtllm-gen", | ||
| "trtllm-native", | ||
| "trtllm-fmha-v2", | ||
| "trtllm-gen-native", # Deprecated, will be removed in future | ||
| "cute-dsl", | ||
| ], | ||
|
|
@@ -936,6 +938,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): | |
| remove_trtllm_native = True | ||
| if remove_trtllm_native: | ||
| backends.remove("trtllm-native") | ||
| if "trtllm-fmha-v2" in backends and is_nvfp4_kv: | ||
| print("[INFO] trtllm-fmha-v2 backend does not support NVFP4. Skipping.") | ||
| backends.remove("trtllm-fmha-v2") | ||
|
|
||
| if "cutlass" in backends: | ||
| print("[INFO] CUTLASS backend does not support prefill. Skipping.") | ||
|
|
@@ -1072,7 +1077,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): | |
| .to(device) | ||
| ) | ||
|
|
||
| # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr | ||
| # Page-based indptr for FlashInfer paged attention (cumulative page counts) | ||
| kv_indptr = ( | ||
| torch.cat( | ||
| [ | ||
|
|
@@ -1086,6 +1091,17 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): | |
| .int() | ||
| .to(device) | ||
| ) | ||
| # Token-based indptr for TRT-LLM backends (cumulative token counts) | ||
| kv_token_indptr = ( | ||
| torch.cat( | ||
| [ | ||
| torch.tensor([0], device=device), | ||
| torch.cumsum(actual_seq_lens_kv_device.flatten(), dim=0), | ||
| ] | ||
| ) | ||
| .int() | ||
| .to(device) | ||
| ) | ||
| kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) | ||
| for i in range(len(kv_indptr) - 1): | ||
| start_idx = kv_indptr[i] | ||
|
|
@@ -1158,6 +1174,16 @@ def to_float8(x, dtype=torch.float8_e4m3fn): | |
| v_quantized, _ = to_float8(v_data, kv_dtype) | ||
| kv_cache = torch.cat([k_quantized, v_quantized], dim=1) | ||
|
|
||
| _fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0 | ||
|
|
||
| # Ensure trtllm-fmha-v2 sees contiguous HND-physical paged KV cache. | ||
| # Skip if kv_cache is not a plain Tensor (e.g., NVFP4 packed tuple). | ||
| # backend filter further down also drops trtllm-fmha-v2 in that case. | ||
| if "trtllm-fmha-v2" in backends and isinstance(kv_cache, torch.Tensor): | ||
| _fmha_v2_kv_cache = kv_cache.contiguous() | ||
| else: | ||
| _fmha_v2_kv_cache = kv_cache | ||
|
|
||
| # Prepare wrappers (after FP8 conversion so we have correct dtypes) | ||
| backend_wrappers = {} | ||
| resolved_backends = {} | ||
|
|
@@ -1304,6 +1330,25 @@ def run_backend_wrapper( | |
| v_scale=v_scale_tensor, | ||
| o_data_type=o_data_type, | ||
| )[0] | ||
| elif backend == "trtllm-fmha-v2": | ||
| _q_scale = q_scale if q_scale is not None else 1.0 | ||
| _k_scale = k_scale if k_scale is not None else 1.0 | ||
| return trtllm_fmha_v2_prefill( | ||
| qkv=(q, _fmha_v2_kv_cache), | ||
| input_layout="Q_PAGED_KV_HND", | ||
| workspace_buffer=workspace_buffer, | ||
| seq_lens=actual_seq_lens_kv_device.flatten(), | ||
| max_q_len=s_qo, | ||
| max_kv_len=s_kv, | ||
| bmm1_scale=_q_scale * _k_scale * scale, | ||
| bmm2_scale=_fmha_v2_bmm2_scale, | ||
| batch_size=batch_size, | ||
| cum_seq_lens_q=qo_indptr, | ||
| cum_seq_lens_kv=kv_token_indptr, | ||
| block_tables=block_tables, | ||
| mask_mode="causal" if causal else "padding", | ||
| out_dtype=o_data_type, | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| else: | ||
| print(f"[ERROR] Backend {backend} not supported") | ||
| return None | ||
|
|
@@ -1366,9 +1411,15 @@ def run_backend_wrapper( | |
| tested_outputs = list(outputs.values()) | ||
|
|
||
| # When cases where FA2 is not available, try to find an alternative reference | ||
| # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native | ||
| # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native > trtllm-fmha-v2 | ||
| if run_refcheck and not has_reference_output and len(tested_backends) > 1: | ||
| reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"] | ||
| reference_priority = [ | ||
| "cudnn", | ||
| "cudnn-native", | ||
| "trtllm-gen", | ||
| "trtllm-native", | ||
| "trtllm-fmha-v2", | ||
| ] | ||
| for candidate in reference_priority: | ||
| if candidate in tested_backends: | ||
| has_reference_output = True | ||
|
|
@@ -1598,6 +1649,12 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): | |
| remove_trtllm_native = True | ||
| if remove_trtllm_native: | ||
| backends.remove("trtllm-native") | ||
| if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn: | ||
| print( | ||
| "[INFO] trtllm-fmha-v2 backend does not support FP8 e4m3 with " | ||
| "SEPARATE_Q_K_V layout. Skipping." | ||
| ) | ||
| backends.remove("trtllm-fmha-v2") | ||
|
|
||
| if len(backends) == 0: | ||
| print("[ERROR] No backends to test. Exiting.") | ||
|
|
@@ -1836,6 +1893,8 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): | |
| k = (k / k_scale).to(kv_dtype) | ||
| v = (v / v_scale).to(kv_dtype) | ||
|
|
||
| _fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0 | ||
|
|
||
| trtllm_out = None | ||
| if "trtllm-native" in backends or "cute-dsl" in backends: | ||
| # cute-dsl varlen kernel uses negative pointer offsets on output, | ||
|
|
@@ -1944,6 +2003,24 @@ def run_backend_wrapper( | |
| return_lse=True, | ||
| out=trtllm_out, | ||
| )[0] | ||
| elif backend == "trtllm-fmha-v2": | ||
| _q_scale = q_scale if q_scale is not None else 1.0 | ||
| _k_scale = k_scale if k_scale is not None else 1.0 | ||
| return trtllm_fmha_v2_prefill( | ||
| qkv=(q, k, v), | ||
| input_layout="SEPARATE_Q_K_V", | ||
| workspace_buffer=workspace_buffer, | ||
| seq_lens=actual_seq_lens_kv_device.flatten(), | ||
| max_q_len=s_qo, | ||
| max_kv_len=s_kv, | ||
| bmm1_scale=_q_scale * _k_scale * scale, | ||
| bmm2_scale=_fmha_v2_bmm2_scale, | ||
| batch_size=batch_size, | ||
| cum_seq_lens_q=qo_indptr, | ||
| cum_seq_lens_kv=kv_indptr, | ||
| mask_mode="causal" if causal else "padding", | ||
| out_dtype=out_dtype, | ||
| ) | ||
|
Comment on lines
+2006
to
+2023
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filter out FP8 queries before the ragged FMHA-v2 call.
π‘οΈ Suggested guard if "trtllm-native" in backends:
remove_trtllm_native = False
if not (head_dim_qk == 192 and head_dim_vo == 128) and not (
head_dim_qk == 128 and head_dim_vo == 128
):
print(
"[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128 or head_dim_qk == 128 and head_dim_vo == 128. Skipping."
)
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
+ if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn:
+ print(
+ "[INFO] trtllm-fmha-v2 does not support FP8 query with SEPARATE_Q_K_V. Skipping."
+ )
+ backends.remove("trtllm-fmha-v2")π€ Prompt for AI Agents |
||
| else: | ||
| print(f"[ERROR] Backend {backend} not supported") | ||
| return None | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.