Skip to content

Fix NSA FP8 KV cache path for both-trtllm MHA one-shot#18931

Open
mmangkad wants to merge 5 commits intosgl-project:mainfrom
mmangkad:fix/nsa-fp8-kv-layout-mismatch
Open

Fix NSA FP8 KV cache path for both-trtllm MHA one-shot#18931
mmangkad wants to merge 5 commits intosgl-project:mainfrom
mmangkad:fix/nsa-fp8-kv-layout-mismatch

Conversation

@mmangkad
Copy link
Contributor

@mmangkad mmangkad commented Feb 17, 2026

Motivation

This is a follow-up to #18389 for the NSA FP8 KV cache path case when both prefill and decode backends are trtllm.

In this setup, the MHA one-shot prefix path could still enter paged dequant and hit a layout mismatch assertion (compact trtllm KV layout vs packed dequant expectation). For the both-trtllm configuration, this dequant step is unnecessary.

cc @rainj-me @Fridge003

Modifications

  • Updated the FP8 branch condition in forward_mha.py to skip paged dequant when both nsa_prefill_backend and nsa_decode_backend are trtllm.
  • In the both-trtllm case, reuse the existing direct KV fetch path.
  • No changes to NSA routing semantics or dequant kernel implementation.

Accuracy Tests

This is an example of a case that previously crashed with AssertionError: dim_quant: 576 != 656.

SGLANG_ENABLE_SPEC_V2=1 python -m sglang.launch_server \
  --model-path nvidia/DeepSeek-V3.2-NVFP4 \
  --tensor-parallel-size 4 \
  --attention-backend nsa \
  --nsa-prefill-backend trtllm \
  --nsa-decode-backend trtllm \
  --moe-runner-backend flashinfer_trtllm \
  --quantization modelopt_fp4 \
  --kv-cache-dtype fp8_e4m3 \
  --reasoning-parser deepseek-v3 \
  --tool-call-parser deepseekv32 \
  --speculative-algorithm EAGLE \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 \
  --model-loader-extra-config '{"enable_multithread_load": true,"num_threads": 96}'
python benchmark/gsm8k/bench_sglang.py --num-shots 20 --num-questions 1319 --parallel 48

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:45<00:00,  7.96it/s]
Accuracy: 0.955
Invalid: 0.000
Latency: 165.688 s
Output throughput: 786.323 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Support compact(576) and packed(656) KV layouts in paged dequant, gate dequant path by backend storage mode, and pass explicit MLA dims.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mmangkad, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a regression where a mismatch in FP8 KV cache layouts caused assertion errors during MHA one-shot operations, particularly when prefix sharing was involved. The primary fix involves updating the paged dequantization logic to correctly handle both 'compact' and 'packed' FP8 KV cache formats. This ensures compatibility across different backend storage modes and prevents crashes in high-concurrency scenarios by dynamically adapting to the KV cache's structure.

Highlights

  • KV Cache Layout Support: The dequantize_k_cache_paged function now supports both compact and packed FP8 KV cache layouts, resolving a previous mismatch issue.
  • Dynamic Dimension Parameters: Explicit dim_nope and dim_rope parameters were added to dequantize_k_cache_paged to enable model-dimension aware layout parsing.
  • Conditional Dequantization: The MHA one-shot FP8 path now conditionally uses paged dequantization based on the backend's KV cache storage mode.
  • Improved Debugging: Assertion messages for unsupported KV cache layouts have been enhanced to provide clearer debugging information.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
    • Modified dequantize_k_cache_paged to accept dim_nope and dim_rope as explicit parameters with default values.
    • Removed the hardcoded assertion for dim_quant == 656 and introduced dynamic checks for packed_dim_quant and compact_dim_quant.
    • Implemented logic to handle the compact KV cache layout by directly casting FP8 to BF16.
    • Updated the assertion message for unsupported dim_quant values to provide more contextual information.
  • python/sglang/srt/layers/attention/nsa_backend.py
    • Updated the call to dequantize_k_cache_paged within forward_extend to pass dim_nope and dim_rope using self.kv_lora_rank and self.qk_rope_head_dim.
  • python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py
    • Introduced a use_paged_fp8_dequant flag to conditionally control the use of the paged FP8 dequantization path.
    • Added logic to determine use_paged_fp8_dequant based on the nsa_kv_cache_store_fp8 attribute of the attention backend.
    • Modified the _get_mla_kv_buffer_from_fp8_for_nsa call to pass dim_nope and dim_rope to the dequantization function.
Activity
  • The pull request was created to address a regression related to FP8 KV layout mismatches under MHA one-shot mode.
  • The author provided detailed motivation, modifications, and accuracy test results demonstrating the fix.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a regression causing a KV layout mismatch in MHA one-shot mode. The changes introduce support for both 'compact' and 'packed' FP8 layouts in dequantize_k_cache_paged by making it aware of model dimensions and adding gating logic to select the correct dequantization path. The fix appears robust and the call sites have been updated correctly. My review includes a minor suggestion to replace a magic number with a named constant for improved code clarity and maintainability.

@mmangkad
Copy link
Contributor Author

/tag-and-rerun-ci

@mmangkad
Copy link
Contributor Author

mmangkad commented Feb 17, 2026

/rerun-failed-ci

@rainj-me
Copy link
Collaborator

rainj-me commented Feb 17, 2026

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

Update: I checked the depensencies calling

  1. forward_batch.mha_one_shot is only set on https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py#L314
  2. the forward_normal_one_shot_prepare is depends on https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_v2.py#L1415
  3. the attn_forward_method is eventually depends on https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L125-L134
  4. the handle_attention_trtllm_mla doesn't have any chance to return the MHA_ONE_SHOT

@mmangkad
Copy link
Contributor Author

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

@rainj-me my read is that this goes through the NSA-specific dispatcher: handle_attention_nsa checks backend.use_mha and returns MHA_ONE_SHOT when true, and backend.use_mha is decided in set_nsa_prefill_impl.

@rainj-me
Copy link
Collaborator

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

@rainj-me my read is that this goes through the NSA-specific dispatcher: handle_attention_nsa checks backend.use_mha and returns MHA_ONE_SHOT when true, and backend.use_mha is decided in set_nsa_prefill_impl.

Then the fix should be only support MHA instead of MHA_ONE_SHOT with handle_attention_nsa . The reason to avoid changing the dequant logic is because we should keep it only for triton and flashmla util trtllm-mla support the sparse fp8 kv cache.

@mmangkad
Copy link
Contributor Author

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

@rainj-me my read is that this goes through the NSA-specific dispatcher: handle_attention_nsa checks backend.use_mha and returns MHA_ONE_SHOT when true, and backend.use_mha is decided in set_nsa_prefill_impl.

Then the fix should be only support MHA instead of MHA_ONE_SHOT with handle_attention_nsa . The reason to avoid changing the dequant logic is because we should keep it only for triton and flashmla util trtllm-mla support the sparse fp8 kv cache.

Thanks - I thought NSA and trtllm_mla are different dispatch paths, and NSA can intentionally go through MHA_ONE_SHOT via use_mha, so I kept this PR scoped to the FP8 layout mismatch fix. As I mentioned to Fridge, trtllm kernel paths still bypass this dequant path anyway. Given that, what would you prefer here: keep this scoped fix, or change NSA routing semantics in this PR?

@rainj-me
Copy link
Collaborator

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

@rainj-me my read is that this goes through the NSA-specific dispatcher: handle_attention_nsa checks backend.use_mha and returns MHA_ONE_SHOT when true, and backend.use_mha is decided in set_nsa_prefill_impl.

Then the fix should be only support MHA instead of MHA_ONE_SHOT with handle_attention_nsa . The reason to avoid changing the dequant logic is because we should keep it only for triton and flashmla util trtllm-mla support the sparse fp8 kv cache.

Thanks - I thought NSA and trtllm_mla are different dispatch paths, and NSA can intentionally go through MHA_ONE_SHOT via use_mha, so I kept this PR scoped to the FP8 layout mismatch fix. As I mentioned to Fridge, trtllm kernel paths still bypass this dequant path anyway. Given that, what would you prefer here: keep this scoped fix, or change NSA routing semantics in this PR?

Fix the kv cache dequant to use 576 kv cache dim may lead to revert in future if trtllm support sparse kv cache. Base on this, I believe change the NSA routing should be a simple and sufficient fix. The only difference for MHA_ONE_SHOT and MHA is the forward_batch.mha_one_shot flag.

@mmangkad
Copy link
Contributor Author

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

@rainj-me my read is that this goes through the NSA-specific dispatcher: handle_attention_nsa checks backend.use_mha and returns MHA_ONE_SHOT when true, and backend.use_mha is decided in set_nsa_prefill_impl.

Then the fix should be only support MHA instead of MHA_ONE_SHOT with handle_attention_nsa . The reason to avoid changing the dequant logic is because we should keep it only for triton and flashmla util trtllm-mla support the sparse fp8 kv cache.

Thanks - I thought NSA and trtllm_mla are different dispatch paths, and NSA can intentionally go through MHA_ONE_SHOT via use_mha, so I kept this PR scoped to the FP8 layout mismatch fix. As I mentioned to Fridge, trtllm kernel paths still bypass this dequant path anyway. Given that, what would you prefer here: keep this scoped fix, or change NSA routing semantics in this PR?

Fix the kv cache dequant to use 576 kv cache dim may lead to revert in future if trtllm support sparse kv cache. Base on this, I believe change the NSA routing should be a simple and sufficient fix. The only difference for MHA_ONE_SHOT and MHA is the forward_batch.mha_one_shot flag.

I may be misunderstanding the “dequant to 576” concern, but this change does not hardcode 576; it handles both compact and packed KV layouts using runtime dimensions, and trtllm kernel paths still bypass dequant.

