Skip to content

Commit 9204dd5

Browse files
discrete profiler support agent loop
Co-authored-by: tardis-key <[email protected]>
1 parent f11af2d commit 9204dd5

File tree

9 files changed

+212
-106
lines changed

9 files changed

+212
-106
lines changed

recipe/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,9 @@ async def sleep(self):
325325

326326
async def clear_kv_cache(self):
327327
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
328+
329+
async def start_profile(self, **kwargs):
330+
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
331+
332+
async def stop_profile(self):
333+
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])

recipe/one_step_off_policy/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@ async def sleep(self):
6262

6363
async def clear_kv_cache(self):
6464
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
65+
66+
async def start_profile(self, **kwargs):
67+
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
68+
69+
async def stop_profile(self):
70+
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])

verl/experimental/agent_loop/agent_loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
819819

820820
# Fix for Issue #4147: Always call wake_up() to ensure weight sync
821821
# The wake_up()/sleep() methods internally check free_cache_engine
822+
self.start_profile(role="agent_loop_rollout_generate")
822823
self.wake_up()
823824
if self.reward_model_manager:
824825
self.reward_model_manager.wake_up()
@@ -841,6 +842,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
841842
timing = self._performance_metrics(metrics, output)
842843

843844
output.meta_info = {"timing": timing, **outputs[0].meta_info}
845+
self.stop_profile()
844846
return output
845847

846848
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
@@ -877,6 +879,14 @@ def clear_kv_cache(self):
877879
"""Clear all rollout kv cache, but don`t sleep."""
878880
self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])
879881

882+
def start_profile(self, **kwargs):
883+
"""Start profiling on all rollout replicas."""
884+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
885+
886+
def stop_profile(self):
887+
"""Stop profiling on all rollout replicas."""
888+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
889+
880890
def _run_all(self, tasks: list[asyncio.Task]):
881891
async def run_all():
882892
await asyncio.gather(*tasks)

verl/utils/profiler/mstx_profile.py

Lines changed: 68 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -173,46 +173,59 @@ def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig
173173
config = ProfilerConfig(ranks=[], enable=False)
174174
if not tool_config:
175175
assert not config.enable, "tool_config must be set when profiler is enabled"
176-
self.enable: bool = config.enable
177-
if not config.enable:
178-
return
179-
self.this_step: bool = False
180176
self.discrete: bool = tool_config.discrete
181-
self.this_rank: bool = False
182-
self.profile_npu = None
177+
self.e2e_profile_npu = None
178+
self.capture_profiler_npu = None
183179
self.profile_contents = tool_config.contents
184180
self.profile_level = tool_config.level
185181
self.profile_save_path = config.save_path
186182
self.analysis = tool_config.analysis
187-
if config.all_ranks:
188-
self.this_rank = True
189-
elif config.ranks:
190-
self.this_rank = rank in config.ranks
191-
192-
def start(self, **kwargs):
193-
role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None)
194-
profile_step = str(profile_step) if profile_step is not None else None
195-
if self.enable and self.this_rank:
196-
self.this_step = True
197-
if not self.discrete and NPUProfiler._define_count == 0:
198-
self.profile_npu = get_npu_profiler(
199-
contents=self.profile_contents,
200-
profile_level=self.profile_level,
201-
profile_save_path=self.profile_save_path,
202-
analysis=self.analysis,
203-
role=role,
204-
profile_step=profile_step,
205-
)
206-
self.profile_npu.start()
207-
NPUProfiler._define_count += 1
208-
209-
def stop(self):
210-
if self.enable and self.this_rank:
211-
self.this_step = False
212-
if not self.discrete and NPUProfiler._define_count == 1:
213-
self.profile_npu.step()
214-
self.profile_npu.stop()
215-
NPUProfiler._define_count -= 1
183+
184+
def start_e2e_profiler(self, **kwargs):
185+
role = kwargs.get("role", None)
186+
if not self.discrete and NPUProfiler._define_count == 0:
187+
self.profile_npu = get_npu_profiler(
188+
contents=self.profile_contents,
189+
profile_level=self.profile_level,
190+
profile_save_path=self.profile_save_path,
191+
analysis=self.analysis,
192+
role=role,
193+
)
194+
self.profile_npu.start()
195+
NPUProfiler._define_count += 1
196+
197+
def stop_e2e_profiler(self):
198+
if not self.discrete and NPUProfiler._define_count == 1:
199+
self.profile_npu.step()
200+
self.profile_npu.stop()
201+
NPUProfiler._define_count -= 1
202+
203+
def start_capture_profiler(self, **kwargs):
204+
"""Start an on-demand profiling segment."""
205+
role = kwargs.get("role", "")
206+
207+
if self.discrete:
208+
self.capture_profiler_npu = get_npu_profiler(
209+
contents=self.profile_contents,
210+
profile_level=self.profile_level,
211+
profile_save_path=self.profile_save_path,
212+
analysis=self.analysis,
213+
role=role,
214+
)
215+
self.capture_profiler_npu.start()
216+
217+
self._capture_range_id = mark_start_range(message=role)
218+
219+
def stop_capture_profiler(self):
220+
"""Stop the on-demand profiling segment."""
221+
if hasattr(self, "_capture_range_id"):
222+
mark_end_range(self._capture_range_id)
223+
del self._capture_range_id
224+
225+
if self.discrete and getattr(self, "capture_profiler_npu", None):
226+
self.capture_profiler_npu.step()
227+
self.capture_profiler_npu.stop()
228+
del self.capture_profiler_npu
216229

