[Spec Decoding] Add DFlash model and proposer#1868
[Spec Decoding] Add DFlash model and proposer#1868aaronzhfeng wants to merge 1 commit intovllm-project:mainfrom
Conversation
|
This is a large PR. Can we break it down into several small PRs to make review easier? |
Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
|
Sorry about the large PR. The model, proposer, and attention kernel are tightly coupled (proposer calls model forward, model uses the attention kernel), so splitting them further would leave each PR non-functional on its own. All files here are new additions with no changes to existing code, which should make it easier to review. Broke the original PR down into 3:
PRs 2 and 3 coming shortly. |
kyuyeunk
left a comment
There was a problem hiding this comment.
Hi @aaronzhfeng! thank you for the contribution. couple of questions
- for those who aren't familiar with DFlash (like myself), can you give a brief overview & maybe a link where we can find out about more info?
- is my understanding correct that this feature is not available in vllm's pytorch model implementation? if so, is there a way for a backend that utilizes vllm's model implementation to leverage this spec decoding?
- can you share a sample command for people to try out this feature while going through the review process?
|
Thanks for taking a look! DFlash overview: DFlash is a block-diffusion speculative decoding method that predicts multiple tokens in parallel using discrete diffusion, instead of generating them one at a time autoregressively. Given a context, the draft model takes a block of masked/noise positions and denoises them in a single forward pass to produce K candidate tokens simultaneously. This makes drafting O(1) in block size rather than O(K). Paper: "DFlash: Block Diffusion for Flash Speculative Decoding" (Chen et al., arXiv:2602.06036). The reference GPU implementation is at https://github.com/z-lab/dflash. PyTorch/vLLM availability: Right now there is no DFlash support in vLLM's PyTorch backend. The DFlash authors have confirmed vLLM integration is still in progress on their end (see z-lab/dflash#6). SGLang has DFlash support via sgl-project/sglang#16818, but this PR would be the first DFlash integration in the vLLM ecosystem. It targets the JAX/TPU backend specifically, since the draft model uses non-causal attention which required a different attention path from the standard causal pipeline. A PyTorch port is feasible but not in scope for this PR. Sample command: The unit tests in this PR can be run without a full serving setup: pytest tests/models/jax/test_qwen3_dflash_attention.py
pytest tests/models/jax/test_qwen3_dflash.py
pytest tests/spec_decode/test_dflash.pyEnd-to-end serving requires the pipeline integration in PR #1869 (already open). Once both are merged, with Qwen3-4B on a TPU v5p-8: python -m tpu_inference.entrypoint \
--model Qwen/Qwen3-4B \
--speculative_config '{"model": "z-lab/Qwen3-4B-DFlash-b16", "num_speculative_tokens": 15, "method": "dflash", "draft_tensor_parallel_size": 1}' |
There was a problem hiding this comment.
can you elaborate what kind of feature is missing from existing attention implementation that it requires its own separate code? if it's due to bi-directional attention, we already have an implementation for that.
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=("max_query_len", )) | ||
| def dflash_concat_attention( |
There was a problem hiding this comment.
in general, i think this function is lacking a lot of comments explaning what each line does.
There was a problem hiding this comment.
@Lumosis can you help take a look at spec decoding part?
Description
Add DFlash draft model and proposer for block-diffusion speculative decoding on JAX/TPU. DFlash predicts multiple tokens in parallel using discrete diffusion, unlike Eagle3's autoregressive drafting. This follows the same proposer pattern as Eagle3.
This is PR 1 of 3 for DFlash support:
New files:
tpu_inference/models/jax/dflash.py-- DFlash draft model (DFlashForCausalLM)tpu_inference/models/jax/qwen3_dflash.py-- Qwen3-specific DFlash variant with attentiontpu_inference/layers/common/dflash_attention_interface.py-- dflash_concat_attention kerneltpu_inference/spec_decode/jax/dflash.py-- DFlashProposer (prepare_inputs, propose, sampling)tests/models/jax/test_qwen3_dflash_attention.py-- DFlash attention unit teststests/models/jax/test_qwen3_dflash.py-- target layer ID selection teststests/spec_decode/test_dflash.py-- proposer sampling testsTests
tests/models/jax/test_qwen3_dflash_attention.pytests/models/jax/test_qwen3_dflash.pytests/spec_decode/test_dflash.pyChecklist