Skip to content

Commit b5d89aa

Browse files
committed
Add check for min m/n/k in MXFP8
1 parent 2ef199e commit b5d89aa

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

flashinfer/gemm/gemm_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2499,6 +2499,16 @@ def _check_mm_mxfp8_problem_size(
24992499
f"K dimension mismatch in mm_mxfp8. got {a.shape[1]=}, {b.shape[0]=}"
25002500
)
25012501

2502+
# The output may contain NaN/Inf if the dimensions are too small
2503+
min_m = 32
2504+
min_n = 128
2505+
min_k = 128
2506+
if a.shape[0] < min_m or b.shape[1] < min_n or a.shape[1] < min_k:
2507+
raise ValueError(
2508+
f"MXFP8 requires m >= {min_m}, n >= {min_n}, k >= {min_k} for CUTLASS MXFP8. "
2509+
f"got m={a.shape[0]}, n={b.shape[1]}, k={a.shape[1]}."
2510+
)
2511+
25022512
# Input dtype as returned by mxfp8_quantize_sm100
25032513
if a.dtype != torch.float8_e4m3fn:
25042514
raise ValueError(f"a must be a float8_e4m3fn tensor, got {a.dtype=}")

0 commit comments

Comments
 (0)