Skip to content

Commit 54f19ce

Browse files
discrete profiler support agent loop
Co-authored-by: tardis-key <[email protected]>
1 parent 2fd6591 commit 54f19ce

File tree

13 files changed

+185
-133
lines changed

13 files changed

+185
-133
lines changed

recipe/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,9 @@ async def sleep(self):
325325

326326
async def clear_kv_cache(self):
327327
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
328+
329+
async def start_profile(self, **kwargs):
330+
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
331+
332+
async def stop_profile(self):
333+
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])

recipe/one_step_off_policy/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@ async def sleep(self):
6262

6363
async def clear_kv_cache(self):
6464
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
65+
66+
async def start_profile(self, **kwargs):
67+
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
68+
69+
async def stop_profile(self):
70+
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])

verl/experimental/agent_loop/agent_loop.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ async def generate(
115115
)
116116
return output
117117

118+
async def start_profiler(self, **kwargs):
119+
for server in self.server_handles:
120+
await server.start_profiler.remote(**kwargs)
121+
122+
async def stop_profiler(self):
123+
for server in self.server_handles:
124+
await server.stop_profiler.remote()
125+
118126

119127
class AgentLoopMetrics(BaseModel):
120128
"""Agent loop performance metrics."""
@@ -580,6 +588,12 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
580588
extra_fields=output.extra_fields,
581589
)
582590

591+
async def start_profiler(self, **kwargs):
592+
self.server_manager.start_profiler(**kwargs)
593+
594+
async def stop_profiler(self):
595+
self.server_manager.stop_profiler()
596+
583597
def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
584598
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
585599
# Convert lists back to tensors and stack them to create a batch.
@@ -807,6 +821,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
807821

808822
# Fix for Issue #4147: Always call wake_up() to ensure weight sync
809823
# The wake_up()/sleep() methods internally check free_cache_engine
824+
ray.get([worker.start_profiler.remote() for worker in self.agent_loop_workers])
810825
self.wake_up()
811826
if self.reward_model_manager:
812827
self.reward_model_manager.wake_up()
@@ -829,6 +844,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
829844
timing = self._performance_metrics(metrics, output)
830845

831846
output.meta_info = {"timing": timing, **outputs[0].meta_info}
847+
ray.get([worker.stop_profiler.remote() for worker in self.agent_loop_workers])
832848
return output
833849

834850
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
@@ -865,6 +881,14 @@ def clear_kv_cache(self):
865881
"""Clear all rollout kv cache, but don`t sleep."""
866882
self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])
867883

884+
def start_profile(self, **kwargs):
885+
"""Start profiling on all rollout replicas."""
886+
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
887+
888+
def stop_profile(self):
889+
"""Stop profiling on all rollout replicas."""
890+
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
891+
868892
def _run_all(self, tasks: list[asyncio.Task]):
869893
async def run_all():
870894
await asyncio.gather(*tasks)

verl/trainer/ppo/ray_trainer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -990,24 +990,34 @@ def _load_checkpoint(self):
990990
def _start_profiling(self, do_profile: bool) -> None:
991991
"""Start profiling for all worker groups if profiling is enabled."""
992992
if do_profile:
993-
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
993+
self.async_rollout_manager.enable_profiler_this_step()
994+
self.actor_rollout_wg.enable_profiler_this_step()
995+
self.actor_rollout_wg.start_profiler(role="e2e", profile_step=self.global_steps)
994996
if self.use_reference_policy:
995-
self.ref_policy_wg.start_profile(profile_step=self.global_steps)
997+
self.ref_policy_wg.enable_profiler_this_step()
998+
self.ref_policy_wg.start_profiler(profile_step=self.global_steps)
996999
if self.use_critic:
997-
self.critic_wg.start_profile(profile_step=self.global_steps)
1000+
self.critic_wg.enable_profiler_this_step()
1001+
self.critic_wg.start_profiler(profile_step=self.global_steps)
9981002
if self.use_rm and not self.use_reward_loop:
999-
self.rm_wg.start_profile(profile_step=self.global_steps)
1003+
self.rm_wg.enable_profiler_this_step()
1004+
self.rm_wg.start_profiler(profile_step=self.global_steps)
10001005

