Skip to content

Commit b620922

Browse files
author
tangmengcheng
committed
profiler bug fix for agent loop
1 parent c12b0cf commit b620922

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,14 @@ def sleep(self):
837837
"""Sleep all rollout replica instances."""
838838
self._run_all([replica.sleep() for replica in self.rollout_replicas])
839839

840+
def start_profile(self, **kwargs):
841+
"""Start profiling on all replicas."""
842+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
843+
844+
def stop_profile(self):
845+
"""Stop profiling on all replicas."""
846+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
847+
840848
def _run_all(self, tasks: list[asyncio.Task]):
841849
async def run_all():
842850
await asyncio.gather(*tasks)

verl/trainer/ppo/ray_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,8 @@ def _start_profiling(self, do_profile: bool) -> None:
904904
self.critic_wg.start_profile(profile_step=self.global_steps)
905905
if self.use_rm:
906906
self.rm_wg.start_profile(profile_step=self.global_steps)
907+
if hasattr(self, "async_rollout_manager") and self.async_rollout_manager is not None:
908+
self.async_rollout_manager.start_profile(role="rollout", profile_step=self.global_steps)
907909

908910
def _stop_profiling(self, do_profile: bool) -> None:
909911
"""Stop profiling for all worker groups if profiling is enabled."""
@@ -915,6 +917,8 @@ def _stop_profiling(self, do_profile: bool) -> None:
915917
self.critic_wg.stop_profile()
916918
if self.use_rm:
917919
self.rm_wg.stop_profile()
920+
if hasattr(self, "async_rollout_manager") and self.async_rollout_manager is not None:
921+
self.async_rollout_manager.stop_profile()
918922

919923
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
920924
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,16 @@ async def sleep(self):
440440
async def wait_for_requests_to_drain(self):
441441
await self.engine.wait_for_requests_to_drain()
442442

443+
def start_profile(self, **kwargs):
444+
"""Start profiling on all workers."""
445+
if self.workers:
446+
ray.get([worker.start_profile.remote(**kwargs) for worker in self.workers])
447+
448+
def stop_profile(self):
449+
"""Stop profiling on all workers."""
450+
if self.workers:
451+
ray.get([worker.stop_profile.remote() for worker in self.workers])
452+
443453

444454
@ray.remote(num_cpus=1)
445455
class vLLMHttpServer(vLLMHttpServerBase):
@@ -559,6 +569,14 @@ async def sleep(self):
559569
await self.servers[0].wait_for_requests_to_drain.remote()
560570
await asyncio.gather(*[server.sleep.remote() for server in self.servers])
561571

572+
async def start_profile(self, **kwargs):
573+
"""Start profiling on all servers."""
574+
await asyncio.gather(*[server.start_profile.remote(**kwargs) for server in self.servers])
575+
576+
async def stop_profile(self):
577+
"""Stop profiling on all servers."""
578+
await asyncio.gather(*[server.stop_profile.remote() for server in self.servers])
579+
562580

563581
def _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):
564582
"""Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the

0 commit comments

Comments
 (0)