diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index d4ea77ad143..769356c0a1c 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -109,6 +109,11 @@ jobs: ray stop --force bash tests/special_npu/run_qwen2_5_05b_grpo.sh rm -rf $HOME/ckpts + - name: Running gsm8k e2e training tests with GRPO on ASCEND NPU (skip actor update via critic warmup) + run: | + ray stop --force + bash tests/special_npu/run_qwen2_5_05b_grpo_critic_warmup_skip_actor.sh + rm -rf $HOME/ckpts - name: Running gsm8k e2e training tests with GRPO on ASCEND NPU (MindSpeed backend) run: | ray stop --force diff --git a/tests/special_npu/run_qwen2_5_05b_grpo_critic_warmup_skip_actor.sh b/tests/special_npu/run_qwen2_5_05b_grpo_critic_warmup_skip_actor.sh new file mode 100644 index 00000000000..0da5105cba0 --- /dev/null +++ b/tests/special_npu/run_qwen2_5_05b_grpo_critic_warmup_skip_actor.sh @@ -0,0 +1,48 @@ +set -x + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} +MODEL_PATH=${MODEL_PATH:-${HOME}/.cache/models/${MODEL_ID}} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=16 \ + data.max_prompt_length=512 \ + data.max_response_length=128 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_AND_PIECEWISE" \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=10 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_05b_grpo_critic_warmup_skip_actor' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=1 \ + trainer.total_training_steps=2 $@ diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py index 6b3a7cd2584..c203a2b3ba1 100644 --- a/verl/checkpoint_engine/base.py +++ b/verl/checkpoint_engine/base.py @@ -24,7 +24,8 @@ from verl.utils.distributed import initialize_global_process_group_ray from verl.utils.ray_utils import auto_await from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig -from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class +from verl.workers.rollout import BaseRollout, get_rollout_class +from verl.workers.rollout.replica import RolloutMode, RolloutReplica class TensorMeta(TypedDict): @@ -395,11 +396,27 @@ async def sleep_replicas(self): """Sleep all rollout replicas: free weight and kv_cache device memory.""" await asyncio.gather(*[r.sleep() for r in self.replicas]) - @auto_await - async def wake_up_replicas(self): - """Resume all rollout replicas: recover kv_cache and weights device memory.""" + async def _direct_wake_up_replicas(self): + """Directly wake rollout replicas without any fallback weight sync.""" await asyncio.gather(*[r.wake_up() for r in self.replicas]) + @auto_await + async def wake_up_replicas(self, global_steps: int = None): + """Resume rollout replicas without forcing a full actor update when possible. + + For colocated and standalone rollout replicas, we can directly wake the inference + engines. Hybrid rollout servers recover through ``update_weights()``, so we fall + back to that path when needed. + + Args: + global_steps: The global steps of the trainer, forwarded when fallback weight + update is required. + """ + if any(replica.rollout_mode == RolloutMode.HYBRID for replica in self.replicas): + await self.update_weights(global_steps=global_steps) + return + await self._direct_wake_up_replicas() + @auto_await async def update_weights(self, global_steps: int = None): """Update weights from trainer to rollout replicas. @@ -439,7 +456,7 @@ async def update_weights(self, global_steps: int = None): ) # 7. resume replicas to recover kv_cache (for free_cache_engine scenarios) - await self.wake_up_replicas() + await self._direct_wake_up_replicas() # 8. resume all unfinished requests for partial rollout await asyncio.gather(*[r.resume_generation() for r in self.replicas]) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e178ffc143d..099d724fa95 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1534,6 +1534,12 @@ def fit(self): actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) + elif self.config.actor_rollout_ref.rollout.free_cache_engine: + # Rollout replicas were put to sleep after generation. When actor update is + # skipped by critic warmup, there is no trailing update_weights() call to + # wake them back up for the next step. + with marked_timer("wake_up_rollout", timing_raw, color="red"): + self.checkpoint_manager.wake_up_replicas(self.global_steps) # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)