Skip to content

Commit e033c06

Browse files
discrete profiler support agent loop
Co-authored-by: tardis-key <[email protected]>
1 parent 2fd6591 commit e033c06

File tree

12 files changed

+220
-134
lines changed

12 files changed

+220
-134
lines changed

recipe/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,9 @@ async def sleep(self):
325325

326326
async def clear_kv_cache(self):
327327
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
328+
329+
async def start_profile(self, **kwargs):
330+
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
331+
332+
async def stop_profile(self):
333+
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])

recipe/one_step_off_policy/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@ async def sleep(self):
6262

6363
async def clear_kv_cache(self):
6464
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
65+
66+
async def start_profile(self, **kwargs):
67+
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
68+
69+
async def stop_profile(self):
70+
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])

verl/experimental/agent_loop/agent_loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
807807

808808
# Fix for Issue #4147: Always call wake_up() to ensure weight sync
809809
# The wake_up()/sleep() methods internally check free_cache_engine
810+
self.start_profile(role="agent_loop_rollout_generate")
810811
self.wake_up()
811812
if self.reward_model_manager:
812813
self.reward_model_manager.wake_up()
@@ -829,6 +830,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
829830
timing = self._performance_metrics(metrics, output)
830831

831832
output.meta_info = {"timing": timing, **outputs[0].meta_info}
833+
self.stop_profile()
832834
return output
833835

834836
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
@@ -865,6 +867,14 @@ def clear_kv_cache(self):
865867
"""Clear all rollout kv cache, but don`t sleep."""
866868
self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])
867869

870+
def start_profile(self, **kwargs):
871+
"""Start profiling on all rollout replicas."""
872+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
873+
874+
def stop_profile(self):
875+
"""Stop profiling on all rollout replicas."""
876+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
877+
868878
def _run_all(self, tasks: list[asyncio.Task]):
869879
async def run_all():
870880
await asyncio.gather(*tasks)

verl/trainer/ppo/ray_trainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -990,24 +990,24 @@ def _load_checkpoint(self):
990990
def _start_profiling(self, do_profile: bool) -> None:
991991
"""Start profiling for all worker groups if profiling is enabled."""
992992
if do_profile:
993-
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
993+
self.actor_rollout_wg.start_e2e_profiler(role="e2e", profile_step=self.global_steps)
994994
if self.use_reference_policy:
995-
self.ref_policy_wg.start_profile(profile_step=self.global_steps)
995+
self.ref_policy_wg.start_e2e_profiler(profile_step=self.global_steps)
996996
if self.use_critic:
997-
self.critic_wg.start_profile(profile_step=self.global_steps)
997+
self.critic_wg.start_e2e_profiler(profile_step=self.global_steps)
998998
if self.use_rm and not self.use_reward_loop:
999-
self.rm_wg.start_profile(profile_step=self.global_steps)
999+
self.rm_wg.start_e2e_profiler(profile_step=self.global_steps)
10001000

10011001
def _stop_profiling(self, do_profile: bool) -> None:
10021002
"""Stop profiling for all worker groups if profiling is enabled."""
10031003
if do_profile:
1004-
self.actor_rollout_wg.stop_profile()
1004+
self.actor_rollout_wg.stop_e2e_profiler()
10051005
if self.use_reference_policy:
1006-
self.ref_policy_wg.stop_profile()
1006+
self.ref_policy_wg.stop_e2e_profiler()
10071007
if self.use_critic:
1008-
self.critic_wg.stop_profile()
1008+
self.critic_wg.stop_e2e_profiler()
10091009
if self.use_rm and not self.use_reward_loop:
1010-
self.rm_wg.stop_profile()
1010+
self.rm_wg.stop_e2e_profiler()
10111011

10121012
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
10131013
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""

