-
Notifications
You must be signed in to change notification settings - Fork 931
Expand file tree
/
Copy pathattention.py
More file actions
2526 lines (2336 loc) · 90.7 KB
/
attention.py
File metadata and controls
2526 lines (2336 loc) · 90.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from collections import defaultdict
import numpy as np
import torch
import flashinfer
import flashinfer.decode
# Try to import cudnn for version checking
CUDNN_AVAILABLE = False
CUDNN_BACKEND_VERSION = 0
try:
import cudnn
CUDNN_AVAILABLE = True
CUDNN_BACKEND_VERSION = cudnn.backend_version()
except ImportError:
pass
except OSError as e:
error_msg = str(e).lower()
is_lib_missing = any(ext in error_msg for ext in [".so", ".dll"])
if not is_lib_missing:
raise
from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
bench_gpu_time,
)
from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
get_device,
print_perf_metrics,
is_close_stats,
filter_backends_by_compute_capability,
)
def normalize_backends(backends):
"""
Normalize backend names planned for deprecation and print warnings.
Currently:
- Replaces deprecated 'trtllm-gen-native' with 'trtllm-native'.
Args:
backends: List of backend names
Returns:
List of normalized backend names
"""
normalized = []
for backend in backends:
if backend == "trtllm-gen-native":
print(
"[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. "
)
normalized.append("trtllm-native")
else:
normalized.append(backend)
return normalized
def run_attention_test(args):
"""
Run an attention test.
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "BatchDecodeWithPagedKVCacheWrapper":
return testBatchDecodeWithPagedKVCacheWrapper(args)
elif args.routine == "BatchPrefillWithPagedKVCacheWrapper":
return testBatchPrefillWithPagedKVCacheWrapper(args)
elif args.routine == "BatchPrefillWithRaggedKVCacheWrapper":
return testBatchPrefillWithRaggedKVCacheWrapper(args)
elif args.routine == "BatchMLAPagedAttentionWrapper":
return testBatchMLAPagedAttentionWrapper(args)
else:
print(f"[ERROR] Unsupported routine: {args.routine}")
return []
def parse_attention_args(line, parser):
"""
Parse command line arguments for attention test configuration.
Args:
line: Command line arguments
parser: ArgumentParser object already populated with shared arguments
Returns:
Parsed argument namespace
"""
parser.add_argument(
"--backends",
type=str,
required=False,
nargs="+",
default=["fa2"],
choices=[
"fa2",
"fa2_tc",
"fa3",
"auto",
"cudnn",
"cudnn-native",
"cutlass",
"trtllm-gen",
"trtllm-native",
"trtllm-gen-native", # Deprecated, will be removed in future
"cute-dsl",
],
help="Kernel backends to test. Default: fa2. backend=auto is only supported for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.",
)
parser.add_argument(
"--page_size",
type=int,
required=False,
default=0,
help="Page size for paged attention. Required for paged attention. Ignored for non-paged attention.",
)
parser.add_argument(
"--batch_size", type=int, required=True, help="Batch size of test case."
)
parser.add_argument(
"--s_qo",
type=int,
required=False,
default=1,
help="Max sequence length of the query. For decode, 1 is standard decode and >1 enables speculative decode on supported backends.",
)
parser.add_argument(
"--s_kv",
type=int,
required=True,
help="Max sequence length of the key and value.",
)
parser.add_argument(
"--num_qo_heads", type=int, required=True, help="Number of query heads."
)
parser.add_argument(
"--num_kv_heads", type=int, required=True, help="Number of key and value heads."
)
parser.add_argument(
"--head_dim_qk",
type=int,
required=False,
help="Head dimension of the query and key for prefill and decode MHA/GQA/MQA.",
)
parser.add_argument(
"--head_dim_vo",
type=int,
required=False,
help="Head dimension of the value and output for prefill and decode MHA/GQA/MQ.",
)
parser.add_argument(
"--head_dim_ckv",
type=int,
required=False,
help="Head dimension of compressed kv-cache tensor (without rope).",
)
parser.add_argument(
"--head_dim_kpe",
type=int,
required=False,
help="Head dimension of the rope part of the kv-cache tensor.",
)
parser.add_argument(
"--q_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the query. Currently only bfloat16 is supported.",
)
parser.add_argument(
"--kv_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the key and value. Currently only bfloat16 is supported.",
)
parser.add_argument(
"--out_dtype",
type=str,
required=False,
default=None,
help="Data type of the output. If not specified, defaults to q_dtype.",
)
parser.add_argument(
"--causal",
action="store_true",
default=False,
help="Causal masking. Note: not padding masking. Only used for prefill tests.",
)
parser.add_argument(
"--random_actual_seq_len",
action="store_true",
default=False,
help="Use random actual sequence lengths for the query and key and value. Random values are generated between 1 and maximum sequence length. If False, use maximum sequence length.",
)
args = parser.parse_args(line)
# Normalize backend names (handle deprecated names)
args.backends = normalize_backends(args.backends)
if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len):
"""
Get an array of actual sequence lengths for given batch size and max sequence length.
If random_actual_seq_len is True, sample actual sequence lengths randomly.
Otherwise, set all actual sequence lengths to max_seqlen.
Args:
max_seqlen: Maximum sequence length.
batch_size: Batch size.
device: Device to sample on.
random_actual_seq_len: Whether to sample actual sequence lengths randomly.
Returns:
actual_seq_lens: Actual sequence lengths for each batch.
"""
if random_actual_seq_len:
actual_seq_lens = torch.randint(
1, max_seqlen + 1, (batch_size, 1, 1, 1), device=device, dtype=torch.int32
)
else:
actual_seq_lens = torch.full(
(batch_size, 1, 1, 1), max_seqlen, device=device, dtype=torch.int32
)
return actual_seq_lens
def generate_speculative_causal_mask(batch_size, q_seq_len, device):
"""
Generate a packed causal mask for speculative decode chunks (q_len > 1).
Returned shape is [batch_size, q_seq_len, num_packed_masks_per_token * 2]
with dtype uint16, where num_packed_masks_per_token = ceil(q_seq_len / 32).
Each query row i encodes allowed attention to draft-token columns j <= i
(strictly lower-triangular with diagonal) and masks out j > i.
The innermost dimension stores packed bits (uint32 words reinterpreted as
uint16), matching the mask layout expected by decode APIs.
"""
num_packed_masks_per_token = (q_seq_len + 31) // 32
q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1)
kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0)
causal_bool_mask = kv_indices <= q_indices
padded_seq_len = num_packed_masks_per_token * 32
if padded_seq_len > q_seq_len:
padding = torch.zeros(
q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool
)
causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1)
causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32)
bit_positions = torch.tensor(
[1 << i for i in range(32)], device=device, dtype=torch.int64
)
mask_uint32 = (
(causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32)
)
mask_uint32 = (
mask_uint32.unsqueeze(0)
.expand(batch_size, q_seq_len, num_packed_masks_per_token)
.contiguous()
)
return mask_uint32.view(torch.uint16)
def testBatchDecodeWithPagedKVCacheWrapper(args):
"""
Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa2_tc, auto, cudnn, trtllm-gen, trtllm-native backends.
This test:
1. Creates paged KV cache and query tensors
2. Runs decode attention with different backends
3. Verifies outputs match between backends
4. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBatchDecodeWithPagedKVCacheWrapper")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
# Basic setup
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
q_init_dtype = torch.bfloat16
kv_init_dtype = torch.bfloat16
rtol = 2e-1
atol = 1e-2
res = []
# Handle different query data types.
q_dtype = dtype_str_to_torch_dtype(args.q_dtype)
if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
print(f"[ERROR] Unsupported q_dtype: {args.q_dtype}")
return res
# Handle different KV cache data types.
is_nvfp4_kv = args.kv_dtype == "nvfp4"
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.uint8]:
print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}")
return res
if args.out_dtype is not None:
print(
"[WARNING] --out_dtype is not yet supported for BatchDecodeWithPagedKVCacheWrapper; ignoring."
)
# Parse and validate backend configurations
backends = args.backends
page_size = args.page_size
batch_size = args.batch_size
s_qo = args.s_qo
speculative_decode = s_qo > 1
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim_qk = args.head_dim_qk
head_dim_vo = args.head_dim_vo if args.head_dim_vo is not None else head_dim_qk
is_cuda_graph_compatible = not args.no_cuda_graph
# return_lse = not args.no_lse # TO-DO: Add support for this
run_refcheck = args.refcheck
backends = filter_backends_by_compute_capability(backends, args.routine, device)
# Check for backend-specific constraints
if "fa2" in backends:
remove_fa2 = False
if speculative_decode:
print("[INFO] FA2 backend does not support speculative decode. Skipping.")
remove_fa2 = True
head_grp_size = (
num_qo_heads // num_kv_heads
) # If 5, FA2 backend is not supported.
if head_grp_size == 5:
print(
"[INFO] FA2 backend is not supported for this configuration. Skipping."
)
remove_fa2 = True
if remove_fa2:
backends.remove("fa2")
if "fa2_tc" in backends:
remove_fa2_tc = False
if speculative_decode:
print(
"[INFO] FA2_TC backend does not support speculative decode. Skipping."
)
remove_fa2_tc = True
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] FA2_TC backend does not support FP8. Skipping.")
remove_fa2_tc = True
if remove_fa2_tc:
backends.remove("fa2_tc")
if "cudnn" in backends:
remove_cudnn = False
if speculative_decode:
print("[INFO] cuDNN backend does not support speculative decode. Skipping.")
remove_cudnn = True
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] cuDNN backend does not support FP8. Skipping.")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if "auto" in backends and speculative_decode:
print("[INFO] auto backend is disabled for speculative decode. Skipping.")
backends.remove("auto")
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
# Sample sequence lengths and create tensors
actual_seq_lens_kv = sample_actual_seq_lens(
s_kv, batch_size, device, args.random_actual_seq_len
)
sum_seq_kv = torch.sum(actual_seq_lens_kv).item()
avg_seq_len_kv = sum_seq_kv // batch_size
if args.verbose >= 1:
print(f"[VERBOSE] Average actual seq len: {avg_seq_len_kv}")
if args.verbose >= 2:
print(f"[VVERBOSE] {actual_seq_lens_kv.flatten() = }")
# Create query tensor
q = torch.rand(
batch_size * s_qo,
num_qo_heads,
head_dim_qk,
device=device,
dtype=q_init_dtype,
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
# Create KV cache
num_pages_per_seq = (s_kv + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
if args.verbose >= 2:
print(f"[VVERBOSE] {num_pages_per_seq = }")
print(f"[VVERBOSE] {total_num_pages = }")
# Initialize KV cache with appropriate shape and stride
kv_cache_shape = (
total_num_pages,
2, # 2 for key and value
num_kv_heads,
page_size,
head_dim_qk,
)
kv_cache = torch.randn(size=kv_cache_shape, dtype=kv_init_dtype).to(device)
# Keep a copy for TRT-LLM which uses different strides
if "trtllm-gen" in backends:
kv_cache_for_trt = kv_cache.detach().clone()
kv_cache = kv_cache.as_strided(
kv_cache.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
k_cache_view, v_cache_view = kv_cache[:, 0, :, :, :], kv_cache[:, 1, :, :, :]
if "trtllm-gen" in backends:
# kv_cache now has different tensor stride and logical values. Copy over values to kv_cache_for_trt.
# Result is kv_cache and kv_cache_for_trt have the same logical values but different tensor strides.
kv_cache_for_trt.copy_(kv_cache)
v_cache = v_cache_view.as_strided(
v_cache_view.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
k_cache = k_cache_view.as_strided(
k_cache_view.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
device=device,
)
kv_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(
(actual_seq_lens_kv.flatten() + page_size - 1) // page_size, dim=0
),
]
)
.int()
.to(device)
)
# kv_indices[-1] is the total number of actual pages
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx]
kv_last_page_len = (
torch.where(
actual_seq_lens_kv.flatten() % page_size == 0,
torch.full((batch_size,), page_size, device=device),
actual_seq_lens_kv.flatten() % page_size,
)
.int()
.to(device)
)
ragged_q = (
torch.arange(0, batch_size + 1, device=device)
* (s_qo * num_qo_heads * head_dim_qk)
).long() # For cuDNN
speculative_mask = (
generate_speculative_causal_mask(batch_size, s_qo, device)
if speculative_decode
else None
)
scale = float(1.0 / (head_dim_qk**0.5))
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
if args.verbose >= 2:
print(f"[VVERBOSE] {kv_cache.shape = }")
print(f"[VVERBOSE] {kv_cache.stride() = }")
print(f"[VVERBOSE] {block_tables.shape = }")
print(f"[VVERBOSE] {kv_indptr.shape = }")
print(f"[VVERBOSE] {kv_indices.shape = }")
print(f"[VVERBOSE] {kv_last_page_len.shape = }")
print(f"[VVERBOSE] {scale = }")
# Prepare wrappers
backend_wrappers = {}
resolved_backends = {}
for backend in backends:
if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]:
plan_kv_indptr = (
kv_indptr.clone().detach() if backend == "trtllm-gen" else kv_indptr
)
# Map fa2_tc to fa2 for the actual backend parameter
# fa2_tc is a benchmark-specific name meaning "fa2 with tensor cores"
actual_backend = "fa2" if backend == "fa2_tc" else backend
backend_wrappers[backend] = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
"HND",
use_cuda_graph=is_cuda_graph_compatible,
use_tensor_cores=(backend != "fa2"),
paged_kv_indptr_buffer=plan_kv_indptr,
paged_kv_indices_buffer=kv_indices,
paged_kv_last_page_len_buffer=kv_last_page_len,
backend=actual_backend,
)
backend_wrappers[backend].plan(
plan_kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
page_size,
q_data_type=q_dtype,
data_type=kv_dtype,
block_tables=block_tables,
)
resolved_backends[backend] = backend_wrappers[backend]._backend
else:
resolved_backends[backend] = backend
## Prepare dtype-specific data
k_scale, v_scale = None, None
kv_cache_sf = None
kv_cache_nvfp4 = None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if is_nvfp4_kv:
# NVFP4 KV requires FP8 query
if q_dtype != torch.float8_e4m3fn:
print("[ERROR] NVFP4 KV cache requires --q_dtype fp8_e4m3.")
return res
kv_cache_nvfp4, kv_cache_sf, k_scale, v_scale = nvfp4_quantize_paged_kv_cache(
kv_cache[:, 0], kv_cache[:, 1]
)
elif kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_scale = k_data.amax().item() / 256
v_scale = v_data.amax().item() / 256
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)
if "trtllm-gen" in backends:
k_data, v_data = torch.chunk(kv_cache_for_trt, 2, dim=1)
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache_for_trt = torch.cat([k_fp8, v_fp8], dim=1)
def run_backend_wrapper(
backend,
q,
kv_cache,
k_cache,
v_cache,
workspace_buffer,
block_tables,
actual_seq_lens_kv,
ragged_q,
speculative_mask,
):
if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]:
wrapper_kv = kv_cache_nvfp4 if is_nvfp4_kv else kv_cache
return backend_wrappers[backend].run(
q,
wrapper_kv,
k_scale=k_scale,
v_scale=v_scale,
q_len_per_req=s_qo,
kv_cache_sf=kv_cache_sf,
)
elif backend == "cudnn":
return flashinfer.decode.cudnn_batch_decode_with_kv_cache(
q,
k_cache,
v_cache,
scale,
workspace_buffer,
max_sequence_kv=s_kv,
actual_seq_lens_kv=actual_seq_lens_kv,
block_tables=block_tables,
is_cuda_graph_compatible=is_cuda_graph_compatible,
batch_offsets_q=ragged_q,
batch_offsets_o=ragged_q,
)
elif backend == "trtllm-native":
native_kv = kv_cache_nvfp4 if is_nvfp4_kv else kv_cache
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q.contiguous(),
kv_cache=native_kv,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv,
max_seq_len=s_kv,
bmm1_scale=scale if k_scale is None else k_scale * scale,
bmm2_scale=1.0 if v_scale is None else v_scale,
kv_layout="HND",
backend="auto",
q_len_per_req=s_qo,
mask=speculative_mask,
kv_cache_sf=kv_cache_sf,
)
else:
print(f"[ERROR] Backend {backend} not supported")
return None
has_reference_output = False
# Iterate over each backend:
for cur_backend in backends:
# Clear workspace buffer to prevent unexpected interactions between backends.
workspace_buffer.zero_()
if run_refcheck:
outputs[cur_backend] = (
run_backend_wrapper(
cur_backend,
q,
kv_cache,
k_cache,
v_cache,
workspace_buffer,
block_tables,
actual_seq_lens_kv,
ragged_q,
speculative_mask,
)
.detach()
.clone()
)
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
# Unified benchmark entry: prefer graph if compatible and not using CUPTI
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend_wrapper,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=False,
enable_cupti=args.use_cupti,
use_cuda_graph=(is_cuda_graph_compatible and cur_backend != "fa2"),
cold_l2_cache=True,
input_args=(
cur_backend,
q,
kv_cache,
k_cache,
v_cache,
workspace_buffer,
block_tables,
actual_seq_lens_kv,
ragged_q,
speculative_mask,
),
)
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if args.verbose >= 2:
print(
"[VVERBOSE] Reference output is FP8. Converting to float32 for reference check."
)
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_outputs)):
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
if num_different_elements > 0:
print(
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Backend {tested_backends[i]} output mismatch"
)
# Compute perf metrics
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu")
actual_seq_lens_q_flat = torch.full_like(actual_seq_lens_kv_flat, s_qo)
tflops = attention_tflops_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
False,
median_time,
)
tb_per_sec = attention_tb_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
num_kv_heads,
median_time,
q_dtype=q_dtype,
kv_dtype=kv_dtype,
o_dtype=q_dtype,
)
resolved_backend = resolved_backends.get(backend, backend)
wrapper = backend_wrappers.get(backend)
if (
wrapper is not None
and resolved_backend == "fa2"
and wrapper.use_tensor_cores
):
resolved_backend = "fa2_tc"
display_backend = (
f"auto({resolved_backend})" if backend == "auto" else resolved_backend
)
print_perf_metrics(
display_backend, median_time, std_time, tflops, tb_per_sec
)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["resolved_backend"] = resolved_backend
cur_res["page_size"] = page_size
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
cur_res["s_kv"] = s_kv
cur_res["num_qo_heads"] = num_qo_heads
cur_res["num_kv_heads"] = num_kv_heads
cur_res["head_dim_qk"] = head_dim_qk
cur_res["head_dim_vo"] = head_dim_vo
cur_res["causal"] = False
cur_res["q_dtype"] = q_dtype
cur_res["kv_dtype"] = kv_dtype
cur_res["avg_actual_seq_len"] = avg_seq_len_kv
cur_res["random_actual_seq_len"] = args.random_actual_seq_len
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBatchPrefillWithPagedKVCacheWrapper(args):
"""
Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa3, auto, trtllm-gen, trtllm-native, and cudnn backends.
This test:
1. Creates paged KV cache and query tensors for prefill
2. Runs prefill attention with different backends
3. Verifies outputs match between backends (if refcheck enabled)
4. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: Dictionary containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBatchPrefillWithPagedKVCacheWrapper")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
# Basic setup
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
q_init_dtype = torch.bfloat16
kv_init_dtype = torch.bfloat16
rtol = 2e-1
atol = 1e-2
res = []
q_dtype = dtype_str_to_torch_dtype(args.q_dtype)
if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
print(f"[ERROR] Unsupported q_dtype: {args.q_dtype}")
return res
is_nvfp4_kv = args.kv_dtype == "nvfp4"
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.uint8]:
print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}")
return res
if args.out_dtype is not None:
print(
"[WARNING] --out_dtype is not yet supported for BatchPrefillWithPagedKVCacheWrapper; ignoring."
)
# Increase tolerances for FP8 due to lower precision
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
rtol = 5e-1 # Relaxed relative tolerance for FP8
atol = 1e-1 # Relaxed absolute tolerance for FP8
# Parse and validate backend configurations
backends = args.backends
page_size = args.page_size
batch_size = args.batch_size
s_qo = args.s_qo
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim_qk = args.head_dim_qk
head_dim_vo = args.head_dim_vo
causal = args.causal
is_cuda_graph_compatible = not args.no_cuda_graph
# return_lse = not args.no_lse # TO-DO: Add support for this
run_refcheck = args.refcheck
backends = filter_backends_by_compute_capability(backends, args.routine, device)
# Check for backend-specific constraints
if "fa2" in backends:
remove_fa2 = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print("[INFO] FA2 backend does not support FP8. Skipping.")
remove_fa2 = True
if remove_fa2:
backends.remove("fa2")
if "cudnn" in backends:
remove_cudnn = False
# cuDNN FP8 prefill requires cuDNN >= 9.17.1 (backend version 91701)
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91701:
print(
f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.17.1. "
f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn backend."
)
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if "cudnn-native" in backends:
remove_cudnn_native = False
# cuDNN-native does not yet support FP8 prefill
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91701:
print(
f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.17.1. "
f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn-native backend."
)
remove_cudnn_native = True
if remove_cudnn_native:
backends.remove("cudnn-native")
if "trtllm-gen" in backends:
remove_trtllm = False
if not causal:
print("[INFO] trtllm-gen backend currently requires causal = True")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm-gen")
if "trtllm-native" in backends:
remove_trtllm_native = False
if not causal:
print("[INFO] trtllm-native backend currently requires causal = True")
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
if "cutlass" in backends:
print("[INFO] CUTLASS backend does not support prefill. Skipping.")
remove_cutlass = True
if remove_cutlass:
backends.remove("cutlass")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
# Check for layer-specific constraints
layer_not_supported = False
if s_qo > s_kv:
print("[ERROR] s_qo > s_kv is not supported. Exiting.")
layer_not_supported = True
if layer_not_supported:
print("[ERROR] Layer not supported. Exiting.")
return res
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
# Sample sequence lengths.
# If s_qo == s_kv, then make sampled actual_seq_lens_kv the same as actual_seq_lens_q.
# IF s_qo < s_kv, then sample actual_seq_lens_kv separately. Then ensure actual_seq_lens_kv is at least as long as actual_seq_lens_q.
actual_seq_lens_q = sample_actual_seq_lens(
s_qo, batch_size, None, args.random_actual_seq_len
)
if s_qo == s_kv:
if args.verbose >= 2:
print(
"[VVERBOSE] s_qo == s_kv, making actual_seq_lens_kv the same as actual_seq_lens_q"
)
actual_seq_lens_kv = actual_seq_lens_q.clone()
else: # s_qo < s_kv
if args.verbose >= 2:
print("[VVERBOSE] s_qo < s_kv, sampling actual_seq_lens_kv")
actual_seq_lens_kv = sample_actual_seq_lens(
s_kv, batch_size, None, args.random_actual_seq_len
)
actual_seq_lens_kv = torch.maximum(actual_seq_lens_kv, actual_seq_lens_q)
avg_seq_len_q = actual_seq_lens_q.sum().item() // batch_size
avg_seq_len_kv = actual_seq_lens_kv.sum().item() // batch_size
if args.verbose >= 1:
print(f"[VERBOSE] Average actual qo seq len: {avg_seq_len_q}")
print(f"[VERBOSE] Average actual kv seq len: {avg_seq_len_kv}")
if args.verbose >= 2:
print(f"[VVERBOSE] {actual_seq_lens_q.flatten() = }")
print(f"[VVERBOSE] {actual_seq_lens_kv.flatten() = }")
cumsum_s_qo = torch.sum(actual_seq_lens_q)
q = torch.randn(
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
# Create KV cache
num_pages_per_seq = (s_kv + page_size - 1) // page_size