-
Notifications
You must be signed in to change notification settings - Fork 707
Open
Labels
Description
On B300 (SM103), the unit test test_groupwise_scaled_gemm_mxfp4.py non-deterministically fails unit tests with nans in the cu130 container.
==========================================
Running: pytest --continue-on-collection-errors "tests/gemm/test_groupwise_scaled_gemm_mxfp4.py"
==========================================
============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 3456 items
tests/gemm/test_groupwise_scaled_gemm_mxfp4.py ......................... [ 0%]
........................................................................ [ 2%]
........................................................................ [ 4%]
...
........................................................................ [ 56%]
........................................................................ [ 59%]
.............................................F.......................... [ 61%]
........................................................................ [ 63%]
...
........................................................................ [ 98%]
............................................... [100%]
=================================== FAILURES ===================================
_ test_mxfp8_mxfp4_groupwise_group_gemm[out_dtype1-fp8_dtype0-2-2880-8192-4096] _
m = 4096, n = 8192, k = 2880, group_size = 2, fp8_dtype = torch.float8_e4m3fn
out_dtype = torch.float16
@pytest.mark.parametrize("m", [4, 128, 256, 512, 4096, 8192])
@pytest.mark.parametrize("n", [128, 256, 512, 2879, 4096, 8192])
@pytest.mark.parametrize("k", [128, 256, 512, 2880, 4096, 8192])
@pytest.mark.parametrize("group_size", [1, 2, 4, 8])
@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_mxfp8_mxfp4_groupwise_group_gemm(
m,
n,
k,
group_size,
fp8_dtype,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
# TODO: We need to add gemm_mxfp4_nt_groupwise support for sm120/121 at some point.
if compute_capability[0] not in [10]:
pytest.skip(
"gemm_mxfp4_nt_groupwise is only supported on SM100 and SM103 GPUs."
)
torch.random.manual_seed(0)
tile_size = 32
alignment_n = 8
alignment_k = 128
a_val = torch.randn((group_size * m, k), dtype=torch.float32, device="cuda")
b_val = torch.randn(
(group_size, n, k), dtype=torch.float32, device="cuda"
) / math.sqrt(k)
n_padded = (n + alignment_n - 1) // alignment_n * alignment_n
k_padded = (k + alignment_k - 1) // alignment_k * alignment_k
if fp8_dtype == torch.float8_e4m3fn:
a_quant_mode = QuantMode.MXFP8_E4M3
elif fp8_dtype == torch.float8_e5m2:
a_quant_mode = QuantMode.MXFP8_E5M2
else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
a_fp8, a_scale = quantize_tensor(a_val, tile_size, None, k_padded, a_quant_mode)
b_fp4, b_scale = quantize_tensor(
b_val, tile_size, n_padded, k_padded, QuantMode.MXFP4
)
a_scale_swizzled = swizzle_blockscale(
a_scale.unflatten(0, (group_size, m)), group_size, m, k_padded, tile_size
).flatten(0, 1)
b_scale_swizzled = swizzle_blockscale(
b_scale, group_size, n_padded, k_padded, tile_size
)
group_arange = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda")
m_indptr = group_arange * m
# Pad a_scale_swizzled according to the function compute_sm100_cutlass_group_gemm_args
# in group_gemm_mxfp4_groupwise_sm100.cuh
alignment_m_sf = 128
m_indptr_padded = (
(m_indptr + group_arange * (alignment_m_sf - 1))
// alignment_m_sf
* alignment_m_sf
)
m_sf = m_indptr_padded[1:] - m_indptr_padded[:-1]
a_scale_chunked = a_scale_swizzled.chunk(group_size, dim=0)
a_scale_chunked = [
torch.cat(
[
x,
torch.zeros(
m_sf[i] - x.shape[0], *x.shape[1:], dtype=x.dtype, device=x.device
),
]
)
for i, x in enumerate(a_scale_chunked)
]
a_scale_swizzled = torch.cat(a_scale_chunked)
out_ref = torch.empty((group_size * m, n), dtype=out_dtype, device="cuda")
for i in range(group_size):
out_ref[m * i : m * (i + 1)] = gemm_mxfp8_mxfp4_nt_groupwise_ref(
a_fp8[m * i : m * (i + 1)],
b_fp4[i],
a_scale[m * i : m * (i + 1)],
b_scale[i],
tile_size,
n,
k,
out_dtype,
)
mma_sm_list = [1, 2]
tile_m_list = [128]
tile_n_list = [64, 128, 192, 256]
tile_k_list = [128, 256]
swap_ab_list = [True, False]
for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(
mma_sm_list, tile_m_list, tile_n_list, tile_k_list, swap_ab_list
):
out = group_gemm_mxfp4_nt_groupwise(
a_fp8,
b_fp4,
a_scale_swizzled,
b_scale_swizzled,
m_indptr,
mma_sm=mma_sm,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
swap_ab=swap_ab,
out_dtype=out_dtype,
)[:, :n]
> torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 228311 / 67108864 (0.3%)
E Greatest absolute difference: nan at index (4161, 3982) (up to 0.01 allowed)
E Greatest relative difference: nan at index (4161, 3982) (up to 0.01 allowed)
tests/gemm/test_groupwise_scaled_gemm_mxfp4.py:353: AssertionError
- generated xml file: /workspace/junit/tests_gemm_test_groupwise_scaled_gemm_mxfp4.py.xml -
=========================== short test summary info ============================
FAILED tests/gemm/test_groupwise_scaled_gemm_mxfp4.py::test_mxfp8_mxfp4_groupwise_group_gemm[out_dtype1-fp8_dtype0-2-2880-8192-4096]
================== 1 failed, 3455 passed in 79.44s (0:01:19) ===================
Reactions are currently unavailable