-
Notifications
You must be signed in to change notification settings - Fork 931
Expand file tree
/
Copy pathgemm.py
More file actions
1942 lines (1742 loc) · 68.7 KB
/
gemm.py
File metadata and controls
1942 lines (1742 loc) · 68.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
from einops import einsum
import flashinfer
from flashinfer.autotuner import autotune
from flashinfer.fp8_quantization import mxfp8_quantize
from flashinfer.testing.utils import (
bench_gpu_time,
dequantize_fp8,
quantize_fp8,
)
from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
get_device,
print_perf_metrics,
is_close_stats,
filter_backends_by_compute_capability,
)
def run_gemm_test(args):
"""
Run a gemm test.
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "gemm_fp8_nt_groupwise":
return testGemmFp8NtGroupwise(args)
elif args.routine == "group_gemm_fp8_nt_groupwise":
return testGroupGemmFp8NtGroupwise(args)
elif args.routine == "bmm_fp8":
return testBmmFp8(args)
elif args.routine == "bmm_mxfp8":
return testBmmMxfp8(args)
elif args.routine == "mm_fp4":
return testMmFp4(args)
elif args.routine == "mm_mxfp8":
return testMmMxfp8(args)
elif args.routine == "mm_bf16":
return testMmBf16(args)
elif args.routine == "bmm_bf16":
return testBmmBf16(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
def parse_gemm_args(line, parser):
"""
Parse command line arguments for gemm test configuration.
Args:
line: Command line arguments
parser: ArgumentParser object already populated with shared arguments
Returns:
Parsed argument namespace
"""
parser.add_argument(
"--batch_size",
type=int,
required=False,
default=1,
help="Batch size of test case.",
)
parser.add_argument(
"--m", type=int, required=True, help="Number of rows in the first matrix."
)
parser.add_argument(
"--n", type=int, required=True, help="Number of columns in the second matrix."
)
parser.add_argument(
"--k",
type=int,
required=True,
help="Number of columns in the first matrix and number of rows in the second matrix.",
)
parser.add_argument(
"--tile_size",
type=int,
required=False,
default=128,
help="Tile size for the gemm operation.",
)
parser.add_argument(
"--group_size",
type=int,
required=False,
default=1,
help="Group size for the group gemm operation.",
)
parser.add_argument(
"--scale_major_mode",
type=str,
required=False,
default="MN",
choices=["MN", "K"],
help="Scale major mode.",
)
parser.add_argument(
"--input_dtype",
type=str,
required=False,
default="fp8_e4m3",
help="Data type of the input.",
)
parser.add_argument(
"--mat2_dtype",
type=str,
required=False,
default="fp8_e4m3",
help="Data type of the mat2.",
)
parser.add_argument(
"--out_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the output.",
)
parser.add_argument(
"--mma_sm",
type=int,
required=False,
default=1,
choices=[1, 2],
help="How many SMs to use for the MMA operation, must be 1 or 2",
)
parser.add_argument(
"--backends",
type=str,
required=False,
nargs="+",
default=["cudnn"],
choices=[
"cudnn",
"cublas",
"trtllm",
"cutlass",
"tgv",
"cublaslt",
"cute-dsl",
"auto",
],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
"--use_128x4_sf_layout",
action="store_true",
help="Use 128x4 SF layout for the input and mat2.",
)
parser.add_argument(
"--use_nvfp4",
action="store_true",
help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.",
)
parser.add_argument(
"--autotune",
action="store_true",
default=False,
help=(
"Enable autotuner warmup for supported routines (mm_fp4, bmm_fp8, bmm_mxfp8, mm_mxfp8, mm_bf16, bmm_bf16)."
),
)
parser.add_argument(
"--bias",
action="store_true",
default=False,
help="Use bias (enabled for mm_bf16 with TGV backend for now)",
)
parser.add_argument(
"--enable_pdl",
action="store_true",
default=False,
help="Enable programmatic dependent launch.",
)
args = parser.parse_args(line)
if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
def testGemmFp8NtGroupwise(args):
"""
Test gemm_fp8_nt_groupwise API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP8
3. Runs gemm_fp8_nt_groupwise
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testGemmFp8NtGroupwise")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = args.backends
m = args.m
n = args.n
k = args.k
tile_size = args.tile_size
scale_major_mode = args.scale_major_mode
mma_sm = args.mma_sm
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
res = []
backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if out_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(f"Unsupported output dtype: {args.out_dtype}")
## Done parsing input arguments
if "trtllm" in backends:
remove_trtllm = False
if scale_major_mode != "MN":
print(
"[INFO] trtllm only supports MN scale_major_mode, removing trtllm from backends"
)
remove_trtllm = True
if k < 256:
print("[INFO] trtllm only supports k >= 256, removing trtllm from backends")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
## Prepare input tensors
a_val = torch.randn((m, k), dtype=torch.float, device=device)
b_val = torch.randn((n, k), dtype=torch.float, device=device) / np.sqrt(k)
if args.verbose >= 2:
print(f"[VVERBOSE] {a_val.shape = }")
print(f"[VVERBOSE] {b_val.shape = }")
if scale_major_mode == "K":
a_scale_shape = (m, k // tile_size)
b_scale_shape = (n // tile_size, k // tile_size)
else:
a_scale_shape = (k // tile_size, m)
b_scale_shape = (k // tile_size, n // tile_size)
a_tile_shape = (1, tile_size)
b_tile_shape = (tile_size, tile_size)
a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode)
b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode)
if args.verbose >= 2:
print(f"[VVERBOSE] {a_fp8.shape = }")
print(f"[VVERBOSE] {b_fp8.shape = }")
print(f"[VVERBOSE] {a_scale.shape = }")
print(f"[VVERBOSE] {b_scale.shape = }")
a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode)
b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode)
def run_backend(backend, a_fp8, b_fp8, a_scale, b_scale):
if backend in ["cutlass", "trtllm"]:
return flashinfer.gemm.gemm_fp8_nt_groupwise(
a=a_fp8,
b=b_fp8,
a_scale=a_scale,
b_scale=b_scale,
scale_major_mode=scale_major_mode,
out_dtype=out_dtype,
mma_sm=mma_sm,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype)
has_reference_output = True
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, a_fp8, b_fp8, a_scale, b_scale
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling.
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, a_fp8, b_fp8, a_scale, b_scale),
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(
reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2
)
if num_different_elements > 0:
print(
f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements"
)
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k
problem_bytes = (m * k + n * k) * torch.float8_e4m3fn.itemsize + (
m * n
) * out_dtype.itemsize
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["tile_size"] = tile_size
cur_res["scale_major_mode"] = scale_major_mode
cur_res["out_dtype"] = out_dtype
cur_res["mma_sm"] = mma_sm
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testGroupGemmFp8NtGroupwise(args):
"""
Test group_gemm_fp8_nt_groupwise API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP8
3. Runs group_gemm_fp8_nt_groupwise
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testGroupGemmFp8NtGroupwise")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = ["cutlass"] # Cutlass is currently the only supported backend
m = args.m
n = args.n
k = args.k
group_size = args.group_size
tile_size = args.tile_size
scale_major_mode = args.scale_major_mode
mma_sm = args.mma_sm
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
res = []
backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if out_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(f"Unsupported output dtype: {args.out_dtype}")
## Done parsing input arguments
## Prepare input tensors
a_val = torch.randn((group_size * m, k), dtype=torch.float, device="cuda")
b_val = torch.randn((group_size, n, k), dtype=torch.float, device="cuda") / np.sqrt(
k
)
if args.verbose >= 2:
print(f"[VVERBOSE] {a_val.shape = }")
print(f"[VVERBOSE] {b_val.shape = }")
if scale_major_mode == "K":
a_scale_shape = (group_size * m, k // tile_size)
b_scale_shape = (group_size, n // tile_size, k // tile_size)
else:
a_scale_shape = (k // tile_size, m * group_size)
b_scale_shape = (group_size, k // tile_size, n // tile_size)
a_tile_shape = (1, tile_size)
b_tile_shape = (1, tile_size, tile_size)
a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode)
b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode)
a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode)
b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode)
m_indptr = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda") * m
if args.verbose >= 2:
print(f"[VVERBOSE] {a_fp8.shape = }")
print(f"[VVERBOSE] {b_fp8.shape = }")
print(f"[VVERBOSE] {a_scale.shape = }")
print(f"[VVERBOSE] {b_scale.shape = }")
print(f"[VVERBOSE] {m_indptr.shape = }")
def run_backend(backend, a_fp8, b_fp8, a_scale, b_scale, m_indptr):
if backend == "cutlass":
return flashinfer.gemm.group_gemm_fp8_nt_groupwise(
a=a_fp8,
b=b_fp8,
a_scale=a_scale,
b_scale=b_scale,
m_indptr=m_indptr,
scale_major_mode=scale_major_mode,
out_dtype=out_dtype,
mma_sm=mma_sm,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = (
einsum(
a_dequant.view((group_size, m, k)), b_dequant, "b m k, b n k -> b m n"
)
.view((group_size * m, n))
.to(out_dtype)
)
has_reference_output = True
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, a_fp8, b_fp8, a_scale, b_scale, m_indptr
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling.
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, a_fp8, b_fp8, a_scale, b_scale, m_indptr),
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(
reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2
)
if num_different_elements > 0:
print(
f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements"
)
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k * group_size
problem_bytes = (
group_size * m * k + group_size * n * k
) * torch.float8_e4m3fn.itemsize + (group_size * m * n) * out_dtype.itemsize
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["group_size"] = group_size
cur_res["tile_size"] = tile_size
cur_res["scale_major_mode"] = scale_major_mode
cur_res["out_dtype"] = out_dtype
cur_res["mma_sm"] = mma_sm
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBmmFp8(args):
"""
Test bmm_fp8 API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP8
3. Runs bmm_fp8
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBmmFp8")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = args.backends
batch_size = args.batch_size
m = args.m
n = args.n
k = args.k
input_dtype = args.input_dtype
mat2_dtype = args.mat2_dtype
res_dtype = args.out_dtype
backends = args.backends
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = [
"cutlass",
]
res = []
backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
if input_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(
f"Unsupported input dtype: {input_dtype}. Supported dtypes are fp8_e4m3 and fp8_e5m2."
)
mat2_dtype = dtype_str_to_torch_dtype(args.mat2_dtype)
if mat2_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(
f"Unsupported mat2 dtype: {mat2_dtype}. Supported dtypes are fp8_e4m3 and fp8_e5m2."
)
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments
if getattr(args, "autotune", False):
backends_to_remove = []
for cur_backend in backends:
if cur_backend not in autotune_supported_backends:
print(f"[INFO] {cur_backend} backend does not support autotune")
backends_to_remove.append(cur_backend)
for cur_backend in backends_to_remove:
backends.remove(cur_backend)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
## Prepare input tensors
input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
mat2 = torch.randn(
[batch_size, n, k], device=device, dtype=torch.bfloat16
).transpose(-2, -1)
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
if args.verbose >= 2:
print(f"[VVERBOSE] {input_fp8.shape = }")
print(f"[VVERBOSE] {input_fp8.dtype = }")
print(f"[VVERBOSE] {mat2_fp8.shape = }")
print(f"[VVERBOSE] {mat2_fp8.dtype = }")
print(f"[VVERBOSE] {input_inv_s = }")
print(f"[VVERBOSE] {input_inv_s.dtype = }")
print(f"[VVERBOSE] {mat2_inv_s = }")
print(f"[VVERBOSE] {mat2_inv_s.dtype = }")
def run_backend(backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s):
if backend in ["cudnn", "cublas", "cutlass"]:
return flashinfer.gemm.bmm_fp8(
A=input_fp8,
B=mat2_fp8,
A_scale=input_inv_s,
B_scale=mat2_inv_s,
dtype=res_dtype,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = torch.bmm(input, mat2)
has_reference_output = True
cache_path = getattr(args, "autotune_cache", None)
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
for cur_backend in backends:
if cur_backend in autotune_supported_backends:
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for bmm_fp8: {warmup_iters} iters")
with autotune(True, cache=cache_path):
for _ in range(warmup_iters):
run_backend(
cur_backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s
)
elif cache_path:
with autotune(False, cache=cache_path):
pass
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=True,
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s),
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print(
"[INFO] Reference output is FP8. Converting to float32 for reference check."
)
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_backends)):
cos_sim = F.cosine_similarity(
reference_output.reshape(-1),
tested_outputs[i].reshape(-1),
dim=0,
)
if cos_sim < 0.99:
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}"
)
for backend in backends:
backend_name = backend + (
"_autotune"
if (
getattr(args, "autotune", False)
and backend in autotune_supported_backends
)
else ""
)
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k * batch_size
problem_bytes = (
m * k * input_dtype.itemsize
+ n * k * mat2_dtype.itemsize
+ m * n * res_dtype.itemsize
)
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["batch_size"] = batch_size
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["input_dtype"] = input_dtype
cur_res["mat2_dtype"] = mat2_dtype
cur_res["out_dtype"] = res_dtype
cur_res["backend"] = backend_name
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBmmMxfp8(args):
"""
Test bmm_mxfp8 API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to MXFP8
3. Runs bmm_mxfp8
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBmmMxfp8")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = args.backends
batch_size = args.batch_size
m = args.m
n = args.n
k = args.k
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = [
"cudnn",
]
res = []
backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments
if getattr(args, "autotune", False):
backends_to_remove = []
for cur_backend in backends:
if cur_backend not in autotune_supported_backends:
print(f"[INFO] {cur_backend} backend does not support autotune")
backends_to_remove.append(cur_backend)
for cur_backend in backends_to_remove:
backends.remove(cur_backend)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
## Prepare input tensors
input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(input, is_sf_swizzled_layout=True)
mat2 = (
torch.randn([batch_size, n, k], device=device, dtype=torch.bfloat16)
.transpose(-2, -1)
.contiguous()
)
mat2_mxfp8, mat2_scale = mxfp8_quantize(mat2, is_sf_swizzled_layout=True)
if args.verbose >= 2:
print(f"[VVERBOSE] {input_mxfp8.shape = }")
print(f"[VVERBOSE] {input_mxfp8.dtype = }")
print(f"[VVERBOSE] {mat2_mxfp8.shape = }")
print(f"[VVERBOSE] {mat2_mxfp8.dtype = }")
print(f"[VVERBOSE] {input_scale.shape = }")
print(f"[VVERBOSE] {input_scale.dtype = }")
print(f"[VVERBOSE] {mat2_scale.shape = }")
print(f"[VVERBOSE] {mat2_scale.dtype = }")
def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend == "cudnn":
return flashinfer.gemm.bmm_mxfp8(
A=input_mxfp8,
B=mat2_mxfp8,
A_scale=input_scale,
B_scale=mat2_scale,
dtype=res_dtype,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = torch.bmm(input, mat2)
has_reference_output = True
cache_path = getattr(args, "autotune_cache", None)
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
for cur_backend in backends:
if cur_backend in autotune_supported_backends:
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for bmm_mxfp8: {warmup_iters} iters")
with autotune(True, cache=cache_path):
for _ in range(warmup_iters):
run_backend(
cur_backend,
input_mxfp8,
mat2_mxfp8,
input_scale,
mat2_scale,
)
elif cache_path:
with autotune(False, cache=cache_path):
pass
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=True,
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale),
)
min_cos_sim = 0.9 # TODO: check if can be increased
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
cos_sim = F.cosine_similarity(
reference_output.reshape(-1),
tested_outputs[i].reshape(-1),
dim=0,
)
if cos_sim < min_cos_sim:
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}"
)
for backend in backends:
backend_name = backend + (
"_autotune"
if (
getattr(args, "autotune", False)
and backend in autotune_supported_backends
)
else ""
)
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k * batch_size
# MXFP8 uses fp8_e4m3fn for data (1 byte) and uint8 for scales
# Scale tensors are much smaller, so approximate as 1 byte per element for simplicity
problem_bytes = (
m * k * torch.float8_e4m3fn.itemsize
+ n * k * torch.float8_e4m3fn.itemsize
+ m * n * res_dtype.itemsize
) * batch_size
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["batch_size"] = batch_size
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["out_dtype"] = res_dtype
cur_res["backend"] = backend_name
cur_res["case_tag"] = args.case_tag