Skip to content

[Bug] [MXFP8 Online] AssertionError: n=64 must be divisible by 128 #18277

@vincentzed

Description

@vincentzed

Checklist

  • I searched related issues but found no solution.
  • The bug persists in the latest version.
  • Issues without environment info and a minimal reproducible demo are hard to resolve and may receive no feedback.
  • If this is not a bug report but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.
  • Please use English. Otherwise, it will be closed.

Describe the bug

Loading safetensors checkpoint shards:  98% Completed | 40/41 [05:06<00:07,  7.30s/it]
Loading safetensors checkpoint shards: 100% Completed | 41/41 [05:14<00:00,  7.38s/it]
Loading safetensors checkpoint shards: 100% Completed | 41/41 [05:14<00:00,  7.67s/it]

[2026-02-05 03:01:29] Loading weights took 315.10 seconds
[2026-02-05 03:01:29] Load weight end. elapsed=316.39 s, type=Qwen3NextForCausalLM, dtype=torch.bfloat16, avail mem=188.87 GB, mem usage=77.95 GB.
[2026-02-05 03:01:29] Using KV cache dtype: torch.bfloat16
[2026-02-05 03:01:29] Mamba Cache is allocated. max_mamba_cache_size: 1076, conv_state size: 1.77GB, ssm_state size: 75.73GB 
[2026-02-05 03:01:29] KV Cache is allocated. #tokens: 3763592, K size: 43.07 GB, V size: 43.07 GB
[2026-02-05 03:01:29] Memory pool end. avail mem=25.27 GB
[2026-02-05 03:01:35] Init attention backend begin.
[2026-02-05 03:01:35] Using hybrid linear attention backend for hybrid GDN models.
[2026-02-05 03:01:35] CuTe DSL GDN decode enabled: False
[2026-02-05 03:01:35] Init attention backend end. elapsed=0.01 s
[2026-02-05 03:01:35] Capture cuda graph begin. This can take up to several minutes. avail mem=25.21 GB
[2026-02-05 03:01:35] Capture cuda graph bs [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 358]
Capturing batches (bs=358 avail_mem=24.22 GB):   0%|                                                                                                                                          | 0/43 [00:02<?, ?it/s]
[2026-02-05 03:01:38] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3054, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 350, in __init__
    self.init_model_worker()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 544, in init_model_worker
    self.init_tp_model_worker()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 506, in init_tp_model_worker
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 242, in __init__
    self._init_model_runner()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 325, in _init_model_runner
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 391, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 588, in initialize
    self.init_device_graphs()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2042, in init_device_graphs
    self.graph_runner = graph_runners[self.device](self)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 366, in __init__
    self.capture()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 522, in capture
    _capture_one_stream()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 509, in _capture_one_stream
    ) = self.capture_one_batch_size(bs, forward, stream_idx)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 728, in capture_one_batch_size
    run_once()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 715, in run_once
    logits_output_or_pp_proxy_tensors = forward(
                                        ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 971, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 889, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 565, in forward
    hidden_states = self.linear_attn(
                    ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 417, in forward
    return self._forward(hidden_states, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 427, in _forward
    projected_states_qkvz, projected_states_ba = self._forward_input_proj(
                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 396, in _forward_input_proj
    projected_states_ba, _ = self.in_proj_ba(hidden_states)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/linear.py", line 451, in forward
    output_parallel = self.quant_method.apply(self, input_, bias)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 608, in apply
    return triton_mxfp8_blockscaled_linear(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 679, in triton_mxfp8_blockscaled_linear
    assert n % block_n == 0, f"{n=} must be divisible by {block_n}"
           ^^^^^^^^^^^^^^^^
AssertionError: n=64 must be divisible by 128

Reproduction

❯ python -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --quantization mxfp8 --fp8-gemm-backend triton --moe-runner-backend cutlass

Environment

python3 -m sglang.check_env
Python: 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA B300 SXM6 AC
GPU 0,1,2,3,4,5,6,7 Compute Capability: 10.3
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 13.0, V13.0.88
CUDA Driver Version: 580.82.07
PyTorch: 2.9.1+cu130
sglang: 0.0.0.dev1+g28e234072
sgl_kernel: 0.3.21
flashinfer_python: 0.6.2
flashinfer_cubin: 0.6.2
flashinfer_jit_cache: 0.6.2+cu130
triton: 3.5.1
transformers: 4.57.6
torchao: 0.9.0
numpy: 2.4.2
aiohttp: 3.13.3
fastapi: 0.128.0
hf_transfer: 0.1.9
huggingface_hub: 0.36.1
interegular: 0.3.3
modelscope: 1.34.0
orjson: 3.11.7
outlines: 0.1.11
packaging: 26.0
psutil: 7.2.2
pydantic: 2.12.5
python-multipart: 0.0.22
pyzmq: 27.1.0
uvicorn: 0.40.0
uvloop: 0.22.1
vllm: Module Not Found
xgrammar: 0.1.27
openai: 2.6.1
tiktoken: 0.12.0
anthropic: 0.77.0
litellm: Module Not Found
decord2: 3.0.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    0-239   0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    0-239   0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    0-239   0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    0-239   0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    0-239   0               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    0-239   0               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    0-239   0               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      0-239   0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Hypervisor vendor:: KVM
ulimit soft: 1048576

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions