Skip to content

Commit 8486b44

Browse files
committed
discrete profiler support agent loop
1 parent f11af2d commit 8486b44

File tree

5 files changed

+74
-1
lines changed

5 files changed

+74
-1
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
819819

820820
# Fix for Issue #4147: Always call wake_up() to ensure weight sync
821821
# The wake_up()/sleep() methods internally check free_cache_engine
822+
self._start_profile(role="rollout_generate")
822823
self.wake_up()
823824
if self.reward_model_manager:
824825
self.reward_model_manager.wake_up()
@@ -839,8 +840,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
839840
# calculate performance metrics
840841
metrics = [output.meta_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
841842
timing = self._performance_metrics(metrics, output)
842-
843843
output.meta_info = {"timing": timing, **outputs[0].meta_info}
844+
self._stop_profile()
844845
return output
845846

846847
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
@@ -877,6 +878,14 @@ def clear_kv_cache(self):
877878
"""Clear all rollout kv cache, but don`t sleep."""
878879
self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])
879880

881+
def _start_profile(self, **kwargs):
882+
"""Start profiling on all rollout replicas."""
883+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
884+
885+
def _stop_profile(self):
886+
"""Stop profiling on all rollout replicas."""
887+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
888+
880889
def _run_all(self, tasks: list[asyncio.Task]):
881890
async def run_all():
882891
await asyncio.gather(*tasks)

verl/utils/profiler/mstx_profile.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,46 @@ def stop(self):
214214
self.profile_npu.stop()
215215
NPUProfiler._define_count -= 1
216216

217+
def capture_start(self, **kwargs):
218+
"""
219+
Pure start interface for manual control, decoupled from training steps but enabling decorators.
220+
Starts the profiler and sets 'this_step' to True so that decorated functions are instrumented.
221+
"""
222+
if not (self.enable and self.this_step):
223+
return
224+
225+
message = kwargs.get("message")
226+
role = kwargs.get("role")
227+
profile_name = message or role
228+
229+
if not self.discrete:
230+
self._manual_range = mark_start_range(message=profile_name)
231+
else:
232+
self._manual_profiler = get_npu_profiler(
233+
contents=self.profile_contents,
234+
profile_level=self.profile_level,
235+
profile_save_path=self.profile_save_path,
236+
analysis=self.analysis,
237+
role=role,
238+
)
239+
self._manual_profiler.start()
240+
self._manual_range = mark_start_range(message=profile_name)
241+
242+
def capture_stop(self):
243+
"""Pure stop interface."""
244+
if not (self.enable and self.this_step):
245+
return
246+
247+
# End manual range
248+
if hasattr(self, "_manual_range"):
249+
mark_end_range(self._manual_range)
250+
del self._manual_range
251+
252+
if self.discrete and hasattr(self, "_manual_profiler"):
253+
self._manual_profiler.step()
254+
self._manual_profiler.stop()
255+
del self._manual_profiler
256+
217257
def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:
218258
"""Decorate a Worker member function to profile the current rank in the current training step.
219259

verl/utils/profiler/profile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ def start(self, **kwargs):
227227
def stop(self):
228228
return getattr(self._impl, "stop", lambda: None)()
229229

230+
def capture_start(self, **kwargs):
231+
return getattr(self._impl, "capture_start", lambda **_: None)(**kwargs)
232+
233+
def capture_stop(self):
234+
return getattr(self._impl, "capture_stop", lambda: None)()
235+
230236
@classmethod
231237
def annotate(
232238
cls,

verl/workers/fsdp_workers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,16 @@ async def sleep(self):
19391939
await self.trainer_mode()
19401940
return True
19411941

1942+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
1943+
async def start_capture_profile(self, **kwargs):
1944+
self.profiler.capture_start(**kwargs)
1945+
return True
1946+
1947+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
1948+
async def stop_capture_profile(self):
1949+
self.profiler.capture_stop()
1950+
return True
1951+
19421952
# ============================ vLLM related ============================
19431953

19441954
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)

verl/workers/rollout/replica.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ async def clear_kv_cache(self):
216216
"""reset kv cache in each rollout server."""
217217
await asyncio.gather(*[server.clear_kv_cache.remote() for server in self.servers])
218218

219+
async def start_profile(self, **kwargs):
220+
"""Start profiling on all workers."""
221+
await asyncio.gather(*[worker.start_capture_profile.remote(**kwargs) for worker in self.workers])
222+
223+
async def stop_profile(self):
224+
"""Stop profiling on all workers."""
225+
await asyncio.gather(*[worker.stop_capture_profile.remote() for worker in self.workers])
226+
219227

220228
class RolloutReplicaRegistry:
221229
"""Factory for managing rollout replica implementations."""

0 commit comments

Comments
 (0)