Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def parse_gemm_args(line, parser):
required=False,
nargs="+",
default=["cudnn"],
choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"],
choices=["cudnn", "cublas", "trtllm", "cutlass", "cute-dsl", "auto"],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def testMmFp4(args):
run_refcheck = args.refcheck
use_128x4_sf_layout = args.use_128x4_sf_layout
use_nvfp4 = args.use_nvfp4
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "cute-dsl", "auto"]
res = []

res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
Expand Down Expand Up @@ -1104,7 +1104,7 @@ def run_backend(
mat2_inv_s,
mat2_inv_s_trtllm,
):
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
if backend in ["cudnn", "trtllm", "cutlass", "cute-dsl", "auto"]:
return flashinfer.gemm.mm_fp4(
a=input_fp4,
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
Expand Down
Loading
Loading