Also, I don’t think switching NSA from MHA_ONE_SHOT to MHA is a simple flag swap here. In NSA, that routing changes the prepare/core path and one-shot prefix-KV fetch behavior (prefix+current indices). So that looks like a broader routing/behavior change, while this PR is scoped to the KV layout mismatch regression. If I’m interpreting your intent incorrectly, I’m happy to align.

@rainj-me
Copy link
Collaborator

rainj-me commented Feb 17, 2026

@mmangkad From my understanding, the MHA_ONE_SHOT supported attn backends are "fa3", "flashinfer", "flashmla" per code https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_common/attention_backend_handler.py#L64-L69 . Since the trtllmla attn backend is not in the list, if we find there is MHA_ONE_SHOT seq in the forward_batch, there probably some other bugs involved.

@rainj-me my read is that this goes through the NSA-specific dispatcher: handle_attention_nsa checks backend.use_mha and returns MHA_ONE_SHOT when true, and backend.use_mha is decided in set_nsa_prefill_impl.

Then the fix should be only support MHA instead of MHA_ONE_SHOT with handle_attention_nsa . The reason to avoid changing the dequant logic is because we should keep it only for triton and flashmla util trtllm-mla support the sparse fp8 kv cache.

Thanks - I thought NSA and trtllm_mla are different dispatch paths, and NSA can intentionally go through MHA_ONE_SHOT via use_mha, so I kept this PR scoped to the FP8 layout mismatch fix. As I mentioned to Fridge, trtllm kernel paths still bypass this dequant path anyway. Given that, what would you prefer here: keep this scoped fix, or change NSA routing semantics in this PR?

Fix the kv cache dequant to use 576 kv cache dim may lead to revert in future if trtllm support sparse kv cache. Base on this, I believe change the NSA routing should be a simple and sufficient fix. The only difference for MHA_ONE_SHOT and MHA is the forward_batch.mha_one_shot flag.

I may be misunderstanding the “dequant to 576” concern, but this change does not hardcode 576; it handles both compact and packed KV layouts using runtime dimensions, and trtllm kernel paths still bypass dequant.

Also, I don’t think switching NSA from MHA_ONE_SHOT to MHA is a simple flag swap here. In NSA, that routing changes the prepare/core path and one-shot prefix-KV fetch behavior (prefix+current indices). So that looks like a broader routing/behavior change, while this PR is scoped to the KV layout mismatch regression. If I’m interpreting your intent incorrectly, I’m happy to align.

I just tested, with the following change instead of dequant should work.

diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py b/python/sglang/srt/models/de
epseek_common/attention_forward_methods/forward_mha.py
index 6eff360c4..db803c6c1 100644
--- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py
+++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py
@@ -215,7 +215,10 @@ class DeepseekMHAForwardMixin:
             forward_batch.mha_one_shot
             and sum(forward_batch.extend_prefix_lens_cpu) != 0
         ):
-            if self.use_nsa and self.kv_cache_dtype == "fp8_e4m3":
+            if self.use_nsa and self.kv_cache_dtype == "fp8_e4m3" and (
+                not get_global_server_args().nsa_decode_backend == 'trtllm'
+                or not get_global_server_args().nsa_prefill_backend == 'trtllm'
+            ):
                 # FP8 path: dequantize NSA-specific FP8 format to BF16
                 kv_a, k_pe = self._get_mla_kv_buffer_from_fp8_for_nsa(forward_batch)
             else:

For the testing with radix cache enabled

# 1st run
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 100 --parallel 100 --port 30000
Accuracy: 0.980
Invalid: 0.000
Latency: 14.141 s
Output throughput: 666.624 token/s

# 2nd run
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 100 --parallel 100 --port 30000
Accuracy: 0.970
Invalid: 0.000
Latency: 2.770 s
Output throughput: 3334.405 token/s

# 3rd run
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 100 --parallel 100 --port 30000
Accuracy: 0.970
Invalid: 0.000
Latency: 2.587 s
Output throughput: 3450.028 token/s

For the testing without radix cache enabled

# 1st run
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 100 --parallel 100 --port 30000
Accuracy: 0.980
Invalid: 0.000
Latency: 12.903 s
Output throughput: 695.574 token/s

# 2nd run
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 100 --parallel 100 --port 30000
Accuracy: 0.970
Invalid: 0.000
Latency: 10.022 s
Output throughput: 901.345 token/s

# 3rd run
root@prod-server-41115:/sgl-workspace/sglang# python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 100 --parallel 100 --port 30000
Accuracy: 0.980
Invalid: 0.000
Latency: 4.226 s
Output throughput: 2129.826 token/s

@mmangkad
Copy link
Contributor Author

@rainj-me thanks, nice find - I’ll revert my changes and apply this gate for the both-trtllm FP8 path.

@mmangkad mmangkad changed the title Fix NSA FP8 KV layout mismatch under MHA one-shot Fix NSA FP8 KV cache path for both-trtllm MHA one-shot Feb 18, 2026
@rainj-me
Copy link
Collaborator

/rerun-failed-ci

1 similar comment
@rainj-me
Copy link
Collaborator

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments