We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2ef199e commit b5d89aaCopy full SHA for b5d89aa
1 file changed
flashinfer/gemm/gemm_base.py
@@ -2499,6 +2499,16 @@ def _check_mm_mxfp8_problem_size(
2499
f"K dimension mismatch in mm_mxfp8. got {a.shape[1]=}, {b.shape[0]=}"
2500
)
2501
2502
+ # The output may contain NaN/Inf if the dimensions are too small
2503
+ min_m = 32
2504
+ min_n = 128
2505
+ min_k = 128
2506
+ if a.shape[0] < min_m or b.shape[1] < min_n or a.shape[1] < min_k:
2507
+ raise ValueError(
2508
+ f"MXFP8 requires m >= {min_m}, n >= {min_n}, k >= {min_k} for CUTLASS MXFP8. "
2509
+ f"got m={a.shape[0]}, n={b.shape[1]}, k={a.shape[1]}."
2510
+ )
2511
+
2512
# Input dtype as returned by mxfp8_quantize_sm100
2513
if a.dtype != torch.float8_e4m3fn:
2514
raise ValueError(f"a must be a float8_e4m3fn tensor, got {a.dtype=}")
0 commit comments