217230
def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:
218231
"""Decorate a Worker member function to profile the current rank in the current training step.
@@ -230,42 +243,33 @@ def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **
230243
def decorator(func):
231244
@functools.wraps(func)
232245
def wrapper(*args, **kwargs_inner):
233-
if not self.enable:
234-
return func(*args, **kwargs_inner)
235-
236246
profile_name = message or func.__name__
237247
discrete_mode = self.discrete
238-
profile_enable = self.this_step and self.enable
239-
240-
if not profile_enable:
241-
return func(*args, **kwargs_inner)
242-
243-
if profile_enable:
244-
if not discrete_mode:
245-
mark_range = mark_start_range(message=profile_name)
246-
else:
247-
profile_npu = get_npu_profiler(
248-
contents=self.profile_contents,
249-
profile_level=self.profile_level,
250-
profile_save_path=self.profile_save_path,
251-
analysis=self.analysis,
252-
role=role,
253-
)
254-
profile_npu.start()
255-
mark_range = mark_start_range(message=profile_name)
248+
249+
if not discrete_mode:
250+
mark_range = mark_start_range(message=profile_name)
251+
else:
252+
profile_npu = get_npu_profiler(
253+
contents=self.profile_contents,
254+
profile_level=self.profile_level,
255+
profile_save_path=self.profile_save_path,
256+
analysis=self.analysis,
257+
role=role,
258+
)
259+
profile_npu.start()
260+
mark_range = mark_start_range(message=profile_name)
256261

257262
result = func(*args, **kwargs_inner)
258263

259-
if profile_enable:
260-
if not discrete_mode:
261-
mark_end_range(mark_range)
262-
else:
263-
mark_end_range(mark_range)
264-
profile_npu.step()
265-
profile_npu.stop()
264+
if not discrete_mode:
265+
mark_end_range(mark_range)
266+
else:
267+
mark_end_range(mark_range)
268+
profile_npu.step()
269+
profile_npu.stop()
266270

267271
return result
268272

269273
return wrapper
270274

271-
return decorator
275+
return decorator

verl/utils/profiler/nvtx_profile.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -129,25 +129,23 @@ def __init__(self, rank: int, config: Optional[ProfilerConfig], tool_config: Opt
129129
self.enable = config.enable
130130
if not config.enable:
131131
return
132-
self.this_step: bool = False
133132
self.discrete: bool = tool_config.discrete
134-
self.this_rank: bool = False
135-
if config.all_ranks:
136-
self.this_rank = True
137-
elif config.ranks:
138-
self.this_rank = rank in config.ranks
139-
140-
def start(self, **kwargs):
141-
if self.enable and self.this_rank:
142-
self.this_step = True
143-
if not self.discrete:
144-
torch.cuda.profiler.start()
145-
146-
def stop(self):
147-
if self.enable and self.this_rank:
148-
self.this_step = False
149-
if not self.discrete:
150-
torch.cuda.profiler.stop()
133+
134+
def start_e2e_profiler(self, **kwargs):
135+
if not self.discrete:
136+
torch.cuda.profiler.start()
137+
138+
def stop_e2e_profiler(self):
139+
if not self.discrete:
140+
torch.cuda.profiler.stop()
141+
142+
def start_capture_profiler(self, **kwargs):
143+
if self.discrete:
144+
torch.cuda.profiler.start()
145+
146+
def stop_capture_profiler(self):
147+
if self.discrete:
148+
torch.cuda.profiler.stop()
151149

152150
def annotate(
153151
self,
@@ -176,22 +174,17 @@ def annotate(
176174
def decorator(func):
177175
@functools.wraps(func)
178176
def wrapper(*args, **kwargs_inner):
179-
if not self.enable:
180-
return func(*args, **kwargs_inner)
181-
182177
profile_name = message or func.__name__
183178

184-
if self.this_step:
185-
if self.discrete:
186-
torch.cuda.profiler.start()
187-
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
179+
if self.discrete:
180+
torch.cuda.profiler.start()
181+
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
188182

189183
result = func(*args, **kwargs_inner)
190184

191-
if self.this_step:
192-
mark_end_range(mark_range)
193-
if self.discrete:
194-
torch.cuda.profiler.stop()
185+
mark_end_range(mark_range)
186+
if self.discrete:
187+
torch.cuda.profiler.stop()
195188

196189
return result
197190

verl/utils/profiler/profile.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
import os
17+
from tabnanny import check
1718
from typing import Callable, Optional
1819

1920
import torch
@@ -183,6 +184,8 @@ class DistProfiler:
183184
- torch_memory: Torch CUDA memory snapshot dump
184185
"""
185186