10011006
def _stop_profiling(self, do_profile: bool) -> None:
10021007
"""Stop profiling for all worker groups if profiling is enabled."""
10031008
if do_profile:
1004-
self.actor_rollout_wg.stop_profile()
1009+
self.async_rollout_manager.disable_profiler_this_step()
1010+
self.actor_rollout_wg.disable_profiler_this_step()
1011+
self.actor_rollout_wg.stop_profiler()
10051012
if self.use_reference_policy:
1006-
self.ref_policy_wg.stop_profile()
1013+
self.ref_policy_wg.disable_profiler_this_step()
1014+
self.ref_policy_wg.stop_profiler()
10071015
if self.use_critic:
1008-
self.critic_wg.stop_profile()
1016+
self.critic_wg.disable_profiler_this_step()
1017+
self.critic_wg.stop_profiler()
10091018
if self.use_rm and not self.use_reward_loop:
1010-
self.rm_wg.stop_profile()
1019+
self.rm_wg.disable_profiler_this_step()
1020+
self.rm_wg.stop_profiler()
10111021

10121022
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
10131023
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""

verl/utils/profiler/mstx_profile.py

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -173,46 +173,32 @@ def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig
173173
config = ProfilerConfig(ranks=[], enable=False)
174174
if not tool_config:
175175
assert not config.enable, "tool_config must be set when profiler is enabled"
176-
self.enable: bool = config.enable
177-
if not config.enable:
178-
return
179-
self.this_step: bool = False
180176
self.discrete: bool = tool_config.discrete
181-
self.this_rank: bool = False
182177
self.profile_npu = None
183178
self.profile_contents = tool_config.contents
184179
self.profile_level = tool_config.level
185180
self.profile_save_path = config.save_path
186181
self.analysis = tool_config.analysis
187-
if config.all_ranks:
188-
self.this_rank = True
189-
elif config.ranks:
190-
self.this_rank = rank in config.ranks
191-
192-
def start(self, **kwargs):
193-
role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None)
194-
profile_step = str(profile_step) if profile_step is not None else None
195-
if self.enable and self.this_rank:
196-
self.this_step = True
197-
if not self.discrete and NPUProfiler._define_count == 0:
198-
self.profile_npu = get_npu_profiler(
199-
contents=self.profile_contents,
200-
profile_level=self.profile_level,
201-
profile_save_path=self.profile_save_path,
202-
analysis=self.analysis,
203-
role=role,
204-
profile_step=profile_step,
205-
)
206-
self.profile_npu.start()
207-
NPUProfiler._define_count += 1
208-
209-
def stop(self):
210-
if self.enable and self.this_rank:
211-
self.this_step = False
212-
if not self.discrete and NPUProfiler._define_count == 1:
213-
self.profile_npu.step()
214-
self.profile_npu.stop()
215-
NPUProfiler._define_count -= 1
182+
183+
def start_profiler(self, **kwargs):
184+
role = kwargs.get("role", None)
185+
if not self.discrete and NPUProfiler._define_count == 0:
186+
self.profile_npu = get_npu_profiler(
187+
contents=self.profile_contents,
188+
profile_level=self.profile_level,
189+
profile_save_path=self.profile_save_path,
190+
analysis=self.analysis,
191+
role=role,
192+
)
193+
self.profile_npu.start()
194+
NPUProfiler._define_count += 1
195+
196+
def stop_profiler(self):
197+
if not self.discrete and NPUProfiler._define_count == 1:
198+
self.profile_npu.step()
199+
self.profile_npu.stop()
200+
NPUProfiler._define_count -= 1
201+
216202

