Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/e2e_ascend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ jobs:
ray stop --force
bash tests/special_npu/run_qwen2_5_05b_grpo.sh
rm -rf $HOME/ckpts
- name: Running gsm8k e2e training tests with GRPO on ASCEND NPU (skip actor update via critic warmup)
run: |
ray stop --force
bash tests/special_npu/run_qwen2_5_05b_grpo_critic_warmup_skip_actor.sh
rm -rf $HOME/ckpts
- name: Running gsm8k e2e training tests with GRPO on ASCEND NPU (MindSpeed backend)
run: |
ray stop --force
Expand Down
48 changes: 48 additions & 0 deletions tests/special_npu/run_qwen2_5_05b_grpo_critic_warmup_skip_actor.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
set -x

MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
MODEL_PATH=${MODEL_PATH:-${HOME}/.cache/models/${MODEL_ID}}

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=16 \
data.max_prompt_length=512 \
data.max_response_length=128 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=5e-7 \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.use_torch_compile=False \
actor_rollout_ref.ref.use_torch_compile=False \
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_AND_PIECEWISE" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.n=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=10 \
trainer.logger=console \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2_5_05b_grpo_critic_warmup_skip_actor' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=-1 \
trainer.total_epochs=1 \
trainer.total_training_steps=2 $@
27 changes: 22 additions & 5 deletions verl/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from verl.utils.distributed import initialize_global_process_group_ray
from verl.utils.ray_utils import auto_await
from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig
from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class
from verl.workers.rollout import BaseRollout, get_rollout_class
from verl.workers.rollout.replica import RolloutMode, RolloutReplica


class TensorMeta(TypedDict):
Expand Down Expand Up @@ -395,11 +396,27 @@ async def sleep_replicas(self):
"""Sleep all rollout replicas: free weight and kv_cache device memory."""
await asyncio.gather(*[r.sleep() for r in self.replicas])

@auto_await
async def wake_up_replicas(self):
"""Resume all rollout replicas: recover kv_cache and weights device memory."""
async def _direct_wake_up_replicas(self):
"""Directly wake rollout replicas without any fallback weight sync."""
await asyncio.gather(*[r.wake_up() for r in self.replicas])

@auto_await
async def wake_up_replicas(self, global_steps: int = None):
"""Resume rollout replicas without forcing a full actor update when possible.

For colocated and standalone rollout replicas, we can directly wake the inference
engines. Hybrid rollout servers recover through ``update_weights()``, so we fall
back to that path when needed.

Args:
global_steps: The global steps of the trainer, forwarded when fallback weight
update is required.
"""
if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas):
await self.update_weights(global_steps=global_steps)
return
await self._direct_wake_up_replicas()
Comment on lines +415 to +418
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The global_steps parameter is optional (global_steps: int = None), which could lead to issues. If wake_up_replicas is called without global_steps for a manager with hybrid replicas, it will fall back to self.update_weights(global_steps=None). This could cause downstream errors if components like the learning rate scheduler expect an integer step count and do not handle None.

To improve robustness and make the API contract clearer, it's better to enforce that global_steps is provided when hybrid replicas are present.

Suggested change
if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas):
await self.update_weights(global_steps=global_steps)
return
await self._direct_wake_up_replicas()
if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas):
if global_steps is None:
raise ValueError("`global_steps` must be provided for hybrid replicas.")
await self.update_weights(global_steps=global_steps)
return
await self._direct_wake_up_replicas()


@auto_await
async def update_weights(self, global_steps: int = None):
"""Update weights from trainer to rollout replicas.
Expand Down Expand Up @@ -439,7 +456,7 @@ async def update_weights(self, global_steps: int = None):
)

# 7. resume replicas to recover kv_cache (for free_cache_engine scenarios)
await self.wake_up_replicas()
await self._direct_wake_up_replicas()

# 8. resume all unfinished requests for partial rollout
await asyncio.gather(*[r.resume_generation() for r in self.replicas])
6 changes: 6 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,12 @@ def fit(self):

actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
elif self.config.actor_rollout_ref.rollout.free_cache_engine:
# Rollout replicas were put to sleep after generation. When actor update is
# skipped by critic warmup, there is no trailing update_weights() call to
# wake them back up for the next step.
with marked_timer("wake_up_rollout", timing_raw, color="red"):
self.checkpoint_manager.wake_up_replicas(self.global_steps)

# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
Expand Down