Skip to content

Commit d721a17

Browse files
author
tangmengcheng
committed
profiler bug fix for agent loop
1 parent c12b0cf commit d721a17

File tree

5 files changed

+41
-8
lines changed

5 files changed

+41
-8
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
781781
Returns:
782782
DataProto: Output batch.
783783
"""
784-
784+
self.start_profile(async_start=True)
785785
if self.config.actor_rollout_ref.rollout.free_cache_engine:
786786
self.wake_up()
787787
if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
@@ -803,7 +803,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
803803
# calculate performance metrics
804804
metrics = [output.meta_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
805805
timing = self._performance_metrics(metrics, output)
806-
806+
self.stop_profile()
807807
output.meta_info = {"timing": timing, **outputs[0].meta_info}
808808
return output
809809

@@ -837,6 +837,14 @@ def sleep(self):
837837
"""Sleep all rollout replica instances."""
838838
self._run_all([replica.sleep() for replica in self.rollout_replicas])
839839

840+
def start_profile(self, **kwargs):
841+
"""Start profiling on all replicas."""
842+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
843+
844+
def stop_profile(self):
845+
"""Stop profiling on all replicas."""
846+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
847+
840848
def _run_all(self, tasks: list[asyncio.Task]):
841849
async def run_all():
842850
await asyncio.gather(*tasks)

verl/trainer/ppo/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def _load_checkpoint(self):
897897
def _start_profiling(self, do_profile: bool) -> None:
898898
"""Start profiling for all worker groups if profiling is enabled."""
899899
if do_profile:
900-
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
900+
self.actor_rollout_wg.start_profile(profile_step=self.global_steps)
901901
if self.use_reference_policy:
902902
self.ref_policy_wg.start_profile(profile_step=self.global_steps)
903903
if self.use_critic:

verl/utils/profiler/mstx_profile.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,31 @@ def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig
192192
def start(self, **kwargs):
193193
role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None)
194194
profile_step = str(profile_step) if profile_step is not None else None
195+
self.async_start = kwargs.get("async_start", False)
195196
if self.enable and self.this_rank:
196197
self.this_step = True
197-
if not self.discrete and NPUProfiler._define_count == 0:
198+
if (not self.discrete or self.async_start) and NPUProfiler._define_count == 0:
199+
if not self.discrete:
200+
prof_role = "e2e"
201+
prof_step = profile_step
202+
else:
203+
prof_role = role
204+
prof_step = None
198205
self.profile_npu = get_npu_profiler(
199206
contents=self.profile_contents,
200207
profile_level=self.profile_level,
201208
profile_save_path=self.profile_save_path,
202209
analysis=self.analysis,
203-
role=role,
204-
profile_step=profile_step,
210+
role=prof_role,
211+
profile_step=prof_step,
205212
)
206213
self.profile_npu.start()
207214
NPUProfiler._define_count += 1
208215

209216
def stop(self):
210217
if self.enable and self.this_rank:
211218
self.this_step = False
212-
if not self.discrete and NPUProfiler._define_count == 1:
219+
if (not self.discrete or self.async_start) and NPUProfiler._define_count == 1:
213220
self.profile_npu.step()
214221
self.profile_npu.stop()
215222
NPUProfiler._define_count -= 1

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,16 @@ async def sleep(self):
440440
async def wait_for_requests_to_drain(self):
441441
await self.engine.wait_for_requests_to_drain()
442442

443+
def start_profile(self, **kwargs):
444+
"""Start profiling on all workers."""
445+
if self.workers:
446+
ray.get([worker.start_profile.remote(**kwargs) for worker in self.workers])
447+
448+
def stop_profile(self):
449+
"""Stop profiling on all workers."""
450+
if self.workers:
451+
ray.get([worker.stop_profile.remote() for worker in self.workers])
452+
443453

444454
@ray.remote(num_cpus=1)
445455
class vLLMHttpServer(vLLMHttpServerBase):
@@ -559,6 +569,14 @@ async def sleep(self):
559569
await self.servers[0].wait_for_requests_to_drain.remote()
560570
await asyncio.gather(*[server.sleep.remote() for server in self.servers])
561571

572+
async def start_profile(self, **kwargs):
573+
"""Start profiling on all servers."""
574+
await asyncio.gather(*[server.start_profile.remote(**kwargs) for server in self.servers])
575+
576+
async def stop_profile(self):
577+
"""Stop profiling on all servers."""
578+
await asyncio.gather(*[server.stop_profile.remote() for server in self.servers])
579+
562580

563581
def _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):
564582
"""Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from verl.utils.distributed import initialize_global_process_group_ray
7878
from verl.utils.import_utils import deprecated
7979
from verl.utils.model import get_lora_rank_from_adapter
80-
from verl.utils.profiler import GPUMemoryLogger
80+
from verl.utils.profiler import GPUMemoryLogger, mark_start_range, mark_end_range
8181
from verl.utils.ray_utils import ray_noset_visible_devices
8282
from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length
8383
from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge

0 commit comments

Comments
 (0)