Skip to content

[Bug] test_groupwise_scaled_gemm_mxfp4.py non-deterministically fails unit test on B300 #2514

@bkryu

Description

@bkryu

On B300 (SM103), the unit test test_groupwise_scaled_gemm_mxfp4.py non-deterministically fails unit tests with nans in the cu130 container.

==========================================
Running:  pytest --continue-on-collection-errors "tests/gemm/test_groupwise_scaled_gemm_mxfp4.py"
==========================================
============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 3456 items
tests/gemm/test_groupwise_scaled_gemm_mxfp4.py ......................... [  0%]
........................................................................ [  2%]
........................................................................ [  4%]
...
........................................................................ [ 56%]
........................................................................ [ 59%]
.............................................F.......................... [ 61%]
........................................................................ [ 63%]
...
........................................................................ [ 98%]
...............................................                          [100%]
=================================== FAILURES ===================================
_ test_mxfp8_mxfp4_groupwise_group_gemm[out_dtype1-fp8_dtype0-2-2880-8192-4096] _
m = 4096, n = 8192, k = 2880, group_size = 2, fp8_dtype = torch.float8_e4m3fn
out_dtype = torch.float16
    @pytest.mark.parametrize("m", [4, 128, 256, 512, 4096, 8192])
    @pytest.mark.parametrize("n", [128, 256, 512, 2879, 4096, 8192])
    @pytest.mark.parametrize("k", [128, 256, 512, 2880, 4096, 8192])
    @pytest.mark.parametrize("group_size", [1, 2, 4, 8])
    @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
    @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
    def test_mxfp8_mxfp4_groupwise_group_gemm(
        m,
        n,
        k,
        group_size,
        fp8_dtype,
        out_dtype,
    ):
        compute_capability = get_compute_capability(torch.device(device="cuda"))
        # TODO: We need to add gemm_mxfp4_nt_groupwise support for sm120/121 at some point.
        if compute_capability[0] not in [10]:
            pytest.skip(
                "gemm_mxfp4_nt_groupwise is only supported on SM100 and SM103 GPUs."
            )
        torch.random.manual_seed(0)
        tile_size = 32
        alignment_n = 8
        alignment_k = 128
    
        a_val = torch.randn((group_size * m, k), dtype=torch.float32, device="cuda")
        b_val = torch.randn(
            (group_size, n, k), dtype=torch.float32, device="cuda"
        ) / math.sqrt(k)
        n_padded = (n + alignment_n - 1) // alignment_n * alignment_n
        k_padded = (k + alignment_k - 1) // alignment_k * alignment_k
    
        if fp8_dtype == torch.float8_e4m3fn:
            a_quant_mode = QuantMode.MXFP8_E4M3
        elif fp8_dtype == torch.float8_e5m2:
            a_quant_mode = QuantMode.MXFP8_E5M2
        else:
            raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
        a_fp8, a_scale = quantize_tensor(a_val, tile_size, None, k_padded, a_quant_mode)
        b_fp4, b_scale = quantize_tensor(
            b_val, tile_size, n_padded, k_padded, QuantMode.MXFP4
        )
    
        a_scale_swizzled = swizzle_blockscale(
            a_scale.unflatten(0, (group_size, m)), group_size, m, k_padded, tile_size
        ).flatten(0, 1)
        b_scale_swizzled = swizzle_blockscale(
            b_scale, group_size, n_padded, k_padded, tile_size
        )
    
        group_arange = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda")
        m_indptr = group_arange * m
    
        # Pad a_scale_swizzled according to the function compute_sm100_cutlass_group_gemm_args
        # in group_gemm_mxfp4_groupwise_sm100.cuh
        alignment_m_sf = 128
        m_indptr_padded = (
            (m_indptr + group_arange * (alignment_m_sf - 1))
            // alignment_m_sf
            * alignment_m_sf
        )
        m_sf = m_indptr_padded[1:] - m_indptr_padded[:-1]
        a_scale_chunked = a_scale_swizzled.chunk(group_size, dim=0)
        a_scale_chunked = [
            torch.cat(
                [
                    x,
                    torch.zeros(
                        m_sf[i] - x.shape[0], *x.shape[1:], dtype=x.dtype, device=x.device
                    ),
                ]
            )
            for i, x in enumerate(a_scale_chunked)
        ]
        a_scale_swizzled = torch.cat(a_scale_chunked)
    
        out_ref = torch.empty((group_size * m, n), dtype=out_dtype, device="cuda")
        for i in range(group_size):
            out_ref[m * i : m * (i + 1)] = gemm_mxfp8_mxfp4_nt_groupwise_ref(
                a_fp8[m * i : m * (i + 1)],
                b_fp4[i],
                a_scale[m * i : m * (i + 1)],
                b_scale[i],
                tile_size,
                n,
                k,
                out_dtype,
            )
    
        mma_sm_list = [1, 2]
        tile_m_list = [128]
        tile_n_list = [64, 128, 192, 256]
        tile_k_list = [128, 256]
        swap_ab_list = [True, False]
        for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(
            mma_sm_list, tile_m_list, tile_n_list, tile_k_list, swap_ab_list
        ):
            out = group_gemm_mxfp4_nt_groupwise(
                a_fp8,
                b_fp4,
                a_scale_swizzled,
                b_scale_swizzled,
                m_indptr,
                mma_sm=mma_sm,
                tile_m=tile_m,
                tile_n=tile_n,
                tile_k=tile_k,
                swap_ab=swap_ab,
                out_dtype=out_dtype,
            )[:, :n]
>           torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 228311 / 67108864 (0.3%)
E           Greatest absolute difference: nan at index (4161, 3982) (up to 0.01 allowed)
E           Greatest relative difference: nan at index (4161, 3982) (up to 0.01 allowed)
tests/gemm/test_groupwise_scaled_gemm_mxfp4.py:353: AssertionError
- generated xml file: /workspace/junit/tests_gemm_test_groupwise_scaled_gemm_mxfp4.py.xml -
=========================== short test summary info ============================
FAILED tests/gemm/test_groupwise_scaled_gemm_mxfp4.py::test_mxfp8_mxfp4_groupwise_group_gemm[out_dtype1-fp8_dtype0-2-2880-8192-4096]
================== 1 failed, 3455 passed in 79.44s (0:01:19) ===================

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingop: gemm

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions