Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from slime.ray.train_actor import TrainRayActor
from slime.utils import profile_utils
from slime.utils.context_utils import with_defer
from slime.utils.data import get_minimum_num_micro_batch_size, process_rollout_data
from slime.utils.distributed_utils import get_gloo_group
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss
from slime.utils.ray_utils import Box
from slime.utils.timer import Timer, timer
from slime.utils.timer import Timer, inverse_timer, timer
from slime.utils.wandb_utils import init_wandb_secondary

from . import checkpoint
Expand All @@ -41,6 +42,7 @@ class FSDPTrainRayActor(TrainRayActor):
* For small models this is fine; for larger models consider sharded state_dict type.
"""

@with_defer(lambda: Timer().start("train_wait"))
def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = False) -> int: # type: ignore[override]
super().init(args, role, wandb_run_id, with_ref)

Expand Down Expand Up @@ -154,9 +156,9 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
if self.args.offload_train:
self.sleep()

Timer().start("train_wait")
return int(getattr(self.args, "start_rollout_id", 0))

@timer
def sleep(self) -> None:
"""Pause CUDA memory for all tracked tensors."""
if not self.args.offload_train:
Expand Down Expand Up @@ -185,6 +187,7 @@ def sleep(self) -> None:
dist.barrier(group=get_gloo_group())
print_memory("after offload model")

@timer
def wake_up(self) -> None:
"""Resume CUDA memory for all tracked tensors."""
if not self.args.offload_train:
Expand Down Expand Up @@ -354,11 +357,20 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
`rollout_log_probs`, etc.). It will be fetched and partitioned
by `process_rollout_data` based on data-parallel rank/size.
"""
Timer().end("train_wait")

if self.args.offload_train:
self.wake_up()

with inverse_timer("train_wait"), timer("train"):
self._train_core(rollout_id=rollout_id, rollout_data_ref=rollout_data_ref)

if (
self.args.record_memory_history
and ((s := self.args.memory_snapshot_num_steps) is not None)
and (rollout_id == s - 1)
):
profile_utils.dump_snapshot_and_stop(profile_utils.get_memory_snapshot_full_path(self.args))

def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None:
world_size = dist.get_world_size()
rank = dist.get_rank()

Expand Down Expand Up @@ -413,16 +425,17 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
)
wandb.log(log_dict)

reported_accum: dict[str, list[torch.Tensor]] = {}
self.optimizer.zero_grad(set_to_none=True)
for mbs_id, packed_batch in enumerate(packed_batches):
self._train_step(
packed_batch=packed_batch,
world_size=world_size,
reported_accum=reported_accum,
mbs_id=mbs_id,
grad_accum=grad_accum,
)
with timer("actor_train"):
reported_accum: dict[str, list[torch.Tensor]] = {}
self.optimizer.zero_grad(set_to_none=True)
for mbs_id, packed_batch in enumerate(packed_batches):
self._train_step(
packed_batch=packed_batch,
world_size=world_size,
reported_accum=reported_accum,
mbs_id=mbs_id,
grad_accum=grad_accum,
)

self.update_cpu_params_dict(self.weights["actor"])

Expand All @@ -436,16 +449,6 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
print(f"Updating ref model at rollout_id {rollout_id}")
self.update_cpu_params_dict(self.weights["ref"])

if (
self.args.record_memory_history
and ((s := self.args.memory_snapshot_num_steps) is not None)
and (rollout_id == s - 1)
):
profile_utils.dump_snapshot_and_stop(profile_utils.get_memory_snapshot_full_path(self.args))

Timer().start("train_wait")
return

def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_accum):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = self.model(
Expand Down Expand Up @@ -605,6 +608,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
wandb.log(log_dict)
self.global_step += 1

@timer
def update_weights(self) -> None: # type: ignore[override]
"""Synchronize actor weights to rollout engines.

Expand Down
14 changes: 4 additions & 10 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from transformers import AutoConfig, AutoTokenizer

from slime.ray.train_actor import TrainRayActor
from slime.utils.context_utils import with_defer
from slime.utils.data import process_rollout_data
from slime.utils.distributed_utils import get_gloo_group, init_process_group
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.ray_utils import Box
from slime.utils.reloadable_process_group import destroy_process_groups, monkey_patch_torch_dist, reload_process_groups
from slime.utils.routing_replay import RoutingReplay
from slime.utils.timer import Timer, timer
from slime.utils.timer import Timer, inverse_timer, timer
from slime.utils.types import RolloutBatch
from slime.utils.wandb_utils import init_wandb_secondary

Expand All @@ -36,6 +37,7 @@


class MegatronTrainRayActor(TrainRayActor):
@with_defer(lambda: Timer().start("train_wait"))
def init(
self,
args: Namespace,
Expand All @@ -60,7 +62,6 @@ def init(
dist.barrier(group=get_gloo_group())

if self.args.debug_rollout_only:
Timer().start("train_wait")
return 0

if role == "critic":
Expand All @@ -76,7 +77,6 @@ def init(
if role == "critic":
if self.args.offload_train:
self.sleep()
Timer().start("train_wait")
return

start_rollout_id = loaded_rollout_id + 1
Expand Down Expand Up @@ -122,7 +122,6 @@ def init(

self.prof = TrainProfiler(args)

Timer().start("train_wait")
return start_rollout_id

@torch.no_grad()
Expand Down Expand Up @@ -221,16 +220,13 @@ def compute_log_prob(
)

def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
Timer().end("train_wait")

if self.args.offload_train:
self.wake_up()

with timer("data_preprocess"):
rollout_data = self._get_rollout_data(rollout_data_ref)
if self.args.debug_rollout_only:
log_rollout_data(rollout_id, self.args, rollout_data)
Timer().start("train_wait")
return

if self.role == "critic":
Expand Down Expand Up @@ -265,13 +261,12 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch) -> None:
data_iterator,
num_microbatches,
)
Timer().start("train_wait")

def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None:
# Create data iterator for log_probs and train.
data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data)

with timer("train"):
with inverse_timer("train_wait"), timer("train"):
if self.args.compute_advantages_and_returns:
if "ref" in self.weights:
if self.args.use_routing_replay:
Expand Down Expand Up @@ -366,7 +361,6 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None:
self.update_cpu_params_dict(self.weights["ref"])

log_perf_data(rollout_id, self.args)
Timer().start("train_wait")

def save_model(self, iteration: int) -> None:
if self.args.debug_rollout_only:
Expand Down
15 changes: 15 additions & 0 deletions slime/utils/context_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from functools import wraps


def with_defer(deferred_func):
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
finally:
deferred_func()

return wrapper

return decorator
20 changes: 16 additions & 4 deletions slime/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import wraps
from time import time

import torch.distributed

from .misc import SingletonMeta

__all__ = ["Timer", "timer"]
Expand All @@ -15,12 +17,16 @@ def __init__(self):
def start(self, name):
assert name not in self.start_time, f"Timer {name} already started."
self.start_time[name] = time()
if torch.distributed.get_rank() == 0:
print(f"Timer {name} start")

def end(self, name):
assert name in self.start_time, f"Timer {name} not started."
elapsed_time = time() - self.start_time[name]
self.add(name, elapsed_time)
del self.start_time[name]
if torch.distributed.get_rank() == 0:
print(f"Timer {name} end (elapsed: {elapsed_time:.1f}s)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remove those 2 prints? I'm afraid this will cause too much print in the log.

Copy link
Collaborator Author

@fzyzcjy fzyzcjy Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally find it pretty useful since it "splits" the log into sections, and this only introduce ~10 lines of log while one step may have ~100 (or ~1000) lines of log from sglang rollout, but if you like I can remove it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, make sense.


def reset(self, name=None):
if name is None:
Expand All @@ -29,10 +35,7 @@ def reset(self, name=None):
del self.timers[name]

def add(self, name, elapsed_time):
if name not in self.timers:
self.timers[name] = elapsed_time
else:
self.timers[name] += elapsed_time
self.timers[name] = self.timers.get(name, 0) + elapsed_time

def log_dict(self):
return self.timers
Expand Down Expand Up @@ -72,3 +75,12 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


@contextmanager
def inverse_timer(name):
Timer().end(name)
try:
yield
finally:
Timer().start(name)
Loading