187+
_this_step = False
188+
186189
def __init__(
187190
self, rank: int, config: Optional[ProfilerConfig] = None, tool_config: Optional[object] = None, **kwargs
188191
):
@@ -192,6 +195,7 @@ def __init__(
192195

193196
self._impl = None
194197
self._tool = getattr(config, "tool", None)
198+
self._enable = config.enable
195199

196200
# Normalize rank selection
197201
self._this_rank = False
@@ -201,7 +205,7 @@ def __init__(
201205
self._this_rank = rank in config.ranks
202206
else:
203207
# default rank 0 if enabled but ranks unspecified
204-
self._this_rank = (rank == 0) if config.enable else False
208+
self._this_rank = (rank == 0) if self._enable else False
205209

206210
# Lazy import to avoid circular deps
207211
if self._tool == "nsys":
@@ -221,11 +225,24 @@ def __init__(
221225
# Fallback to a no-op impl
222226
self._impl = _NoOpProfiler()
223227

224-
def start(self, **kwargs):
225-
return getattr(self._impl, "start", lambda **_: None)(**kwargs)
228+
def check_enable(self):
229+
return self._enable and self._this_rank and DistProfiler._this_step
226230

227-
def stop(self):
228-
return getattr(self._impl, "stop", lambda: None)()
231+
def start_e2e_profiler(self, **kwargs):
232+
if self.check_enable():
233+
return getattr(self._impl, "start_e2e_profiler", lambda **_: None)(**kwargs)
234+
235+
def stop_e2e_profiler(self):
236+
if self.check_enable():
237+
return getattr(self._impl, "stop_e2e_profiler", lambda: None)()
238+
239+
def start_capture_profiler(self, **kwargs):
240+
if self.check_enable():
241+
return getattr(self._impl, "start_capture_profiler", lambda **_: None)(**kwargs)
242+
243+
def stop_capture_profiler(self):
244+
if self.check_enable():
245+
return getattr(self._impl, "stop_capture_profiler", lambda: None)()
229246

230247
@classmethod
231248
def annotate(
@@ -242,6 +259,8 @@ def wrapper(self_instance, *args, **kwargs_inner):
242259
profiler = getattr(self_instance, "profiler", None)
243260
if not profiler:
244261
return func(self_instance, *args, **kwargs_inner)
262+
elif not profiler.check_enable():
263+
return func(self_instance, *args, **kwargs_inner)
245264

246265
impl = profiler._impl
247266
if hasattr(impl, "annotate"):
@@ -361,11 +380,33 @@ def __init__(self, profiler: DistProfiler):
361380
from verl.single_controller.base.decorator import Dispatch, register
362381

363382
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
364-
def start_profile(self, **kwargs) -> None:
383+
def start_e2e_profile(self, **kwargs) -> None:
384+
"""Start profiling for the current rank in the current training step."""
385+
self.profiler.start_e2e_profiler(**kwargs)
386+
387+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
388+
def stop_e2e_profile(self) -> None:
389+
"""Stop profiling for the current rank in the current training step."""
390+
self.profiler.stop_e2e_profiler()
391+
392+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
393+
def start_capture_profile(self, **kwargs) -> None:
365394
"""Start profiling for the current rank in the current training step."""
366395
self.profiler.start(**kwargs)
367396

368397
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
369-
def stop_profile(self) -> None:
398+
def stop_capture_profile(self) -> None:
370399
"""Stop profiling for the current rank in the current training step."""
371-
self.profiler.stop()
400+
self.profiler.stop_capture_profiler()
401+
402+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
403+
async def start_async_rollout_profile(self, **kwargs):
404+
"""Start an async rollout profiling segment."""
405+
self.rollout.start_async_rollout_profile(**kwargs)
406+
return True
407+
408+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
409+
async def stop_async_rollout_profile(self):
410+
"""Stop the async rollout profiling segment."""
411+
self.rollout.stop_async_rollout_profile()
412+
return True

0 commit comments

Comments
 (0)