Skip to content

Commit cb23607

Browse files
litianjianlitianjianzhangbiao.168
authored
[algo] feat: support router replay (#4101)
### What does this PR do? This PR introduces a draft **Router Replay** support into Verl. Inspired by the recent research in **MoE Reinforcement Learning**([2510.11370](https://arxiv.org/abs/2510.11370), [2507.18071](https://arxiv.org/abs/2507.18071)), this implementation supports **Router Replay (R2)** and **Rollout Router Replay (R3)**. R2 allows recording routing token selection during` log probability computation` and replaying expert selection during policy update. R3 enables recording during `model inference` and replaying during RL post-training. The initial version supports **Router Replay** with `Megatron` backend, including comprehensive support for distributed training strategies (**DP, TP, EP, ETP, PP, and Re-compute**). The current implementation uses a patch-based approach. Once the upstream PR [NVIDIA/Megatron-LM#2101](NVIDIA/Megatron-LM#2101) is merged or provides corresponding interfaces, the patch can be removed and replaced with official API integration. ## Usage Tutorial ### Basic Configuration To enable Router Replay functionality, add the following configuration to your trainer config: #### Method 1: Trainer Configuration Add the following configuration to your trainer config: ```yaml router_replay: enabled: true mode: "R2" # Options: "R2", "R3" ``` #### Method 2: Launch Script Configuration Add the following parameter to your launch script: ```bash # In your launch script actor_rollout_ref.actor.router_replay.mode="R2" ``` ### R2 Mode Usage 1. **Enable R2 mode** in configuration 2. **Record phase**: During log probability computation, routing selections are automatically recorded 3. **Replay phase**: During policy update, recorded expert selections are replayed ### R3 Mode Usage 1. **Enable R3 mode** in configuration 2. **Record phase**: During model inference, routing decisions are captured 3. **Replay phase**: During RL post-training, recorded routing data is used 4. ## In Progress R2 - [ ] FSDP backend R3 - [x] vLLM Rollout - [ ] Sglang Rollout --------- Co-authored-by: litianjian <[email protected]> Co-authored-by: zhangbiao.168 <[email protected]>
1 parent 493a397 commit cb23607

File tree

20 files changed

+1295
-9
lines changed

20 files changed

+1295
-9
lines changed

examples/router_replay/README.md

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Router Replay
2+
3+
Router Replay is an advanced routing replay functionality within the Verl framework designed for Mixture of Experts (MoE) models. It enables deterministic training by recording and replaying routing decisions, ensuring consistent model behavior across training runs.
4+
5+
6+
## Key Features
7+
8+
### Multiple Operating Modes
9+
- **`disabled`**: Router replay functionality is completely disabled
10+
- **`R2`**: Standard router replay mode for recording and replaying routing decisions
11+
- **`R3`**: Rollout-specific router replay mode optimized for reinforcement learning workflows
12+
13+
### Core Capabilities
14+
- **Seamless Integration**: Works with reinforcement learning pipelines including PPO
15+
- **Distributed Training Support**: Compatible with multi-GPU and multi-node training environments
16+
- **Flexible Configuration**: Easy to configure via YAML files or command-line parameters
17+
18+
## Configuration
19+
20+
### RouterReplayConfig Parameters
21+
22+
```yaml
23+
router_replay:
24+
mode: "disabled" # Available options: disabled, R2, R3
25+
record_file: null # Path for recording routing decisions
26+
replay_file: null # Path for replaying recorded decisions
27+
```
28+
29+
## Quick Start Guide
30+
31+
### Enabling R2 Mode
32+
33+
#### Configuration File Method
34+
Add the following to your training configuration:
35+
36+
```yaml
37+
actor:
38+
router_replay:
39+
mode: "R2"
40+
```
41+
42+
#### Command Line Method
43+
Enable R2 mode via command-line parameters:
44+
45+
```bash
46+
actor_rollout_ref.actor.router_replay.mode="R2"
47+
actor_rollout_ref.rollout.enable_rollout_routing_replay=True
48+
```
49+
50+
### Enabling R3 Mode
51+
52+
#### Configuration File Method
53+
Configure both actor and rollout settings:
54+
55+
```yaml
56+
# Actor configuration
57+
router_replay:
58+
mode: "R3"
59+
60+
# Rollout configuration
61+
enable_rollout_routing_replay: True
62+
```
63+
64+
#### Command Line Method
65+
Enable R3 mode via command-line parameters:
66+
67+
```bash
68+
actor_rollout_ref.actor.router_replay.mode="R3"
69+
actor_rollout_ref.rollout.enable_rollout_routing_replay=True
70+
```
71+
72+
R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284.
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
2+
set -x
3+
4+
NODES=1
5+
6+
# R2: enable routing replay
7+
# R3: enable rollout routing replay
8+
# If enabling R3, please set actor_rollout_ref.rollout.enable_rollout_routing_replay=True
9+
# R3 example is based on vllm related pr https://github.com/vllm-project/vllm/pull/5322
10+
11+
ROUTING_REPLAY_MODE="R2"
12+
13+
DIST_CKPT_PATH=""
14+
HF_MODEL_PATH=""
15+
TRAIN_DATA_PATH=""
16+
TEST_DATA_PATH=""
17+
18+
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
19+
PP=1
20+
VPP=None
21+
TP=2
22+
EP=8
23+
ETP=1
24+
VLLM_INFER_TP=2
25+
offload=True
26+
gpu_memory_utilization=0.65
27+
bs=8
28+
micro_bs=3
29+
use_dynamic_bsz=True
30+
max_prompt_length=1024
31+
max_response_length=1024
32+
ppo_mini_batch_size=8
33+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
34+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
35+
36+
37+
exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${VLLM_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs}
38+
39+
python3 -m verl.trainer.main_ppo --config-path=config \
40+
--config-name='ppo_megatron_trainer.yaml' \
41+
algorithm.adv_estimator=grpo \
42+
data.train_files=$TRAIN_DATA_PATH \
43+
data.val_files=$TEST_DATA_PATH \
44+
data.train_batch_size=$bs \
45+
data.max_prompt_length=$max_prompt_length \
46+
data.max_response_length=$max_response_length \
47+
data.filter_overlong_prompts=True \
48+
data.truncation='error' \
49+
actor_rollout_ref.model.use_fused_kernels=True \
50+
actor_rollout_ref.model.path=$HF_MODEL_PATH \
51+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
52+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
53+
actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE} \
54+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
55+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
56+
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
57+
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \
58+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
59+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
60+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
61+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
62+
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
63+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
64+
actor_rollout_ref.actor.megatron.param_offload=${offload} \
65+
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
66+
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
67+
actor_rollout_ref.actor.optim.lr=1e-6 \
68+
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
69+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_bs \
70+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
71+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
72+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
73+
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
74+
actor_rollout_ref.actor.use_kl_loss=False \
75+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
76+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
77+
actor_rollout_ref.actor.entropy_coeff=0 \
78+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
79+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
80+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_bs \
81+
actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_INFER_TP \
82+
actor_rollout_ref.rollout.name=vllm \
83+
actor_rollout_ref.rollout.mode=async \
84+
actor_rollout_ref.actor.megatron.use_mbridge=True \
85+
actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \
86+
actor_rollout_ref.rollout.n=8 \
87+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
88+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
89+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$micro_bs \
90+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
91+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
92+
actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
93+
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
94+
actor_rollout_ref.ref.megatron.param_offload=${offload} \
95+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
96+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
97+
algorithm.use_kl_in_reward=False \
98+
trainer.critic_warmup=0 \
99+
trainer.logger=['console'] \
100+
trainer.project_name='verl_grpo_example_gsm8k_math' \
101+
trainer.experiment_name="$exper_name" \
102+
trainer.nnodes=$NODES \
103+
trainer.n_gpus_per_node=8 \
104+
trainer.save_freq=-1 \
105+
trainer.test_freq=10 \
106+
trainer.total_training_steps=50000 \
107+
trainer.balance_batch=False \
108+
trainer.val_before_train=False 2>&1

