Skip to content

Commit b76689b

Browse files
committed
discrete profiler support agent loop
1 parent f11af2d commit b76689b

File tree

5 files changed

+71
-0
lines changed

5 files changed

+71
-0
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
819819

820820
# Fix for Issue #4147: Always call wake_up() to ensure weight sync
821821
# The wake_up()/sleep() methods internally check free_cache_engine
822+
self._start_profile(role="rollout_generate")
822823
self.wake_up()
823824
if self.reward_model_manager:
824825
self.reward_model_manager.wake_up()
@@ -841,6 +842,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
841842
timing = self._performance_metrics(metrics, output)
842843

843844
output.meta_info = {"timing": timing, **outputs[0].meta_info}
845+
self._stop_profile()
844846
return output
845847

846848
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
@@ -877,6 +879,14 @@ def clear_kv_cache(self):
877879
"""Clear all rollout kv cache, but don`t sleep."""
878880
self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])
879881

882+
def _start_profile(self, **kwargs):
883+
"""Start profiling on all rollout replicas."""
884+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
885+
886+
def _stop_profile(self):
887+
"""Stop profiling on all rollout replicas."""
888+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
889+
880890
def _run_all(self, tasks: list[asyncio.Task]):
881891
async def run_all():
882892
await asyncio.gather(*tasks)

verl/utils/profiler/mstx_profile.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,43 @@ def stop(self):
214214
self.profile_npu.stop()
215215
NPUProfiler._define_count -= 1
216216

217+
def capture_start(self, **kwargs):
218+
"""Start an on-demand profiling segment."""
219+
if not (self.enable and self.this_step):
220+
return
221+
222+
message = kwargs.get("message")
223+
role = kwargs.get("role")
224+
profile_name = message or role
225+
226+
if not self.discrete:
227+
self._capture_range_id = mark_start_range(message=profile_name)
228+
else:
229+
self.capture_profiler_npu = get_npu_profiler(
230+
contents=self.profile_contents,
231+
profile_level=self.profile_level,
232+
profile_save_path=self.profile_save_path,
233+
analysis=self.analysis,
234+
role=role,
235+
)
236+
self.capture_profiler_npu.start()
237+
self._capture_range_id = mark_start_range(message=profile_name)
238+
239+
def capture_stop(self):
240+
"""Stop the on-demand profiling segment."""
241+
if not (self.enable and self.this_step):
242+
return
243+
244+
# End manual range
245+
if hasattr(self, "_capture_range_id"):
246+
mark_end_range(self._capture_range_id)
247+
del self._capture_range_id
248+
249+
if self.discrete and getattr(self, "capture_profiler_npu", None):
250+
self.capture_profiler_npu.step()
251+
self.capture_profiler_npu.stop()
252+
del self.capture_profiler_npu
253+
217254
def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:
218255
"""Decorate a Worker member function to profile the current rank in the current training step.
219256

verl/utils/profiler/profile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ def start(self, **kwargs):
227227
def stop(self):
228228
return getattr(self._impl, "stop", lambda: None)()
229229

230+
def capture_start(self, **kwargs):
231+
return getattr(self._impl, "capture_start", lambda **_: None)(**kwargs)
232+
233+
def capture_stop(self):
234+
return getattr(self._impl, "capture_stop", lambda: None)()
235+
230236
@classmethod
231237
def annotate(
232238
cls,

verl/workers/fsdp_workers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,16 @@ async def sleep(self):
19391939
await self.trainer_mode()
19401940
return True
19411941

1942+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
1943+
async def start_capture_profile(self, **kwargs):
1944+
self.profiler.capture_start(**kwargs)
1945+
return True
1946+
1947+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
1948+
async def stop_capture_profile(self):
1949+
self.profiler.capture_stop()
1950+
return True
1951+
19421952
# ============================ vLLM related ============================
19431953

19441954
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)

verl/workers/rollout/replica.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ async def clear_kv_cache(self):
216216
"""reset kv cache in each rollout server."""
217217
await asyncio.gather(*[server.clear_kv_cache.remote() for server in self.servers])
218218

219+
async def start_profile(self, **kwargs):
220+
"""Start profiling on all workers."""
221+
await asyncio.gather(*[worker.start_capture_profile.remote(**kwargs) for worker in self.workers])
222+
223+
async def stop_profile(self):
224+
"""Stop profiling on all workers."""
225+
await asyncio.gather(*[worker.stop_capture_profile.remote() for worker in self.workers])
226+
219227

220228
class RolloutReplicaRegistry:
221229
"""Factory for managing rollout replica implementations."""

0 commit comments

Comments
 (0)