Skip to content

[Fix][Qwen3.5] Pass max_mamba_cache_size to mamba pool in disaggregation decode path#19002

Open
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234:qwen3.5_disagg
Open

[Fix][Qwen3.5] Pass max_mamba_cache_size to mamba pool in disaggregation decode path#19002
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234:qwen3.5_disagg

Conversation

@YAMY1234
Copy link
Contributor

@YAMY1234 YAMY1234 commented Feb 19, 2026

Motivation

HybridMambaDecodeReqToTokenPool (used exclusively in PD disaggregation decode) ignores the --max-mamba-cache-size setting when initializing its mamba pool. The original code hardcodes size (= max_num_reqs) as the mamba pool capacity, whereas the non-disaggregation path (HybridReqToTokenPool) correctly accepts a separate mamba_size parameter (= max_mamba_cache_size).

With mamba-scheduler-strategy: extra_buffer, each request consumes 3 mamba pool slots (1 main + 2 ping-pong). Would set max-mamba-cache-size = 3 × max-running-requests to provide enough slots, but in disagg decode the configured value is silently discarded — the mamba pool is always allocated at max_num_reqs size, causing:

 File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/decode.py", line 667, in _pre_alloc
    req_pool_indices = self.req_to_token_pool.alloc([req])
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 526, in alloc
    req.mamba_ping_pong_track_buffer is not None
AssertionError: Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio.

For example, with --max-mamba-cache-size 750 --max-running-requests 250, the decode mamba pool is still created with only 250 slots (confirmed by logs: Mamba Cache is allocated. max_mamba_cache_size: 250), which can serve at most 83 concurrent requests before exhausting mamba slots.

Modifications

  • decode.py: Add mamba_size parameter to HybridMambaDecodeReqToTokenPool.__init__. Use it for the mamba pool capacity (size) in _init_mamba_pool, while keeping mamba_spec_state_size at size + pre_alloc_size (aligned with the non-disagg path where speculative intermediate buffers are sized by max_num_reqs, not max_mamba_cache_size).
  • model_runner_kv_cache_mixin.py: Pass mamba_size=self.server_args.max_mamba_cache_size when constructing HybridMambaDecodeReqToTokenPool.

Accuracy Tests

Qwen3.5-397B-A17B-FP8, 1P1D PD disaggregation, TP4, extra_buffer strategy, GB200:

Run max-mamba-cache-size Code Mamba Pool Allocated Result
948491 Not set (mamba-full-memory-ratio: 2.5) Before fix 250 (ignores auto-calc) CRASH
949024 750 (explicit) Before fix 250 (ignores 750) CRASH
949062 750 (explicit) After fix 750 (expected) GPQA 0.875
====================
Repeat: 8, mean: 0.875
Scores: ['0.869', '0.874', '0.864', '0.879', '0.874', '0.874', '0.864', '0.904']
====================