verl/utils/profiler/mstx_profile.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -173,46 +173,58 @@ def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig
173173
config = ProfilerConfig(ranks=[], enable=False)
174174
if not tool_config:
175175
assert not config.enable, "tool_config must be set when profiler is enabled"
176-
self.enable: bool = config.enable
177-
if not config.enable:
178-
return
179-
self.this_step: bool = False
180176
self.discrete: bool = tool_config.discrete
181-
self.this_rank: bool = False
182-
self.profile_npu = None
177+
self.e2e_profile_npu = None
183178
self.profile_contents = tool_config.contents
184179
self.profile_level = tool_config.level
185180
self.profile_save_path = config.save_path
186181
self.analysis = tool_config.analysis
187-
if config.all_ranks:
188-
self.this_rank = True
189-
elif config.ranks:
190-
self.this_rank = rank in config.ranks
191-
192-
def start(self, **kwargs):
193-
role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None)
194-
profile_step = str(profile_step) if profile_step is not None else None
195-
if self.enable and self.this_rank:
196-
self.this_step = True
197-
if not self.discrete and NPUProfiler._define_count == 0:
198-
self.profile_npu = get_npu_profiler(
199-
contents=self.profile_contents,
200-
profile_level=self.profile_level,
201-
profile_save_path=self.profile_save_path,
202-
analysis=self.analysis,
203-
role=role,
204-
profile_step=profile_step,
205-
)
206-
self.profile_npu.start()
207-
NPUProfiler._define_count += 1
208-
209-
def stop(self):
210-
if self.enable and self.this_rank:
211-
self.this_step = False
212-
if not self.discrete and NPUProfiler._define_count == 1:
213-
self.profile_npu.step()
214-
self.profile_npu.stop()
215-
NPUProfiler._define_count -= 1
182+
183+
def start_e2e_profiler(self, **kwargs):
184+
role = kwargs.get("role", None)
185+
if not self.discrete and NPUProfiler._define_count == 0:
186+
self.e2e_profile_npu = get_npu_profiler(
187+
contents=self.profile_contents,
188+
profile_level=self.profile_level,
189+
profile_save_path=self.profile_save_path,
190+
analysis=self.analysis,
191+
role=role,
192+
)
193+
self.e2e_profile_npu.start()
194+
NPUProfiler._define_count += 1
195+
196+
def stop_e2e_profiler(self):
197+
if not self.discrete and NPUProfiler._define_count == 1:
198+
self.e2e_profile_npu.step()
199+
self.e2e_profile_npu.stop()
200+
NPUProfiler._define_count -= 1
201+
202+
def start_capture_profiler(self, **kwargs):
203+
"""Start an on-demand profiling segment."""
204+
role = kwargs.get("role", "")
205+
206+
if self.discrete:
207+
self.capture_profiler_npu = get_npu_profiler(
208+
contents=self.profile_contents,
209+
profile_level=self.profile_level,
210+
profile_save_path=self.profile_save_path,
211+
analysis=self.analysis,
212+
role=role,
213+
)
214+
self.capture_profiler_npu.start()
215+
216+
self._capture_range_id = mark_start_range(message=role)
217+
218+
def stop_capture_profiler(self):
219+
"""Stop the on-demand profiling segment."""
220+
if hasattr(self, "_capture_range_id"):
221+
mark_end_range(self._capture_range_id)
222+
del self._capture_range_id
223+
224+
if self.discrete and getattr(self, "capture_profiler_npu", None):
225+
self.capture_profiler_npu.step()
226+
self.capture_profiler_npu.stop()
227+
del self.capture_profiler_npu
216228

217229
def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:
218230
"""Decorate a Worker member function to profile the current rank in the current training step.
@@ -230,42 +242,33 @@ def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **
230242
def decorator(func):
231243
@functools.wraps(func)
232244
def wrapper(*args, **kwargs_inner):
233-
if not self.enable:
234-
return func(*args, **kwargs_inner)
235-
236245
profile_name = message or func.__name__
237246
discrete_mode = self.discrete
238-
profile_enable = self.this_step and self.enable
239-
240-
if not profile_enable:
241-
return func(*args, **kwargs_inner)
242-
243-
if profile_enable:
244-
if not discrete_mode:
245-
mark_range = mark_start_range(message=profile_name)
246-
else:
247-
profile_npu = get_npu_profiler(
248-
contents=self.profile_contents,
249-
profile_level=self.profile_level,
250-
profile_save_path=self.profile_save_path,
251-
analysis=self.analysis,
252-
role=role,
253-
)
254-
profile_npu.start()
255-
mark_range = mark_start_range(message=profile_name)
247+
248+
if not discrete_mode:
249+
mark_range = mark_start_range(message=profile_name)
250+
else:
251+
profile_npu = get_npu_profiler(
252+
contents=self.profile_contents,
253+
profile_level=self.profile_level,
254+
profile_save_path=self.profile_save_path,
255+
analysis=self.analysis,
256+
role=role,
257+
)
258+
profile_npu.start()
259+
mark_range = mark_start_range(message=profile_name)
256260

257261
result = func(*args, **kwargs_inner)
258262

259-
if profile_enable:
260-
if not discrete_mode:
261-
mark_end_range(mark_range)
262-
else:
263-
mark_end_range(mark_range)
264-
profile_npu.step()
265-
profile_npu.stop()
263+
if not discrete_mode:
264+
mark_end_range(mark_range)
265+
else:
266+
mark_end_range(mark_range)
267+
profile_npu.step()
268+
profile_npu.stop()
266269

267270
return result
268271

269272
return wrapper
270273

271-
return decorator
274+
return decorator

verl/utils/profiler/nvtx_profile.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -129,25 +129,25 @@ def __init__(self, rank: int, config: Optional[ProfilerConfig], tool_config: Opt
129129
self.enable = config.enable
130130
if not config.enable:
131131
return
132-
self.this_step: bool = False
133132
self.discrete: bool = tool_config.discrete
134-
self.this_rank: bool = False
135-
if config.all_ranks:
136-
self.this_rank = True
137-
elif config.ranks:
138-
self.this_rank = rank in config.ranks
139-
140-
def start(self, **kwargs):
141-
if self.enable and self.this_rank:
142-
self.this_step = True
143-
if not self.discrete:
144-
torch.cuda.profiler.start()
145-
146-
def stop(self):
147-
if self.enable and self.this_rank:
148-
self.this_step = False
149-
if not self.discrete:
150-
torch.cuda.profiler.stop()
133+
134+
def start_e2e_profiler(self, **kwargs):
135+
if not self.discrete:
136+
torch.cuda.profiler.start()
137+
138+
def stop_e2e_profiler(self):
139+
if not self.discrete:
140+
torch.cuda.profiler.stop()
141+
142+
def start_capture_profiler(self, **kwargs):
143+
"""Start an on-demand profiling segment."""
144+
if self.discrete:
145+
torch.cuda.profiler.start()
146+
147+
def stop_capture_profiler(self):
148+
"""Stop the on-demand profiling segment."""
149+
if self.discrete:
150+
torch.cuda.profiler.stop()
151151

152152
def annotate(
153153
self,
@@ -176,22 +176,17 @@ def annotate(
176176
def decorator(func):
177177
@functools.wraps(func)
178178
def wrapper(*args, **kwargs_inner):
179-
if not self.enable:
180-
return func(*args, **kwargs_inner)
181-
182179
profile_name = message or func.__name__
183180

184-
if self.this_step:
185-
if self.discrete:
186-
torch.cuda.profiler.start()
187-
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
181+
if self.discrete:
182+
torch.cuda.profiler.start()
183+
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
188184

189185
result = func(*args, **kwargs_inner)
190186

191-
if self.this_step:
192-
mark_end_range(mark_range)
193-
if self.discrete:
194-
torch.cuda.profiler.stop()
187+
mark_end_range(mark_range)
188+
if self.discrete:
189+
torch.cuda.profiler.stop()
195190

196191
return result
197192

verl/utils/profiler/profile.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ class DistProfiler:
183183
- torch_memory: Torch CUDA memory snapshot dump
184184
"""
185185

