Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions recipe/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,9 @@ async def sleep(self):

async def clear_kv_cache(self):
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])

async def start_profile(self, **kwargs):
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])

async def stop_profile(self):
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])
6 changes: 6 additions & 0 deletions recipe/one_step_off_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ async def sleep(self):

async def clear_kv_cache(self):
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])

async def start_profile(self, **kwargs):
await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])

async def stop_profile(self):
await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])
24 changes: 24 additions & 0 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ async def generate(
)
return output

async def start_profiler(self, **kwargs):
for server in self.server_handles:
await server.start_profiler.remote(**kwargs)

async def stop_profiler(self):
for server in self.server_handles:
await server.stop_profiler.remote()


class AgentLoopMetrics(BaseModel):
"""Agent loop performance metrics."""
Expand Down Expand Up @@ -580,6 +588,12 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
extra_fields=output.extra_fields,
)

async def start_profiler(self, **kwargs):
self.server_manager.start_profiler(**kwargs)

async def stop_profiler(self):
self.server_manager.stop_profiler()

def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
# Convert lists back to tensors and stack them to create a batch.
Expand Down Expand Up @@ -807,6 +821,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:

# Fix for Issue #4147: Always call wake_up() to ensure weight sync
# The wake_up()/sleep() methods internally check free_cache_engine
ray.get([worker.start_profiler.remote() for worker in self.agent_loop_workers])
self.wake_up()
if self.reward_model_manager:
self.reward_model_manager.wake_up()
Expand All @@ -829,6 +844,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
timing = self._performance_metrics(metrics, output)

output.meta_info = {"timing": timing, **outputs[0].meta_info}
ray.get([worker.stop_profiler.remote() for worker in self.agent_loop_workers])
return output

def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
Expand Down Expand Up @@ -865,6 +881,14 @@ def clear_kv_cache(self):
"""Clear all rollout kv cache, but don`t sleep."""
self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])

def start_profile(self, **kwargs):
"""Start profiling on all rollout replicas."""
self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])

def stop_profile(self):
"""Stop profiling on all rollout replicas."""
self._run_all([replica.stop_profile() for replica in self.rollout_replicas])

def _run_all(self, tasks: list[asyncio.Task]):
async def run_all():
await asyncio.gather(*tasks)
Expand Down
26 changes: 18 additions & 8 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,24 +990,34 @@ def _load_checkpoint(self):
def _start_profiling(self, do_profile: bool) -> None:
"""Start profiling for all worker groups if profiling is enabled."""
if do_profile:
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
self.async_rollout_manager.enable_profiler_this_step()
self.actor_rollout_wg.enable_profiler_this_step()
self.actor_rollout_wg.start_profiler(role="e2e", profile_step=self.global_steps)
if self.use_reference_policy:
self.ref_policy_wg.start_profile(profile_step=self.global_steps)
self.ref_policy_wg.enable_profiler_this_step()
self.ref_policy_wg.start_profiler(profile_step=self.global_steps)
if self.use_critic:
self.critic_wg.start_profile(profile_step=self.global_steps)
self.critic_wg.enable_profiler_this_step()
self.critic_wg.start_profiler(profile_step=self.global_steps)
if self.use_rm and not self.use_reward_loop:
self.rm_wg.start_profile(profile_step=self.global_steps)
self.rm_wg.enable_profiler_this_step()
self.rm_wg.start_profiler(profile_step=self.global_steps)

def _stop_profiling(self, do_profile: bool) -> None:
"""Stop profiling for all worker groups if profiling is enabled."""
if do_profile:
self.actor_rollout_wg.stop_profile()
self.async_rollout_manager.disable_profiler_this_step()
self.actor_rollout_wg.disable_profiler_this_step()
self.actor_rollout_wg.stop_profiler()
if self.use_reference_policy:
self.ref_policy_wg.stop_profile()
self.ref_policy_wg.disable_profiler_this_step()
self.ref_policy_wg.stop_profiler()
if self.use_critic:
self.critic_wg.stop_profile()
self.critic_wg.disable_profiler_this_step()
self.critic_wg.stop_profiler()
if self.use_rm and not self.use_reward_loop:
self.rm_wg.stop_profile()
self.rm_wg.disable_profiler_this_step()
self.rm_wg.stop_profiler()

def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
Expand Down
103 changes: 40 additions & 63 deletions verl/utils/profiler/mstx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,46 +173,32 @@ def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig
config = ProfilerConfig(ranks=[], enable=False)
if not tool_config:
assert not config.enable, "tool_config must be set when profiler is enabled"
self.enable: bool = config.enable
if not config.enable:
return
self.this_step: bool = False
self.discrete: bool = tool_config.discrete
self.this_rank: bool = False
self.profile_npu = None
self.profile_contents = tool_config.contents
self.profile_level = tool_config.level
self.profile_save_path = config.save_path
self.analysis = tool_config.analysis
if config.all_ranks:
self.this_rank = True
elif config.ranks:
self.this_rank = rank in config.ranks

