[FEAT][SpecDecode] Add DP attention support for DFLASH speculative decoding#29506
[FEAT][SpecDecode] Add DP attention support for DFLASH speculative decoding#29506EanWang211123 wants to merge 2 commits into
Conversation
Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
There was a problem hiding this comment.
Code Review
This pull request adds support for DP attention in DFLASH speculative decoding. It automatically enables enable_dp_lm_head when DP attention is enabled, runs draft initialization and operations within the attention TP group context, and handles IDLE/DECODE ranks during prefill and verification to prevent deadlocks. Additionally, it handles right-padded hidden states when using MoE Expert Parallelism. The reviewer identified a critical runtime bug where self.draft_tp_context is assigned to empty_context when DP attention is disabled, which will raise a TypeError because empty_context does not accept the positional argument passed to it in multiple places. The reviewer suggested using a lambda wrapper to safely ignore the argument.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self.draft_tp_context = ( | ||
| draft_tp_context if server_args.enable_dp_attention else empty_context | ||
| ) |
There was a problem hiding this comment.
When enable_dp_attention is False, self.draft_tp_context is assigned to empty_context. However, empty_context is a 0-argument context manager, whereas self.draft_tp_context is called with a positional argument (e.g., self.draft_tp_context(self.draft_model_runner.tp_group)) in multiple places (lines 314, 328, 1562). This will raise a TypeError at runtime and crash the server on startup when DP attention is disabled. Using a lambda wrapper like lambda _: empty_context() ensures that the single positional argument is safely ignored.
| self.draft_tp_context = ( | |
| draft_tp_context if server_args.enable_dp_attention else empty_context | |
| ) | |
| self.draft_tp_context = ( | |
| draft_tp_context if server_args.enable_dp_attention else lambda _: empty_context() | |
| ) |
Motivation
DFLASH speculative decoding previously rejected
--enable-dp-attentionat startup. This blocks deployments that combine DFLASH with data-parallel attention (e.g.--tensor-parallel-size 4 --dp-size 2 --enable-dp-attention), which is a common setup for large MoE models like GLM-5.EAGLE3 already supports DP attention by running each draft worker inside the attention TP group (
attn_tp_group). DFLASH should follow the same pattern, but it has additional constraints: it materializes target hidden states directly into the draft KV cache (instead of re-running a draft extend forward), and it performs draft greedy sampling over the targetlm_head. These paths need explicit alignment with DP/EP padding, CUDA graph capture modes, and full-TP target verify collectives.Modifications
speculative_hook.pyenable_dp_attention.enable_dp_lm_headwhen DFLASH runs with DP attention, so draft greedy sampling's vocab-parallelall_gatherstays within the attention TP group (matchinglm_headsharding). Without this, a global-TPall_gathermixes tokens across DP groups and deadlocks when a peer DP group is IDLE.dflash_worker_v2.pyDraft worker initialization (mirrors EAGLE3 + dp_attention)
dp_attentionon the draft server args (draft is dense; keeps KV row count aligned without_cache_loc).draft_tp_context(get_attention_tp_group())so KV head partitioning matchestoken_to_kv_pool.row_dim(both useattn_tp_size).init_attention_backends,init_cuda_graphs, and runtime draft forward + greedy sampling indraft_tp_context.Prefill / extend path
is_extend_in_batch=Trueis broadcast globally but the local rank is IDLE/DECODE (avoids missingextend_lens/prefix_lens).hidden_statesbefore writing into draft KV whenmoe_ep_size > 1(fixescache_locvstarget_hiddenlength mismatch on non-aligned token counts).Decode path
capture_hidden_mode=FULLto match active ranks. ANULLmode mismatch triggers a CUDA graph recapture whose internal barrier active ranks never enter, causing deadlock._greedy_sample_from_vocab_parallel_headinsidedraft_tp_contextso itsall_gatheruses the attention TP group.Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ❌ Run #28285509303
Latest PR Test (Extra): ❌ Run #28285509277