186+
_this_step = False
187+
186188
def __init__(
187189
self, rank: int, config: Optional[ProfilerConfig] = None, tool_config: Optional[object] = None, **kwargs
188190
):
@@ -192,6 +194,7 @@ def __init__(
192194

193195
self._impl = None
194196
self._tool = getattr(config, "tool", None)
197+
self._enable = config.enable
195198

196199
# Normalize rank selection
197200
self._this_rank = False
@@ -201,7 +204,7 @@ def __init__(
201204
self._this_rank = rank in config.ranks
202205
else:
203206
# default rank 0 if enabled but ranks unspecified
204-
self._this_rank = (rank == 0) if config.enable else False
207+
self._this_rank = (rank == 0) if self._enable else False
205208

206209
# Lazy import to avoid circular deps
207210
if self._tool == "nsys":
@@ -221,11 +224,26 @@ def __init__(
221224
# Fallback to a no-op impl
222225
self._impl = _NoOpProfiler()
223226

224-
def start(self, **kwargs):
225-
return getattr(self._impl, "start", lambda **_: None)(**kwargs)
227+
def check_enable(self):
228+
return self._enable and self._this_rank and DistProfiler._this_step
226229

227-
def stop(self):
228-
return getattr(self._impl, "stop", lambda: None)()
230+
def start_e2e_profiler(self, **kwargs):
231+
DistProfiler._this_step = True
232+
if self.check_enable():
233+
return getattr(self._impl, "start_e2e_profiler", lambda **_: None)(**kwargs)
234+
235+
def stop_e2e_profiler(self):
236+
if self.check_enable():
237+
DistProfiler._this_step = False
238+
return getattr(self._impl, "stop_e2e_profiler", lambda: None)()
239+
240+
def start_capture_profiler(self, **kwargs):
241+
if self.check_enable():
242+
return getattr(self._impl, "start_capture_profiler", lambda **_: None)(**kwargs)
243+
244+
def stop_capture_profiler(self):
245+
if self.check_enable():
246+
return getattr(self._impl, "stop_capture_profiler", lambda: None)()
229247

230248
@classmethod
231249
def annotate(
@@ -240,7 +258,7 @@ def decorator(func):
240258
@functools.wraps(func)
241259
def wrapper(self_instance, *args, **kwargs_inner):
242260
profiler = getattr(self_instance, "profiler", None)
243-
if not profiler:
261+
if not profiler or not profiler.check_enable():
244262
return func(self_instance, *args, **kwargs_inner)
245263

246264
impl = profiler._impl
@@ -361,11 +379,11 @@ def __init__(self, profiler: DistProfiler):
361379
from verl.single_controller.base.decorator import Dispatch, register
362380

363381
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
364-
def start_profile(self, **kwargs) -> None:
382+
def start_e2e_profiler(self, **kwargs) -> None:
365383
"""Start profiling for the current rank in the current training step."""
366-
self.profiler.start(**kwargs)
384+
self.profiler.start_e2e_profiler(**kwargs)
367385

368386
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
369-
def stop_profile(self) -> None:
387+
def stop_e2e_profiler(self) -> None:
370388
"""Stop profiling for the current rank in the current training step."""
371-
self.profiler.stop()
389+
self.profiler.stop_e2e_profiler()

0 commit comments

Comments
 (0)