def start(self, **kwargs):
role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None)
profile_step = str(profile_step) if profile_step is not None else None
if self.enable and self.this_rank:
self.this_step = True
if not self.discrete and NPUProfiler._define_count == 0:
self.profile_npu = get_npu_profiler(
contents=self.profile_contents,
profile_level=self.profile_level,
profile_save_path=self.profile_save_path,
analysis=self.analysis,
role=role,
profile_step=profile_step,
)
self.profile_npu.start()
NPUProfiler._define_count += 1

def stop(self):
if self.enable and self.this_rank:
self.this_step = False
if not self.discrete and NPUProfiler._define_count == 1:
self.profile_npu.step()
self.profile_npu.stop()
NPUProfiler._define_count -= 1

def start_profiler(self, **kwargs):
role = kwargs.get("role", None)
if not self.discrete and NPUProfiler._define_count == 0:
self.profile_npu = get_npu_profiler(
contents=self.profile_contents,
profile_level=self.profile_level,
profile_save_path=self.profile_save_path,
analysis=self.analysis,
role=role,
)
self.profile_npu.start()
NPUProfiler._define_count += 1

def stop_profiler(self):
if not self.discrete and NPUProfiler._define_count == 1:
self.profile_npu.step()
self.profile_npu.stop()
NPUProfiler._define_count -= 1


def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:
"""Decorate a Worker member function to profile the current rank in the current training step.
Expand All @@ -230,42 +216,33 @@ def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs_inner):
if not self.enable:
return func(*args, **kwargs_inner)

profile_name = message or func.__name__
discrete_mode = self.discrete
profile_enable = self.this_step and self.enable

if not profile_enable:
return func(*args, **kwargs_inner)

if profile_enable:
if not discrete_mode:
mark_range = mark_start_range(message=profile_name)
else:
profile_npu = get_npu_profiler(
contents=self.profile_contents,
profile_level=self.profile_level,
profile_save_path=self.profile_save_path,
analysis=self.analysis,
role=role,
)
profile_npu.start()
mark_range = mark_start_range(message=profile_name)

if not discrete_mode:
mark_range = mark_start_range(message=profile_name)
else:
profile_npu = get_npu_profiler(
contents=self.profile_contents,
profile_level=self.profile_level,
profile_save_path=self.profile_save_path,
analysis=self.analysis,
role=role,
)
profile_npu.start()
mark_range = mark_start_range(message=profile_name)

result = func(*args, **kwargs_inner)

if profile_enable:
if not discrete_mode:
mark_end_range(mark_range)
else:
mark_end_range(mark_range)
profile_npu.step()
profile_npu.stop()
if not discrete_mode:
mark_end_range(mark_range)
else:
mark_end_range(mark_range)
profile_npu.step()
profile_npu.stop()

return result

return wrapper

return decorator
return decorator
46 changes: 14 additions & 32 deletions verl/utils/profiler/nvtx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,28 +126,15 @@ def __init__(self, rank: int, config: Optional[ProfilerConfig], tool_config: Opt
config = ProfilerConfig(ranks=[])
if not tool_config:
assert not config.enable, "tool_config must be provided when profiler is enabled"
self.enable = config.enable
if not config.enable:
return
self.this_step: bool = False
self.discrete: bool = tool_config.discrete
self.this_rank: bool = False
if config.all_ranks:
self.this_rank = True
elif config.ranks:
self.this_rank = rank in config.ranks

def start(self, **kwargs):
if self.enable and self.this_rank:
self.this_step = True
if not self.discrete:
torch.cuda.profiler.start()

def stop(self):
if self.enable and self.this_rank:
self.this_step = False
if not self.discrete:
torch.cuda.profiler.stop()

def start_profiler(self, **kwargs):
if not self.discrete:
torch.cuda.profiler.start()

def stop_profiler(self):
if not self.discrete:
torch.cuda.profiler.stop()

def annotate(
self,
Expand Down Expand Up @@ -176,22 +163,17 @@ def annotate(
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs_inner):
if not self.enable:
return func(*args, **kwargs_inner)

profile_name = message or func.__name__

if self.this_step:
if self.discrete:
torch.cuda.profiler.start()
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)
if self.discrete:
torch.cuda.profiler.start()
mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)

result = func(*args, **kwargs_inner)

if self.this_step:
mark_end_range(mark_range)
if self.discrete:
torch.cuda.profiler.stop()
mark_end_range(mark_range)
if self.discrete:
torch.cuda.profiler.stop()

return result

Expand Down
Loading