Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cutlass"],
"10.3": ["cutlass"],
"10.0": ["cutlass", "cute-dsl"],
"10.3": ["cutlass", "cute-dsl"],
"11.0": ["cutlass"],
"12.0": [],
},
Expand Down
17 changes: 11 additions & 6 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,13 @@ def testMmFp4(args):
print(f"[VVERBOSE] {mat2_fp4.dtype = }")

alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
# TODO: for MXFP4, we don't need a global scale, we should change the compile interface to make
# alpha optional.
alpha_for_cute_dsl_mxfp4 = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why mxfp4 cutedsl backend have to be using a device tensor with value 1.0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I removed it, I encountered some compilation relating to make_fake_compact_tensor(cutlass.Float32, (1,)), as I believe they still share the exact same code path. To avoid touching too much cute-dsl which I am not super familiar with, I left it as passing it 1.0 for now. But I also think it can be removed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment something like TODO so that in the future people could understand why you add this and we should get it removed in the future?

torch.tensor([1.0], dtype=torch.float32, device=device)
if not use_nvfp4
else None
)
# Completed preparing inputs. Now programmatically filter backends
block_size = 16 if use_nvfp4 else 32
backends_to_remove = []
Expand All @@ -1091,7 +1098,7 @@ def testMmFp4(args):
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
alpha=(alpha_for_cute_dsl_mxfp4 if (backend == "cute-dsl") else alpha),
out_dtype=res_dtype,
block_size=16
if use_nvfp4
Expand Down Expand Up @@ -1129,7 +1136,7 @@ def run_backend(
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
alpha=(alpha_for_cute_dsl_mxfp4 if (backend == "cute-dsl") else alpha),
out_dtype=res_dtype,
block_size=block_size,
use_8x4_sf_layout=not use_128x4_sf_layout,
Expand Down Expand Up @@ -1289,9 +1296,7 @@ def testMmMxfp8(args):
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = [
"cutlass",
]
autotune_supported_backends = ["cutlass", "cute-dsl", "auto"]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
Expand Down Expand Up @@ -1344,7 +1349,7 @@ def testMmMxfp8(args):
print(f"[VVERBOSE] {mat2_scale.dtype = }")

def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend == "cutlass":
if backend in ["cutlass", "cute-dsl", "auto"]:
return flashinfer.gemm.mm_mxfp8(
a=input_mxfp8,
b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t()
Expand Down
Loading
Loading