[trainer,ckpt,rollout] fix: wake up rollout replicas when actor update is skipped by critic warmup#5590
Conversation
…e is skipped by critic warmup When free_cache_engine=True, PPO sleeps rollout replicas after each rollout. Normally the actor update path calls checkpoint_manager.update_weights(), which also wakes the rollout engine back up for the next step. However, when actor update is skipped because trainer.critic_warmup > global_steps, that wake-up never happens. As a result, the next rollout may start with the inference engine still in sleep state and trigger downstream vLLM runtime errors. This change adds a lightweight checkpoint manager wake-up path and invokes it at the end of the PPO step only when actor update is skipped by critic warmup. For colocated and standalone replicas it directly calls replica.wake_up(); for hybrid replicas it falls back to update_weights(global_steps).
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a bug where rollout replicas were not woken up when actor updates were skipped during critic warmup. The introduction of a dedicated wake_up_replicas method is a clean solution. The logic to handle hybrid replicas by falling back to a full weight update is sound, and the changes in the trainer correctly invoke this new path. The addition of a test case to cover this scenario is also a great inclusion. I have one suggestion to improve the robustness of the new API.
| if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas): | ||
| await self.update_weights(global_steps=global_steps) | ||
| return | ||
| await self._direct_wake_up_replicas() |
There was a problem hiding this comment.
The global_steps parameter is optional (global_steps: int = None), which could lead to issues. If wake_up_replicas is called without global_steps for a manager with hybrid replicas, it will fall back to self.update_weights(global_steps=None). This could cause downstream errors if components like the learning rate scheduler expect an integer step count and do not handle None.
To improve robustness and make the API contract clearer, it's better to enforce that global_steps is provided when hybrid replicas are present.
| if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas): | |
| await self.update_weights(global_steps=global_steps) | |
| return | |
| await self._direct_wake_up_replicas() | |
| if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas): | |
| if global_steps is None: | |
| raise ValueError("`global_steps` must be provided for hybrid replicas.") | |
| await self.update_weights(global_steps=global_steps) | |
| return | |
| await self._direct_wake_up_replicas() |
When free_cache_engine=True, PPO sleeps rollout replicas after each rollout. Normally the actor update path calls checkpoint_manager.update_weights(), which also wakes the rollout engine back up for the next step.
However, when actor update is skipped because trainer.critic_warmup > global_steps, that wake-up never happens. As a result, the next rollout may start with the inference engine still in sleep state and trigger downstream vLLM runtime errors.
This change adds a lightweight checkpoint manager wake-up path and invokes it at the end of the PPO step only when actor update is skipped by critic warmup. For colocated and standalone replicas it directly calls replica.wake_up(); for hybrid replicas it falls back to update_weights(global_steps).
What does this PR do?
Fixes a PPO trainer bug on the
critic_warmupskip-actor path.When
actor_rollout_ref.rollout.free_cache_engine=True, rollout replicas are put into sleep state after each rollout to release weights / KV cache. In the normal path, the step later enters actor update and callscheckpoint_manager.update_weights(), which also wakes the rollout engine back up.However, if actor update is skipped because
trainer.critic_warmup > global_steps, thatupdate_weights()call is skipped as well. The rollout engine then remains asleep into the next step, which can surface as downstream vLLM / Ascend NPU runtime failures instead of a clear logical error.This PR adds a lightweight rollout wake-up path and calls it only on the
critic_warmupskip-actor branch.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
I encounterd this bug on Ascend NPU, so i validated on Ascend NPU only.
I currently only have access to Ascend NPU hardware and do not have an NVIDIA GPU environment available, so I could not verify this fix on NVIDIA GPU locally.
Before this fix, the trainer will crash after step 1, logs: link, crashed on line 1051
After this fix, it can finish step 2: link, finished step 2 on line 914
API and Usage Example
No API or config change.
Design & Code Changes
checkpoint_manager.wake_up_replicas(global_steps=None)as a lightweight recovery path.replica.wake_up().trainer.update_weights(global_steps)semantics viacheckpoint_manager.update_weights(global_steps).trainer.critic_warmup > global_stepsand
free_cache_engine=True, call the lightweight wake-up path at the end of the step.Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main. Not related.