Skip to content

[Bug] Autotune Fails for W4A8 in cutlass_fused_moe #2501

@gz944367214

Description

@gz944367214

flashinfer version: v0.6.1

Description:​
I encountered a runtime error when running the cutlass_fused_moe kernel with autotune enabled for W4A8. The issue occurs specifically when the autotune()context manager is active. And it will not happened in fp8_block_scaling and bf16_mxfp4.

Steps to Reproduce:

  1. Run the unit test tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_w4a8 without autotune — it passes correctly.
  2. Wrap the same cutlass_fused_moe call within an autotune()context:
    with autotune():
        _ = fused_moe.cutlass_fused_moe(
            x,
            selected_experts_int32,
            routing_weights,
            fc1_weights.view(torch.uint8),
            fc2_weights.view(torch.uint8),
            dtype,
            quant_scales=quant_scales,
            use_w4_group_scaling=True,
            output=flash_output,
            use_packed_weights=True,
        )

Observed Error

        with autotune():
>           _ = fused_moe.cutlass_fused_moe(
                x,
                selected_experts_int32,
                routing_weights,
                fc1_weights.view(torch.uint8),
                fc2_weights.view(torch.uint8),
                dtype,
                quant_scales=quant_scales,
                use_w4_group_scaling=True,
                output=flash_output,
                use_packed_weights=True,
            )

tests/moe/test_trtllm_cutlass_fused_moe.py:1590: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py:898: in cutlass_fused_moe
    return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe(
/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py:555: in cutlass_fused_moe
    _, gemm_tactic_1 = tuner.choose_one(
/usr/local/lib/python3.12/dist-packages/flashinfer/autotuner.py:482: in choose_one
    r(tensors, tactic=-1, do_preparation=True, **kwargs)
/usr/local/lib/python3.12/dist-packages/flashinfer/autotuner.py:217: in __call__
    return self.forward(inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py:455: in forward
    self.fused_moe_runner.run_gemm_profile(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   ???
E   RuntimeError: [TensorRT-LLM][ERROR] Assertion failed: quant_1 && quant_2 (/workspace/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh:4530)
E   1       0x7f47e2a9135f tensorrt_llm::common::throwRuntimeError(char const*, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 74
E   2       0x7f47e2aa132e /usr/local/lib/python3.12/dist-packages/flashinfer_jit_cache/jit_cache/fused_moe_90/fused_moe_90.so(+0x1ff32e) [0x7f47e2aa132e]
E   3       0x7f47e30ade37 tensorrt_llm::kernels::cutlass_kernels::GemmProfilerBackend::prepare(int, char*, void const*, bool, CUstream_st*) + 103
E   4       0x7f47e303e73f /usr/local/lib/python3.12/dist-packages/flashinfer_jit_cache/jit_cache/fused_moe_90/fused_moe_90.so(+0x79c73f) [0x7f47e303e73f]
E   5       0x7f47e3073280 /usr/local/lib/python3.12/dist-packages/flashinfer_jit_cache/jit_cache/fused_moe_90/fused_moe_90.so(+0x7d1280) [0x7f47e3073280]
E   6       0x7f47e303abb3 tvm::ffi::details::FunctionObjImpl<tvm::ffi::Function::FromTyped<FusedMoeRunner::GetFunction(tvm::ffi::String const&)::{lambda(tvm::ffi::TensorView, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, long, long, long, long, long, long, long, bool, bool, long, long, bool, bool, long)#1}>(FusedMoeRunner::GetFunction(tvm::ffi::String const&)::{lambda(tvm::ffi::TensorView, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, long, long, long, long, long, long, long, bool, bool, long, long, bool, bool, long)#1}&&)::{lambda(tvm::ffi::AnyView const*, int, tvm::ffi::Any*)#1}>::SafeCall(void*, TVMFFIAny const*, int, TVMFFIAny*) + 883
E   7       0x7f4c375e584c /usr/local/lib/python3.12/dist-packages/tvm_ffi/core.abi3.so(+0x5484c) [0x7f4c375e584c]
E   8             0x548e25 _PyObject_MakeTpCall + 117
E   9             0x5d71d9 _PyEval_EvalFrameDefault + 2697
E   10            0x54ca34 /usr/bin/python3.12() [0x54ca34]
E   11            0x54b055 PyObject_Call + 277
E   12            0x5db2ca _PyEval_EvalFrameDefault + 19322
E   13            0x54a73a _PyObject_Call_Prepend + 394
E   14            0x5a3398 /usr/bin/python3.12() [0x5a3398]
E   15            0x54afac PyObject_Call + 108
E   16            0x5db2ca _PyEval_EvalFrameDefault + 19322
E   17            0x54a73a _PyObject_Call_Prepend + 394
E   18            0x5a3398 /usr/bin/python3.12() [0x5a3398]
E   19            0x548eee _PyObject_MakeTpCall + 318
E   20            0x5d71d9 _PyEval_EvalFrameDefault + 2697
E   21            0x54a73a _PyObject_Call_Prepend + 394
E   22            0x5a3398 /usr/bin/python3.12() [0x5a3398]
E   23            0x54afac PyObject_Call + 108
E   24            0x5db2ca _PyEval_EvalFrameDefault + 19322
E   25            0x54a73a _PyObject_Call_Prepend + 394
E   26            0x5a3398 /usr/bin/python3.12() [0x5a3398]
E   27            0x548eee _PyObject_MakeTpCall + 318
E   28            0x5d71d9 _PyEval_EvalFrameDefault + 2697
E   29            0x54a73a _PyObject_Call_Prepend + 394
E   30            0x5a3398 /usr/bin/python3.12() [0x5a3398]
E   31            0x548eee _PyObject_MakeTpCall + 318
E   32            0x5d71d9 _PyEval_EvalFrameDefault + 2697
E   33            0x54a73a _PyObject_Call_Prepend + 394
E   34            0x5a3398 /usr/bin/python3.12() [0x5a3398]
E   35            0x548eee _PyObject_MakeTpCall + 318
E   36            0x5d71d9 _PyEval_EvalFrameDefault + 2697
E   37            0x5d571b PyEval_EvalCode + 347
E   38            0x6084c2 /usr/bin/python3.12() [0x6084c2]
E   39            0x6b44f3 /usr/bin/python3.12() [0x6b44f3]
E   40            0x6b425a _PyRun_SimpleFileObject + 426
E   41            0x6b408f _PyRun_AnyFileObject + 79
E   42            0x6bc0f5 Py_RunMain + 949
E   43            0x6bbbdd Py_BytesMain + 45
E   44      0x7f4c847c21ca /usr/lib/x86_64-linux-gnu/libc.so.6(+0x2a1ca) [0x7f4c847c21ca]
E   45      0x7f4c847c228b __libc_start_main + 139
E   46            0x657005 _start + 37

python/tvm_ffi/cython/function.pxi:923: RuntimeError

Additional Context:

  1. Other tests (e.g., test_moe_bf16_mxfp4, test_moe_fp8_block_scaling) run successfully with autotune enabled. The failure appears specific to W4A8 when combined with autotuning.

Question:​
Is there something wrong with w4a8?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions