Skip to content

Commit d696144

Browse files
committed
examples(streaming-rdma): add Qwen3-8B DFlash streaming examples (single + multi-node)
Mirror the EAGLE3 streaming examples for DFlash: same NIXL-RDMA hidden-states transport and dispatch (train_eagle_streaming.sh), but the DFlash recipe and a different capture set. DFlash extracts build_target_layer_ids(36,5)=[1,9,17,25,33] as the draft fc input plus the final layer for self-logit distillation, so vLLM captures [2,10,18,26,34,36] (each target id +1, plus final layer 36). No code change needed: the streaming dataset already emits the same dict shape DFlash's offline path consumes. The streaming corpus is prompt-only (the serve generates the response and we capture its hidden states), so answer_only_loss is false (train over the full sequence, as in the EAGLE3 streaming example), and report_to is set to none since the dflash recipe defaults to tensorboard, which is absent in the serve container. Validated single-node on oci-nrt H100: serve captures 6 layers, trainer RDMA-fetches and DFlash converges (loss 11.0 -> 6.7 in 400 steps) and exports. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent ed395ab commit d696144

2 files changed

Lines changed: 194 additions & 0 deletions

File tree

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# DFlash streaming speculative decoding pipeline for Qwen3-8B.
2+
#
3+
# Same streaming transport as the EAGLE3 example (hf_streaming_eagle3.yaml): a live
4+
# `vllm serve` captures the target model's hidden states and moves them to the trainer
5+
# over NIXL RDMA (no disk round-trip). DFlash just consumes a different set of captured
6+
# layers and trains a block-diffusion draft instead of an autoregressive one.
7+
#
8+
# 3-step pipeline:
9+
# task_0: Build input conversations (jsonl)
10+
# task_1: Streaming train — vllm serve + DFlash trainer; hidden states over NIXL RDMA
11+
# task_2: vLLM smoke test with DFlash speculative decoding
12+
#
13+
# task_1 uses nodes=2: node 0 runs vllm serve, node 1 the trainer. Tasks share
14+
# /scratchspace to pass artifacts.
15+
#
16+
# Usage:
17+
# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_streaming_dflash.yaml --yes
18+
19+
job_name: Qwen3-8B_DFlash_streaming
20+
pipeline:
21+
allow_to_fail: false
22+
skip: false
23+
note:
24+
25+
global_vars:
26+
hf_model: /hf-local/Qwen/Qwen3-8B
27+
28+
# Step 1: Build input conversations
29+
task_0:
30+
script: common/eagle3/make_dataset.sh
31+
args:
32+
- -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml
33+
- --full-conversations
34+
slurm_config:
35+
_factory_: "slurm_factory"
36+
nodes: 1
37+
ntasks_per_node: 1
38+
gpus_per_node: 1
39+
container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10
40+
41+
# Step 2: Streaming DFlash training — node 0 vllm serve, node 1 trainer.
42+
# DFlash extracts 5 target layers (build_target_layer_ids(36,5)=[1,9,17,25,33], the
43+
# draft's fc input) plus the final layer for self-logit distillation. vLLM's capture
44+
# ids are those +1 -> [2,10,18,26,34], plus final layer 36.
45+
task_1:
46+
script: common/eagle3/train_eagle_streaming.sh
47+
args:
48+
- --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
49+
- model.model_name_or_path=<<global_vars.hf_model>>
50+
- data.mode=streaming
51+
- data.data_path=/scratchspace/data/train.jsonl
52+
- training.output_dir=/scratchspace/dflash
53+
- training.training_seq_len=4096
54+
- training.disable_tqdm=true
55+
# Streaming corpus is prompt-only (the serve generates the response and we
56+
# capture its hidden states), so there is no assistant span to mask -> train
57+
# over the full sequence, same as the EAGLE3 streaming example.
58+
- training.answer_only_loss=false
59+
- training.num_train_epochs=1
60+
# dflash.yaml sets report_to=tensorboard, which hard-fails if tensorboard
61+
# isn't in the serve container; the streaming trainer doesn't need it.
62+
- training.report_to=none
63+
- dflash.dflash_block_size=16
64+
- dflash.dflash_num_anchors=512
65+
- dflash.dflash_loss_decay_factor=7
66+
- dflash.dflash_mask_token_id=151669
67+
- dflash.dflash_architecture_config.num_hidden_layers=5
68+
environment:
69+
- HF_MODEL_CKPT: <<global_vars.hf_model>>
70+
# No spaces: nemo_run emits unquoted `export FOO=value`, so spaces would split.
71+
- EAGLE_CAPTURE_IDS: "[2,10,18,26,34,36]"
72+
- SERVE_TP: "1"
73+
# DFlash uses a custom modeling file; export must trust remote code.
74+
- EXPORT_EXTRA_ARGS: "--trust_remote_code"
75+
slurm_config:
76+
_factory_: "slurm_factory"
77+
nodes: 2
78+
ntasks_per_node: 1
79+
gpus_per_node: 1
80+
container: vllm/vllm-openai:latest
81+
82+
# Step 3: vLLM smoke test (DFlash, uses exported checkpoint from training)
83+
task_2:
84+
script: common/specdec/vllm_smoke_test.sh
85+
environment:
86+
- HF_MODEL_CKPT: <<global_vars.hf_model>>
87+
- DRAFT_MODEL: /scratchspace/export
88+
- SPEC_METHOD: "dflash"
89+
- NUM_SPEC_TOKENS: "7"
90+
- MIN_ACCEPTANCE_LENGTH: "1.2"
91+
slurm_config:
92+
_factory_: "slurm_factory"
93+
container: "vllm/vllm-openai:nightly"
94+
nodes: 1
95+
ntasks_per_node: 1
96+
gpus_per_node: 1
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# DFlash streaming speculative decoding pipeline for Qwen3-8B — MULTI-NODE.
2+
#
3+
# Same streaming transport / dispatch as hf_streaming_eagle3_multi_node.yaml: task_1
4+
# splits N nodes into K serve replicas + (N-K) DDP trainers via SERVE_NODES; hidden
5+
# states move serve -> trainer over NIXL RDMA. DFlash just consumes a different set of
6+
# captured layers and trains a block-diffusion draft instead of an autoregressive one.
7+
# See common/eagle3/train_eagle_streaming.sh for dispatch, rendezvous, and sharding.
8+
#
9+
# 3-step pipeline:
10+
# task_0: Build input conversations (jsonl)
11+
# task_1: Streaming train — 2 serve nodes (2 GPU, TP=2) + 2 trainer nodes (2 GPU)
12+
# task_2: vLLM smoke test with DFlash speculative decoding
13+
#
14+
# Usage:
15+
# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_streaming_dflash_multi_node.yaml --yes
16+
17+
job_name: Qwen3-8B_DFlash_streaming_multi_node
18+
pipeline:
19+
allow_to_fail: false
20+
skip: false
21+
note:
22+
23+
global_vars:
24+
hf_model: /hf-local/Qwen/Qwen3-8B
25+
26+
# Step 1: Build input conversations
27+
task_0:
28+
script: common/eagle3/make_dataset.sh
29+
args:
30+
- -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml
31+
- --full-conversations
32+
slurm_config:
33+
_factory_: "slurm_factory"
34+
nodes: 1
35+
ntasks_per_node: 1
36+
gpus_per_node: 1
37+
container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10
38+
39+
# Step 2: Streaming DFlash training — 2 serve replicas (TP=2) + 2 trainer nodes (2 GPU each).
40+
# DFlash extracts 5 target layers (build_target_layer_ids(36,5)=[1,9,17,25,33], the
41+
# draft's fc input) plus the final layer for self-logit distillation. vLLM's capture
42+
# ids are those +1 -> [2,10,18,26,34], plus final layer 36.
43+
task_1:
44+
script: common/eagle3/train_eagle_streaming.sh
45+
args:
46+
- --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
47+
- model.model_name_or_path=<<global_vars.hf_model>>
48+
- data.mode=streaming
49+
- data.data_path=/scratchspace/data/train.jsonl
50+
- training.output_dir=/scratchspace/dflash
51+
- training.training_seq_len=4096
52+
- training.disable_tqdm=true
53+
# Streaming corpus is prompt-only (the serve generates the response and we
54+
# capture its hidden states), so there is no assistant span to mask -> train
55+
# over the full sequence, same as the EAGLE3 streaming example.
56+
- training.answer_only_loss=false
57+
- training.num_train_epochs=1
58+
# dflash.yaml sets report_to=tensorboard, which hard-fails if tensorboard
59+
# isn't in the serve container; the streaming trainer doesn't need it.
60+
- training.report_to=none
61+
- dflash.dflash_block_size=16
62+
- dflash.dflash_num_anchors=512
63+
- dflash.dflash_loss_decay_factor=7
64+
- dflash.dflash_mask_token_id=151669
65+
- dflash.dflash_architecture_config.num_hidden_layers=5
66+
environment:
67+
- HF_MODEL_CKPT: <<global_vars.hf_model>>
68+
# No spaces: nemo_run emits `export FOO=value` unquoted.
69+
- EAGLE_CAPTURE_IDS: "[2,10,18,26,34,36]"
70+
- SERVE_TP: "2"
71+
# K serve replica nodes (Slurm nodes 0..K-1); the rest are trainers.
72+
- SERVE_NODES: "2"
73+
# Per-rank in-flight fetches; keep low so the cold serve isn't flooded past its execute-model timeout.
74+
- STREAMING_NUM_WORKERS: "4"
75+
# DFlash uses a custom modeling file; export must trust remote code.
76+
- EXPORT_EXTRA_ARGS: "--trust_remote_code"
77+
slurm_config:
78+
_factory_: "slurm_factory"
79+
nodes: 4
80+
ntasks_per_node: 1
81+
gpus_per_node: 2
82+
container: vllm/vllm-openai:latest
83+
84+
# Step 3: vLLM smoke test (DFlash, uses exported checkpoint from training)
85+
task_2:
86+
script: common/specdec/vllm_smoke_test.sh
87+
environment:
88+
- HF_MODEL_CKPT: <<global_vars.hf_model>>
89+
- DRAFT_MODEL: /scratchspace/export
90+
- SPEC_METHOD: "dflash"
91+
- NUM_SPEC_TOKENS: "7"
92+
- MIN_ACCEPTANCE_LENGTH: "1.2"
93+
slurm_config:
94+
_factory_: "slurm_factory"
95+
container: "vllm/vllm-openai:nightly"
96+
nodes: 1
97+
ntasks_per_node: 1
98+
gpus_per_node: 1

0 commit comments

Comments
 (0)