verl/experimental/agent_loop/agent_loop.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class AgentLoopOutput(BaseModel):
134134
"""Response mask, 1 for LLM generated token, 0 for tool response token."""
135135
response_logprobs: Optional[list[float]] = None
136136
"""Log probabilities for the response tokens."""
137+
routed_experts: Optional[Any] = None
138+
"""Routed experts for the total tokens."""
137139
multi_modal_data: Optional[dict[str, Any]] = None
138140
"""Multi-modal data for multi-modal tools."""
139141
reward_score: Optional[float] = None
@@ -165,6 +167,8 @@ class _InternalAgentLoopOutput(AgentLoopOutput):
165167
"""Padded attention mask."""
166168
response_logprobs: Optional[torch.Tensor] = None
167169
"""Padded log probabilities for the response tokens."""
170+
routed_experts: Optional[torch.Tensor] = None
171+
"""Padded routed experts for the total tokens."""
168172
multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None
169173
"""Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw)."""
170174
extra_fields: dict[str, Any] = {}
@@ -487,6 +491,25 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
487491
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)
488492
input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1)
489493

494+
routed_experts = None
495+
if output.routed_experts is not None:
496+
total_length = input_ids.shape[1]
497+
length, layer_num, topk_num = output.routed_experts.shape
498+
experts_tensor = torch.from_numpy(output.routed_experts)
499+
routed_experts = torch.zeros(1, total_length, layer_num, topk_num, dtype=experts_tensor.dtype)
500+
501+
# Calculate start position: left padding means original prompt starts at the end
502+
start_pos = prompt_output["input_ids"].shape[1] - len(output.prompt_ids)
503+
end_pos = min(start_pos + length, total_length)
504+
505+
# Add boundary checks for robustness
506+
if start_pos < 0 or end_pos > total_length:
507+
raise ValueError(
508+
f"Invalid position range: start_pos={start_pos}, end_pos={end_pos}, total_length={total_length}"
509+
)
510+
511+
routed_experts[:, start_pos:end_pos] = experts_tensor.unsqueeze(0)
512+
490513
# Handle multi-modal inputs and position_ids calculation
491514
# Only support Qwen2VLImageProcessor for multi-modal processing currently
492515
# TODO: support other multi-modal inputs
@@ -560,6 +583,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
560583
response_mask=response_mask,
561584
attention_mask=attention_mask,
562585
response_logprobs=response_logprobs,
586+
routed_experts=routed_experts,
563587
multi_modal_inputs=multi_modal_inputs,
564588
multi_modal_data=output.multi_modal_data,
565589
reward_score=output.reward_score,
@@ -580,6 +604,8 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
580604
optional_outputs = {}
581605
if inputs[0].response_logprobs is not None:
582606
optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
607+
if inputs[0].routed_experts is not None:
608+
optional_outputs["routed_experts"] = torch.cat([input.routed_experts for input in inputs], dim=0)
583609