217203
def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:
218204
"""Decorate a Worker member function to profile the current rank in the current training step.
@@ -230,42 +216,33 @@ def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **
230216
def decorator(func):
231217
@functools.wraps(func)
232218
def wrapper(*args, **kwargs_inner):
233-
if not self.enable:
234-
return func(*args, **kwargs_inner)
235-
236219
profile_name = message or func.__name__
237220
discrete_mode = self.discrete
238-
profile_enable = self.this_step and self.enable
239-
240-
if not profile_enable:
241-
return func(*args, **kwargs_inner)
242-
243-
if profile_enable:
244-
if not discrete_mode:
245-
mark_range = mark_start_range(message=profile_name)
246-
else:
247-
profile_npu = get_npu_profiler(
248-
contents=self.profile_contents,
249-
profile_level=self.profile_level,
250-
profile_save_path=self.profile_save_path,
251-
analysis=self.analysis,
252-
role=role,
253-
)
254-
profile_npu.start()
255-
mark_range = mark_start_range(message=profile_name)
221+
222+
if not discrete_mode:
223+
mark_range = mark_start_range(message=profile_name)
224+
else:
225+
profile_npu = get_npu_profiler(
226+
contents=self.profile_contents,
227+
profile_level=self.profile_level,
228+
profile_save_path=self.profile_save_path,
229+
analysis=self.analysis,
230+
role=role,
231+
)
232+
profile_npu.start()
233+
mark_range = mark_start_range(message=profile_name)
256234

257235
result = func(*args, **kwargs_inner)
258236

259-
if profile_enable:
260-
if not discrete_mode:
261-
mark_end_range(mark_range)
262-
else:
263-
mark_end_range(mark_range)
264-
profile_npu.step()
265-
profile_npu.stop()
237+
if not discrete_mode:
238+
mark_end_range(mark_range)
239+
else:
240+
mark_end_range(mark_range)
241+
profile_npu.step()
242+
profile_npu.stop()
266243

267244
return result
268245

269246
return wrapper
270247

271-
return decorator
248+
return decorator

verl/utils/profiler/nvtx_profile.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,15 @@ def __init__(self, rank: int, config: Optional[ProfilerConfig], tool_config: Opt
126126
config = ProfilerConfig(ranks=[])
127127
if not tool_config:
128128
assert not config.enable, "tool_config must be provided when profiler is enabled"
129-
self.enable = config.enable
130-
if not config.enable:
131-
return
132-
self.this_step: bool = False
133129
self.discrete: bool = tool_config.discrete
134-
self.this_rank: bool = False
135-
if config.all_ranks:
136-
self.this_rank = True
137-
elif config.ranks:
138-
self.this_rank = rank in config.ranks
139-
140-
def start(self, **kwargs):
141-
if self.enable and self.this_rank:
142-
self.this_step = True
143-
if not self.discrete:
144-
torch.cuda.profiler.start()
145-
146-
def stop(self):
147-
if self.enable and self.this_rank:
148-
self.this_step = False
149-
if not self.discrete:
150-
torch.cuda.profiler.stop()
130+
131+
def start_profiler(self, **kwargs):
132+
if not self.discrete:
133+
torch.cuda.profiler.start()
134+
135+
def stop_profiler(self):
136+
if not self.discrete:
137+
torch.cuda.profiler.stop()
151138

152139
def annotate(
153140
self,
@@ -176,22 +163,17 @@ def annotate(
176163
def decorator(func):
177164
@functools.wraps(func)
178165
def wrapper(*args, **kwargs_inner):
179-
if not self.enable:
180-
return func(*args, **kwargs_inner)
181-
182166
profile_name = message or func.__name__
183167

184-
if self.this_step:
185-
if self.discrete:
186-
torch.cuda.profiler.start()
187-
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
168+
if self.discrete:
169+
torch.cuda.profiler.start()
170+
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
188171

189172
result = func(*args, **kwargs_inner)
190173

191-
if self.this_step:
192-
mark_end_range(mark_range)
193-
if self.discrete:
194-
torch.cuda.profiler.stop()
174+
mark_end_range(mark_range)
175+
if self.discrete:
176+
torch.cuda.profiler.stop()
195177

196178
return result
197179

0 commit comments

Comments
 (0)