Repro config (on GB200):

  prefill_environment:
    TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
    PYTHONUNBUFFERED: "1"
    NCCL_MNNVL_ENABLE: "1"
    NCCL_CUMEM_ENABLE: "1"
    MC_FORCE_MNNVL: "1"
    SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
    FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
    SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
    SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
    SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
    SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
    SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
    SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

  decode_environment:
    TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
    PYTHONUNBUFFERED: "1"
    NCCL_MNNVL_ENABLE: "1"
    NCCL_CUMEM_ENABLE: "1"
    MC_FORCE_MNNVL: "1"
    SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
    FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
    SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
    SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
    SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
    SGLANG_DECODE_BOOTSTRAP_TIMEOUT: "1000"
    SGLANG_HACK_SEQ_BOOTSTRAP_ROOM: "1"
    SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
    SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
    SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

    prefill:
      served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
      model-path: "/model/"
      trust-remote-code: true

      # Parallelism
      tensor-parallel-size: 4
      data-parallel-size: 1
      expert-parallel-size: 1

      # Mamba hybrid model settings
      mamba-scheduler-strategy: "extra_buffer"
      mamba-track-interval: 128
      mamba-ssm-dtype: "bfloat16"

      # PD disaggregation
      disaggregation-mode: "prefill"
      disable-radix-cache: true

      # Memory: match AGG config (0.75), let system auto-calculate max_running_requests
      mem-fraction-static: 0.75
      chunked-prefill-size: 2048
      context-length: 262144

      load-balance-method: "round_robin"
      watchdog-timeout: 1000000

    decode:
      served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
      model-path: "/model/"
      trust-remote-code: true

      # Parallelism
      tensor-parallel-size: 4
      data-parallel-size: 1
      expert-parallel-size: 1

      # Mamba hybrid model settings
      mamba-scheduler-strategy: "extra_buffer"
      mamba-track-interval: 128
      mamba-ssm-dtype: "bfloat16"
      # extra_buffer needs 3 mamba pool slots per request (1 main + 2 ping-pong).
      max-mamba-cache-size: 750
      max-running-requests: 250

      # PD disaggregation
      disaggregation-mode: "decode"
      disable-radix-cache: true

      # Memory
      mem-fraction-static: 0.75
      chunked-prefill-size: 2048
      context-length: 262144

      watchdog-timeout: 1000000

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug in the disaggregation decode path where the mamba memory pool was not correctly sized according to the --max-mamba-cache-size configuration. Previously, the pool would default to max_num_reqs, causing out-of-memory errors and system crashes, especially with the extra_buffer scheduling strategy. The changes ensure that the mamba pool is initialized with the intended capacity, preventing resource exhaustion and improving the stability and performance of disaggregated decoding.

Highlights

  • Mamba Cache Size Fix: Addressed an issue where HybridMambaDecodeReqToTokenPool in the disaggregation decode path ignored the --max-mamba-cache-size setting, leading to incorrect mamba pool allocation.
  • Crash Prevention: Resolved crashes occurring when using the extra_buffer mamba scheduler strategy due to insufficient mamba pool slots.
  • Correct Resource Allocation: Ensured that the mamba pool capacity is now correctly configured based on max_mamba_cache_size, aligning its behavior with the non-disaggregation path.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/decode.py
    • Added mamba_size parameter to the HybridMambaDecodeReqToTokenPool constructor.
    • Updated the _init_mamba_pool call to use the provided mamba_size for the pool's capacity, falling back to size if mamba_size is not specified.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Modified the init_memory_pool method to pass self.server_args.max_mamba_cache_size as the mamba_size argument when instantiating HybridMambaDecodeReqToTokenPool.
Activity
  • Code formatted according to pre-commit hooks.
  • Accuracy benchmark results provided, showing successful execution and expected performance after the fix.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug where max_mamba_cache_size was ignored in the disaggregation decode path, leading to crashes. The changes introduce the mamba_size parameter and use it to properly configure the mamba pool, which resolves the issue as demonstrated by the provided test results. I have one suggestion to refine the calculation of the mamba pool size for better clarity and to prevent potential memory over-allocation when pre_alloc_size is used.

@YAMY1234 YAMY1234 changed the title [Fix] Pass max_mamba_cache_size to mamba pool in disaggregation decode path [Fix][Qwen3.5] Pass max_mamba_cache_size to mamba pool in disaggregation decode path Feb 19, 2026
@YAMY1234
Copy link
Contributor Author

/tag-and-rerun-ci

speculative_num_draft_tokens: int,
enable_mamba_extra_buffer: bool,
pre_alloc_size: int,
mamba_size: int = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird that we have a size and a mamba size for HybridMambaDecodeReqToTokenPool at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had similar concern when I was trying to understand this, but I found its parent class HybridReqToTokenPool also takes both size and mamba_size as separate parameters — they represent different things (max_num_reqs for req-to-token vs max_mamba_cache_size for mamba state)

pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
if config := self.mambaish_config:
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
size=max_num_reqs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we change this line?
CC: @yizhang2077

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean changing size from max_num_reqs to max_mamba_cache_size, my understanding is that it may be not so good at the call site because size is also used by DecodeReqToTokenPool.__init__ (for the req_to_token tensor) and for mamba_spec_state_size , changing it would over-allocate the req_to_token pool and misalign the speculative buffer size with the non-disagg path. Happy to hear alternative suggestions though!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I check the logic, LGTM.

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Should be OK to merge after a double-check from @yizhang2077

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments