diff --git a/recipe/fully_async_policy/agent_loop/agent_loop.py b/recipe/fully_async_policy/agent_loop/agent_loop.py index d486579596f..a489f33e669 100644 --- a/recipe/fully_async_policy/agent_loop/agent_loop.py +++ b/recipe/fully_async_policy/agent_loop/agent_loop.py @@ -325,3 +325,9 @@ async def sleep(self): async def clear_kv_cache(self): await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) + + async def start_profile(self, **kwargs): + await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas]) + + async def stop_profile(self): + await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas]) diff --git a/recipe/one_step_off_policy/agent_loop/agent_loop.py b/recipe/one_step_off_policy/agent_loop/agent_loop.py index 85455d655b2..c0e9d181d3e 100644 --- a/recipe/one_step_off_policy/agent_loop/agent_loop.py +++ b/recipe/one_step_off_policy/agent_loop/agent_loop.py @@ -62,3 +62,9 @@ async def sleep(self): async def clear_kv_cache(self): await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) + + async def start_profile(self, **kwargs): + await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas]) + + async def stop_profile(self): + await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas]) \ No newline at end of file diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 9d197ad7eaf..a1850afa390 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -115,6 +115,14 @@ async def generate( ) return output + async def start_profiler(self, **kwargs): + for server in self.server_handles: + await server.start_profiler.remote(**kwargs) + + async def stop_profiler(self): + for server in self.server_handles: + await server.stop_profiler.remote() + class AgentLoopMetrics(BaseModel): """Agent loop performance metrics.""" @@ -580,6 +588,12 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO extra_fields=output.extra_fields, ) + async def start_profiler(self, **kwargs): + self.server_manager.start_profiler(**kwargs) + + async def stop_profiler(self): + self.server_manager.stop_profiler() + def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto: """Process the padded outputs from _run_agent_loop and combine them into a batch.""" # Convert lists back to tensors and stack them to create a batch. @@ -807,6 +821,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: # Fix for Issue #4147: Always call wake_up() to ensure weight sync # The wake_up()/sleep() methods internally check free_cache_engine + ray.get([worker.start_profiler.remote() for worker in self.agent_loop_workers]) self.wake_up() if self.reward_model_manager: self.reward_model_manager.wake_up() @@ -829,6 +844,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: timing = self._performance_metrics(metrics, output) output.meta_info = {"timing": timing, **outputs[0].meta_info} + ray.get([worker.stop_profiler.remote() for worker in self.agent_loop_workers]) return output def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: @@ -865,6 +881,14 @@ def clear_kv_cache(self): """Clear all rollout kv cache, but don`t sleep.""" self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas]) + def start_profile(self, **kwargs): + """Start profiling on all rollout replicas.""" + self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas]) + + def stop_profile(self): + """Stop profiling on all rollout replicas.""" + self._run_all([replica.stop_profile() for replica in self.rollout_replicas]) + def _run_all(self, tasks: list[asyncio.Task]): async def run_all(): await asyncio.gather(*tasks) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 88bb6f241c0..0cf8dd84053 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -990,24 +990,34 @@ def _load_checkpoint(self): def _start_profiling(self, do_profile: bool) -> None: """Start profiling for all worker groups if profiling is enabled.""" if do_profile: - self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + self.async_rollout_manager.enable_profiler_this_step() + self.actor_rollout_wg.enable_profiler_this_step() + self.actor_rollout_wg.start_profiler(role="e2e", profile_step=self.global_steps) if self.use_reference_policy: - self.ref_policy_wg.start_profile(profile_step=self.global_steps) + self.ref_policy_wg.enable_profiler_this_step() + self.ref_policy_wg.start_profiler(profile_step=self.global_steps) if self.use_critic: - self.critic_wg.start_profile(profile_step=self.global_steps) + self.critic_wg.enable_profiler_this_step() + self.critic_wg.start_profiler(profile_step=self.global_steps) if self.use_rm and not self.use_reward_loop: - self.rm_wg.start_profile(profile_step=self.global_steps) + self.rm_wg.enable_profiler_this_step() + self.rm_wg.start_profiler(profile_step=self.global_steps) def _stop_profiling(self, do_profile: bool) -> None: """Stop profiling for all worker groups if profiling is enabled.""" if do_profile: - self.actor_rollout_wg.stop_profile() + self.async_rollout_manager.disable_profiler_this_step() + self.actor_rollout_wg.disable_profiler_this_step() + self.actor_rollout_wg.stop_profiler() if self.use_reference_policy: - self.ref_policy_wg.stop_profile() + self.ref_policy_wg.disable_profiler_this_step() + self.ref_policy_wg.stop_profiler() if self.use_critic: - self.critic_wg.stop_profile() + self.critic_wg.disable_profiler_this_step() + self.critic_wg.stop_profiler() if self.use_rm and not self.use_reward_loop: - self.rm_wg.stop_profile() + self.rm_wg.disable_profiler_this_step() + self.rm_wg.stop_profiler() def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): """Reorder the data on single controller such that each dp rank gets similar total tokens""" diff --git a/verl/utils/profiler/mstx_profile.py b/verl/utils/profiler/mstx_profile.py index b9576714248..1bef26e18a8 100644 --- a/verl/utils/profiler/mstx_profile.py +++ b/verl/utils/profiler/mstx_profile.py @@ -173,46 +173,32 @@ def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig config = ProfilerConfig(ranks=[], enable=False) if not tool_config: assert not config.enable, "tool_config must be set when profiler is enabled" - self.enable: bool = config.enable - if not config.enable: - return - self.this_step: bool = False self.discrete: bool = tool_config.discrete - self.this_rank: bool = False self.profile_npu = None self.profile_contents = tool_config.contents self.profile_level = tool_config.level self.profile_save_path = config.save_path self.analysis = tool_config.analysis - if config.all_ranks: - self.this_rank = True - elif config.ranks: - self.this_rank = rank in config.ranks - - def start(self, **kwargs): - role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None) - profile_step = str(profile_step) if profile_step is not None else None - if self.enable and self.this_rank: - self.this_step = True - if not self.discrete and NPUProfiler._define_count == 0: - self.profile_npu = get_npu_profiler( - contents=self.profile_contents, - profile_level=self.profile_level, - profile_save_path=self.profile_save_path, - analysis=self.analysis, - role=role, - profile_step=profile_step, - ) - self.profile_npu.start() - NPUProfiler._define_count += 1 - - def stop(self): - if self.enable and self.this_rank: - self.this_step = False - if not self.discrete and NPUProfiler._define_count == 1: - self.profile_npu.step() - self.profile_npu.stop() - NPUProfiler._define_count -= 1 + + def start_profiler(self, **kwargs): + role = kwargs.get("role", None) + if not self.discrete and NPUProfiler._define_count == 0: + self.profile_npu = get_npu_profiler( + contents=self.profile_contents, + profile_level=self.profile_level, + profile_save_path=self.profile_save_path, + analysis=self.analysis, + role=role, + ) + self.profile_npu.start() + NPUProfiler._define_count += 1 + + def stop_profiler(self): + if not self.discrete and NPUProfiler._define_count == 1: + self.profile_npu.step() + self.profile_npu.stop() + NPUProfiler._define_count -= 1 + def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable: """Decorate a Worker member function to profile the current rank in the current training step. @@ -230,42 +216,33 @@ def annotate(self, message: Optional[str] = None, role: Optional[str] = None, ** def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs_inner): - if not self.enable: - return func(*args, **kwargs_inner) - profile_name = message or func.__name__ discrete_mode = self.discrete - profile_enable = self.this_step and self.enable - - if not profile_enable: - return func(*args, **kwargs_inner) - - if profile_enable: - if not discrete_mode: - mark_range = mark_start_range(message=profile_name) - else: - profile_npu = get_npu_profiler( - contents=self.profile_contents, - profile_level=self.profile_level, - profile_save_path=self.profile_save_path, - analysis=self.analysis, - role=role, - ) - profile_npu.start() - mark_range = mark_start_range(message=profile_name) + + if not discrete_mode: + mark_range = mark_start_range(message=profile_name) + else: + profile_npu = get_npu_profiler( + contents=self.profile_contents, + profile_level=self.profile_level, + profile_save_path=self.profile_save_path, + analysis=self.analysis, + role=role, + ) + profile_npu.start() + mark_range = mark_start_range(message=profile_name) result = func(*args, **kwargs_inner) - if profile_enable: - if not discrete_mode: - mark_end_range(mark_range) - else: - mark_end_range(mark_range) - profile_npu.step() - profile_npu.stop() + if not discrete_mode: + mark_end_range(mark_range) + else: + mark_end_range(mark_range) + profile_npu.step() + profile_npu.stop() return result return wrapper - return decorator + return decorator \ No newline at end of file diff --git a/verl/utils/profiler/nvtx_profile.py b/verl/utils/profiler/nvtx_profile.py index 35857498c03..da5fe9fdd81 100644 --- a/verl/utils/profiler/nvtx_profile.py +++ b/verl/utils/profiler/nvtx_profile.py @@ -126,28 +126,15 @@ def __init__(self, rank: int, config: Optional[ProfilerConfig], tool_config: Opt config = ProfilerConfig(ranks=[]) if not tool_config: assert not config.enable, "tool_config must be provided when profiler is enabled" - self.enable = config.enable - if not config.enable: - return - self.this_step: bool = False self.discrete: bool = tool_config.discrete - self.this_rank: bool = False - if config.all_ranks: - self.this_rank = True - elif config.ranks: - self.this_rank = rank in config.ranks - - def start(self, **kwargs): - if self.enable and self.this_rank: - self.this_step = True - if not self.discrete: - torch.cuda.profiler.start() - - def stop(self): - if self.enable and self.this_rank: - self.this_step = False - if not self.discrete: - torch.cuda.profiler.stop() + + def start_profiler(self, **kwargs): + if not self.discrete: + torch.cuda.profiler.start() + + def stop_profiler(self): + if not self.discrete: + torch.cuda.profiler.stop() def annotate( self, @@ -176,22 +163,17 @@ def annotate( def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs_inner): - if not self.enable: - return func(*args, **kwargs_inner) - profile_name = message or func.__name__ - if self.this_step: - if self.discrete: - torch.cuda.profiler.start() - mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category) + if self.discrete: + torch.cuda.profiler.start() + mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category) result = func(*args, **kwargs_inner) - if self.this_step: - mark_end_range(mark_range) - if self.discrete: - torch.cuda.profiler.stop() + mark_end_range(mark_range) + if self.discrete: + torch.cuda.profiler.stop() return result diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index a5aabcbc8ef..c14c4df78a0 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -192,6 +192,7 @@ def __init__( self._impl = None self._tool = getattr(config, "tool", None) + self._enable = config.enable # Normalize rank selection self._this_rank = False @@ -201,7 +202,9 @@ def __init__( self._this_rank = rank in config.ranks else: # default rank 0 if enabled but ranks unspecified - self._this_rank = (rank == 0) if config.enable else False + self._this_rank = (rank == 0) if self._enable else False + + self._discrete = getattr(tool_config, "discrete", False) # Profiler and TorchMemoryProfiler currently do not support discrete mode. # Lazy import to avoid circular deps if self._tool == "nsys": @@ -221,11 +224,25 @@ def __init__( # Fallback to a no-op impl self._impl = _NoOpProfiler() - def start(self, **kwargs): - return getattr(self._impl, "start", lambda **_: None)(**kwargs) + def enable_profiler_this_step(self): + self._this_step = True - def stop(self): - return getattr(self._impl, "stop", lambda: None)() + def disable_profiler_this_step(self): + self._this_step = False + + def check_enable(self): + return self._enable and self._this_rank and self._this_step + + def is_discrete_mode(self): + return self._discrete + + def start_profiler(self, **kwargs): + if self.check_enable(): + return getattr(self._impl, "start_profiler", lambda **_: None)(**kwargs) + + def stop_profiler(self): + if self.check_enable(): + return getattr(self._impl, "stop_profiler", lambda: None)() @classmethod def annotate( @@ -240,7 +257,7 @@ def decorator(func): @functools.wraps(func) def wrapper(self_instance, *args, **kwargs_inner): profiler = getattr(self_instance, "profiler", None) - if not profiler: + if not profiler or not profiler.check_enable(): return func(self_instance, *args, **kwargs_inner) impl = profiler._impl @@ -361,11 +378,21 @@ def __init__(self, profiler: DistProfiler): from verl.single_controller.base.decorator import Dispatch, register @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def start_profile(self, **kwargs) -> None: + def start_profiler(self, **kwargs) -> None: """Start profiling for the current rank in the current training step.""" - self.profiler.start(**kwargs) + self.profiler.start_profiler(**kwargs) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def stop_profile(self) -> None: + def stop_profiler(self) -> None: """Stop profiling for the current rank in the current training step.""" - self.profiler.stop() + self.profiler.stop_profiler() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def enable_profiler_this_step(self) -> None: + """Enable profiling for the current rank in the current training step.""" + self.profiler.enable_profiler_this_step() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def disable_profiler_this_step(self) -> None: + """Disable profiling for the current rank in the current training step.""" + self.profiler.disable_profiler_this_step() diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index b7d89134d72..eaff8bf01a5 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -1110,16 +1110,6 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def start_profile(self, **kwargs) -> None: - """Start profiling for the current rank in the current training step.""" - self.profiler.start(**kwargs) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def stop_profile(self) -> None: - """Stop profiling for the current rank in the current training step.""" - self.profiler.stop() - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: """Manually trigger a CUDA memory snapshot dump on all ranks.""" @@ -1943,6 +1933,16 @@ async def sleep(self): await self.trainer_mode() return True + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def start_async_rollout_profile(self, **kwargs): + """Start an async rollout profiling segment.""" + await self.rollout.start_async_rollout_profile(**kwargs) + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def stop_async_rollout_profile(self): + """Stop the async rollout profiling segment.""" + await self.rollout.stop_async_rollout_profile() + # ============================ vLLM related ============================ @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index db2e3fb1b97..602d9aa0714 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -919,16 +919,6 @@ def async_calls_finalize_fn_exec(self, blocking=False): async_calls.maybe_finalize_async_calls(blocking=blocking) - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def start_profile(self, **kwargs) -> None: - """Start profiling for the current rank in the current training step.""" - self.profiler.start(**kwargs) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def stop_profile(self) -> None: - """Stop profiling for the current rank in the current training step.""" - self.profiler.stop() - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: """Manually trigger a CUDA memory snapshot dump on all ranks.""" @@ -956,6 +946,16 @@ async def sleep(self): await self.trainer_mode() return True + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def start_async_rollout_profile(self, **kwargs): + """Start an async rollout profiling segment.""" + await self.rollout.start_async_rollout_profile(**kwargs) + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def stop_async_rollout_profile(self): + """Stop the async rollout profiling segment.""" + await self.rollout.stop_async_rollout_profile() + # ============================ vLLM related ============================ @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index f8f6f1084ad..51084b961b9 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -207,6 +207,14 @@ async def clear_kv_cache(self): """reset kv cache in each rollout server.""" await asyncio.gather(*[server.clear_kv_cache.remote() for server in self.servers]) + async def start_profile(self, **kwargs): + """Start profiling on all workers.""" + await asyncio.gather(*[worker.start_async_rollout_profile.remote(**kwargs) for worker in self.workers]) + + async def stop_profile(self): + """Stop profiling on all workers.""" + await asyncio.gather(*[worker.stop_async_rollout_profile.remote() for worker in self.workers]) + class RolloutReplicaRegistry: """Factory for managing rollout replica implementations.""" diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index e78700d9f7a..e96120836c5 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -287,6 +287,11 @@ async def generate( log_probs = None return TokenOutput(token_ids=token_ids, log_probs=log_probs) + async def start_profiler(self, **kwargs): + await self.tokenizer_manager.start_profiler(**kwargs) + + async def stop_profiler(self): + await self.tokenizer_manager.stop_profiler() _rollout_worker_actor_cls = ray.remote(ServerAdapter) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index e0c292b8397..b3a12d00656 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -518,6 +518,12 @@ async def sleep(self): elif self.rollout_mode == RolloutMode.STANDALONE: logger.info("skip sleep in standalone mode") + async def start_profiler(self, **kwargs): + await self.engine.start_profiler(**kwargs) + + async def stop_profiler(self): + await self.engine.stop_profiler() + async def clear_kv_cache(self): if self.node_rank == 0: await self.engine.reset_prefix_cache() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 42a1cd96885..9f6cd3cb468 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -57,6 +57,7 @@ from verl.utils.device import is_npu_available from verl.utils.distributed import initialize_global_process_group_ray from verl.utils.ray_utils import ray_noset_visible_devices +from verl.utils.config import omega_conf_to_dataclass from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights from verl.workers.config import HFModelConfig, RolloutConfig