584610
batch = TensorDict(
585611
{

verl/experimental/agent_loop/single_turn_agent_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
7373
response_ids=output.token_ids[: self.response_length],
7474
response_mask=response_mask[: self.response_length],
7575
response_logprobs=output.log_probs[: self.response_length] if output.log_probs else None,
76+
routed_experts=(
77+
output.routed_experts[: len(prompt_ids) + self.response_length]
78+
if output.routed_experts is not None
79+
else None
80+
),
7681
multi_modal_data={"image": image_data} if image_data is not None else {},
7782
num_turns=2,
7883
metrics=metrics,

verl/experimental/agent_loop/tool_agent_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ async def _handle_generating_state(
233233
if output.log_probs:
234234
agent_data.response_logprobs += output.log_probs
235235

236+
if output.routed_experts is not None:
237+
agent_data.routed_experts = output.routed_experts
238+
236239
# Check termination conditions
237240
if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
238241
return AgentState.TERMINATED

verl/models/mcore/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def postprocess_packed_seqs(
144144
if cp_size > 1:
145145
# output shape: [1, packed_len, hidden_dim]
146146
# need to gather across cp group and concatenate in sequence dimension
147-
output_list = [torch.empty_like(output) for _ in range(cp_size)]
147+
output_list = [torch.empty_like(output, dtype=output.dtype) for _ in range(cp_size)]
148148
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
149149
output_list[mpu.get_context_parallel_rank()] = output
150150
else:
@@ -159,7 +159,7 @@ def postprocess_packed_seqs(
159159
half_seqlen = s_len_padded_chunk // 2
160160
s_len = seq_lens_cpu[i]
161161
s_len_padded = s_len_padded_chunk * cp_size
162-
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
162+
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device, dtype=output.dtype)
163163
for j in range(cp_size):
164164
o = output_list[j][0]
165165
# split to 2 chunks

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ actor_rollout_ref:
122122
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
123123
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
124124
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
125+
router_replay:
126+
_target_: verl.workers.config.RouterReplayConfig
127+
mode: disabled
128+
record_file: null
129+
replay_file: null
125130
data_loader_seed: 42
126131
load_weight: true
127132
ref:
@@ -157,6 +162,11 @@ actor_rollout_ref:
157162
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
158163
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
159164
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
165+
router_replay:
166+
_target_: verl.workers.config.RouterReplayConfig
167+
mode: disabled
168+
record_file: null
169+
replay_file: null
160170
megatron:
161171
_target_: verl.workers.config.McoreEngineConfig
162172
param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False}
@@ -261,6 +271,7 @@ actor_rollout_ref:
261271
skip_rollout: false
262272
skip_dump_dir: /tmp/rollout_dump
263273
skip_tokenizer_init: true
274+
enable_rollout_routing_replay: false
264275
profiler:
265276
_target_: verl.utils.profiler.ProfilerConfig
266277
tool: ${oc.select:global_profiler.tool,null}

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ actor_rollout_ref:
109109
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
110110
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
111111
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
112+
router_replay:
113+
_target_: verl.workers.config.RouterReplayConfig
114+
mode: disabled
115+
record_file: null
116+
replay_file: null
112117
grad_clip: 1.0
113118
ulysses_sequence_parallel_size: 1
114119
entropy_from_logits_with_chunking: false
@@ -147,6 +152,11 @@ actor_rollout_ref:
147152
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
148153
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
149154
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
155+
router_replay:
156+
_target_: verl.workers.config.RouterReplayConfig
157+
mode: disabled
158+
record_file: null
159+
replay_file: null
150160
fsdp_config:
151161
_target_: verl.workers.config.FSDPEngineConfig
152162
wrap_policy:
@@ -249,6 +259,7 @@ actor_rollout_ref:
249259
skip_rollout: false
250260
skip_dump_dir: /tmp/rollout_dump
251261
skip_tokenizer_init: true
262+
enable_rollout_routing_replay: false
252263
profiler:
253264
_target_: verl.utils.profiler.ProfilerConfig
254265
tool: ${oc.select:global_profiler.tool,null}

verl/trainer/config/actor/actor.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,23 @@ profiler:
220220

221221
# Stack trace depth for memory allocations
222222
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
223+
224+
# Router replay configuration for MoE models
225+
router_replay:
226+
227+
# Target dataclass for this configuration
228+
_target_: verl.workers.config.RouterReplayConfig
229+
230+
# Router replay mode: disabled, R2, R3
231+
# - R2: Use R2 routing strategy (record mode)
232+
# - R3: Use R3 routing strategy (record mode)
233+
mode: disabled
234+
235+
# File path to save recorded routing decisions
236+
# Required when mode is 'record', 'R2', or 'R3'
237+
record_file: null
238+
239+
# File path to load recorded routing decisions for replay
240+
# Required when mode is 'replay'
241+
replay_file: null
242+

verl/trainer/config/ref/ref.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,22 @@ profiler:
100100

101101
# Stack trace depth for memory allocations
102102
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
103+
104+
# Router replay configuration for MoE models
105+
router_replay:
106+
107+
# Target dataclass for this configuration
108+
_target_: verl.workers.config.RouterReplayConfig
109+
110+
# Router replay mode: disabled, R2, R3
111+
# - R2: Use R2 routing strategy (record mode)
112+
# - R3: Use R3 routing strategy (record mode)
113+
mode: disabled
114+
115+
# File path to save recorded routing decisions
116+
# Required when mode is 'record', 'R2', or 'R3'
117+
record_file: null
118+
119+
# File path to load recorded routing decisions for replay
120+
# Required when mode is 'replay'
121+
replay_file: null

0 commit comments

Comments
 (0)