From 72a85255a5d39900851b6af3b2af5e66350c8f9b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 29 Oct 2025 22:58:29 +0800 Subject: [PATCH 1/3] cp --- slime/backends/fsdp_utils/actor.py | 52 ++++++++++++++------------ slime/backends/megatron_utils/actor.py | 13 ++----- slime/utils/context_utils.py | 15 ++++++++ slime/utils/timer.py | 20 ++++++++-- 4 files changed, 63 insertions(+), 37 deletions(-) create mode 100644 slime/utils/context_utils.py diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index ad28d5462..8e4dd81ac 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -14,12 +14,13 @@ from slime.ray.train_actor import TrainRayActor from slime.utils import profile_utils +from slime.utils.context_utils import with_defer from slime.utils.data import get_minimum_num_micro_batch_size, process_rollout_data from slime.utils.distributed_utils import get_gloo_group from slime.utils.memory_utils import clear_memory, print_memory from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss from slime.utils.ray_utils import Box -from slime.utils.timer import Timer, timer +from slime.utils.timer import Timer, inverse_timer, timer from slime.utils.wandb_utils import init_wandb_secondary from .data_packing import pack_sequences, unpack_sequences @@ -40,6 +41,7 @@ class FSDPTrainRayActor(TrainRayActor): * For small models this is fine; for larger models consider sharded state_dict type. """ + @with_defer(lambda: Timer().start("train_wait")) def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = False) -> int: # type: ignore[override] super().init(args, role, wandb_run_id, with_ref) @@ -142,11 +144,11 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F if self.args.offload_train: self.sleep() - Timer().start("train_wait") self.global_step = 0 self.micro_step = 0 return 0 + @timer def sleep(self) -> None: """Pause CUDA memory for all tracked tensors.""" if not self.args.offload_train: @@ -175,6 +177,7 @@ def sleep(self) -> None: dist.barrier(group=get_gloo_group()) print_memory("after offload model") + @timer def wake_up(self) -> None: """Resume CUDA memory for all tracked tensors.""" if not self.args.offload_train: @@ -349,11 +352,20 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: `rollout_log_probs`, etc.). It will be fetched and partitioned by `process_rollout_data` based on data-parallel rank/size. """ - Timer().end("train_wait") - if self.args.offload_train: self.wake_up() + with inverse_timer("train_wait"), timer("train"): + self._train_core(rollout_id=rollout_id, rollout_data_ref=rollout_data_ref) + + if ( + self.args.record_memory_history + and ((s := self.args.memory_snapshot_num_steps) is not None) + and (rollout_id == s - 1) + ): + profile_utils.dump_snapshot_and_stop(profile_utils.get_memory_snapshot_full_path(self.args)) + + def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None: world_size = dist.get_world_size() rank = dist.get_rank() @@ -408,16 +420,17 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: ) wandb.log(log_dict) - reported_accum: dict[str, list[torch.Tensor]] = {} - self.optimizer.zero_grad(set_to_none=True) - for mbs_id, packed_batch in enumerate(packed_batches): - self._train_step( - packed_batch=packed_batch, - world_size=world_size, - reported_accum=reported_accum, - mbs_id=mbs_id, - grad_accum=grad_accum, - ) + with timer("actor_train"): + reported_accum: dict[str, list[torch.Tensor]] = {} + self.optimizer.zero_grad(set_to_none=True) + for mbs_id, packed_batch in enumerate(packed_batches): + self._train_step( + packed_batch=packed_batch, + world_size=world_size, + reported_accum=reported_accum, + mbs_id=mbs_id, + grad_accum=grad_accum, + ) self.update_cpu_params_dict(self.weights["actor"]) @@ -431,16 +444,6 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: print(f"Updating ref model at rollout_id {rollout_id}") self.update_cpu_params_dict(self.weights["ref"]) - if ( - self.args.record_memory_history - and ((s := self.args.memory_snapshot_num_steps) is not None) - and (rollout_id == s - 1) - ): - profile_utils.dump_snapshot_and_stop(profile_utils.get_memory_snapshot_full_path(self.args)) - - Timer().start("train_wait") - return - def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_accum): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits = self.model( @@ -600,6 +603,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc wandb.log(log_dict) self.global_step += 1 + @timer def update_weights(self) -> None: # type: ignore[override] """Synchronize actor weights to rollout engines. diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index cacfc2d60..ddfc008d1 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -15,13 +15,14 @@ from transformers import AutoConfig, AutoTokenizer from slime.ray.train_actor import TrainRayActor +from slime.utils.context_utils import with_defer from slime.utils.data import process_rollout_data from slime.utils.distributed_utils import get_gloo_group, init_process_group from slime.utils.memory_utils import clear_memory, print_memory from slime.utils.ray_utils import Box from slime.utils.reloadable_process_group import destroy_process_groups, monkey_patch_torch_dist, reload_process_groups from slime.utils.routing_replay import RoutingReplay -from slime.utils.timer import Timer, timer +from slime.utils.timer import Timer, inverse_timer, timer from slime.utils.types import RolloutBatch from slime.utils.wandb_utils import init_wandb_secondary @@ -35,6 +36,7 @@ class MegatronTrainRayActor(TrainRayActor): + @with_defer(lambda: Timer().start("train_wait")) def init( self, args: Namespace, @@ -59,7 +61,6 @@ def init( dist.barrier(group=get_gloo_group()) if self.args.debug_rollout_only: - Timer().start("train_wait") return 0 if role == "critic": @@ -75,7 +76,6 @@ def init( if role == "critic": if self.args.offload_train: self.sleep() - Timer().start("train_wait") return start_rollout_id = loaded_rollout_id + 1 @@ -235,8 +235,6 @@ def compute_log_prob( ) def train(self, rollout_id: int, rollout_data_ref: Box) -> None: - Timer().end("train_wait") - if self.args.offload_train: self.wake_up() @@ -244,7 +242,6 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: rollout_data = self._get_rollout_data(rollout_data_ref) if self.args.debug_rollout_only: log_rollout_data(rollout_id, self.args, rollout_data) - Timer().start("train_wait") return if self.role == "critic": @@ -279,13 +276,12 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch) -> None: data_iterator, num_microbatches, ) - Timer().start("train_wait") def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: # Create data iterator for log_probs and train. data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data) - with timer("train"): + with inverse_timer("train_wait"), timer("train"): if self.args.compute_advantages_and_returns: if "ref" in self.weights: if self.args.use_routing_replay: @@ -389,7 +385,6 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: self.update_cpu_params_dict(self.weights["ref"]) log_perf_data(rollout_id, self.args) - Timer().start("train_wait") def save_model(self, iteration: int) -> None: if self.args.debug_rollout_only: diff --git a/slime/utils/context_utils.py b/slime/utils/context_utils.py new file mode 100644 index 000000000..02307c40d --- /dev/null +++ b/slime/utils/context_utils.py @@ -0,0 +1,15 @@ +from functools import wraps + + +def with_defer(deferred_func): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + finally: + deferred_func() + + return wrapper + + return decorator diff --git a/slime/utils/timer.py b/slime/utils/timer.py index 2be3dc45c..03b7f6dd7 100644 --- a/slime/utils/timer.py +++ b/slime/utils/timer.py @@ -2,6 +2,8 @@ from functools import wraps from time import time +import torch.distributed + from .misc import SingletonMeta __all__ = ["Timer", "timer"] @@ -15,12 +17,16 @@ def __init__(self): def start(self, name): assert name not in self.start_time, f"Timer {name} already started." self.start_time[name] = time() + if torch.distributed.get_rank() == 0: + print(f"Timer {name} start") def end(self, name): assert name in self.start_time, f"Timer {name} not started." elapsed_time = time() - self.start_time[name] self.add(name, elapsed_time) del self.start_time[name] + if torch.distributed.get_rank() == 0: + print(f"Timer {name} end (elapsed: {elapsed_time:.1f}s)") def reset(self, name=None): if name is None: @@ -29,10 +35,7 @@ def reset(self, name=None): del self.timers[name] def add(self, name, elapsed_time): - if name not in self.timers: - self.timers[name] = elapsed_time - else: - self.timers[name] += elapsed_time + self.timers[name] = self.timers.get(name, 0) + elapsed_time def log_dict(self): return self.timers @@ -72,3 +75,12 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +@contextmanager +def inverse_timer(name): + Timer().end(name) + try: + yield + finally: + Timer().start(name) From 16d5a02eb6ee8d65a246d566e1eca294a915c944 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:14:40 +0800 Subject: [PATCH 2/3] Update actor.py --- slime/backends/fsdp_utils/actor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 945bf8f2d..439d6166e 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -156,7 +156,6 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F if self.args.offload_train: self.sleep() - Timer().start("train_wait") return int(getattr(self.args, "start_rollout_id", 0)) @timer From d196e03e4c3cddbe662eb958d4aac7f084883c09 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:21:18 +0800 Subject: [PATCH 3/3] Update actor.py --- slime/backends/megatron_utils/actor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 80ef2e46d..1d8f660b2 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -122,7 +122,6 @@ def init( self.prof = TrainProfiler(args) - Timer().start("train_wait") return start_rollout_id @torch.no_grad()