Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cutlass"],
"10.3": ["cutlass"],
"10.0": ["cutlass", "cute-dsl"],
"10.3": ["cutlass", "cute-dsl"],
"11.0": ["cutlass"],
"12.0": [],
},
Expand Down
18 changes: 12 additions & 6 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,11 @@ def testMmFp4(args):
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
alpha=(
torch.tensor([1.0], dtype=torch.float32, device=device)
if (not use_nvfp4 and backend == "cute-dsl")
else alpha
),
Comment thread
b8zhong marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
out_dtype=res_dtype,
block_size=16
if use_nvfp4
Expand Down Expand Up @@ -1129,7 +1133,11 @@ def run_backend(
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
alpha=(
torch.tensor([1.0], dtype=torch.float32, device=device)
if (not use_nvfp4 and backend == "cute-dsl")
else alpha
),
out_dtype=res_dtype,
block_size=block_size,
use_8x4_sf_layout=not use_128x4_sf_layout,
Expand Down Expand Up @@ -1289,9 +1297,7 @@ def testMmMxfp8(args):
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = [
"cutlass",
]
autotune_supported_backends = ["cutlass", "cute-dsl"]
Comment thread
b8zhong marked this conversation as resolved.
Outdated
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
Expand Down Expand Up @@ -1344,7 +1350,7 @@ def testMmMxfp8(args):
print(f"[VVERBOSE] {mat2_scale.dtype = }")

def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend == "cutlass":
if backend in ["cutlass", "cute-dsl", "auto"]:
return flashinfer.gemm.mm_mxfp8(
a=input_mxfp8,
b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t()
Expand Down
Loading
Loading