Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
# Initialize data packing parameters
self.max_tokens_per_gpu = args.max_tokens_per_gpu # From main arguments

if self.args.offload:
if self.args.offload_train:
self.sleep(("model"))

Timer().start("train_wait")
Expand Down Expand Up @@ -318,7 +318,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
"""
Timer().end("train_wait")

if self.args.offload:
if self.args.offload_train:
self.wake_up(("model"))

world_size = dist.get_world_size()
Expand Down Expand Up @@ -559,7 +559,7 @@ def update_weights(self) -> None: # type: ignore[override]
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
dist.barrier(group=get_gloo_group())

with torch_memory_saver.disable() if self.args.offload and not torch.version.hip else nullcontext():
with torch_memory_saver.disable() if self.args.offload_train and not torch.version.hip else nullcontext():
self.weight_updater.update_weights()

@torch.no_grad()
Expand Down
16 changes: 8 additions & 8 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def init(
)

if role == "critic":
if self.args.offload:
if self.args.offload_train:
self.sleep(("model"))
Timer().start("train_wait")
return
Expand Down Expand Up @@ -106,7 +106,7 @@ def init(
# empty cache after initialization
clear_memory()

if self.args.offload:
if self.args.offload_train:
# recover to actor in the end.
self.update_gpu_params_dict(self.weights["actor"])
self.sleep(("model"))
Expand Down Expand Up @@ -156,7 +156,7 @@ def update_gpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None:

@timer
def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
assert self.args.offload
assert self.args.offload_train
assert "model" in tags
if isinstance(tags, str):
tags = (tags,)
Expand All @@ -171,7 +171,7 @@ def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:

@timer
def wake_up(self, tags: Union[str, Tuple[str, ...]]) -> None:
assert self.args.offload
assert self.args.offload_train

# there are weird times when sglang is not offloaded immediately, so we wait here.
mem_fraction_static = self.args.sglang_mem_fraction_static or 0.8
Expand Down Expand Up @@ -243,7 +243,7 @@ def compute_log_prob(
def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
Timer().end("train_wait")

if self.args.offload:
if self.args.offload_train:
self.wake_up(("model"))

with timer("data_preprocess"):
Expand Down Expand Up @@ -408,7 +408,7 @@ def update_weights(self) -> None:
if self.args.debug_train_only or self.args.debug_rollout_only:
return

if self.args.offload:
if self.args.offload_train:
reload_process_groups()

rollout_engines, rollout_engine_lock, num_new_engines = ray.get(
Expand All @@ -418,7 +418,7 @@ def update_weights(self) -> None:
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
dist.barrier(group=get_gloo_group())

with torch_memory_saver.disable() if self.args.offload else nullcontext():
with torch_memory_saver.disable() if self.args.offload_train else nullcontext():
print_memory("before update_weights")
self.weight_updater.update_weights()
print_memory("after update_weights")
Expand All @@ -435,7 +435,7 @@ def update_weights(self) -> None:
else:
self.update_cpu_params_dict(self.weights["old_actor"])

if self.args.offload:
if self.args.offload_train:
destroy_process_groups()

def load_other_checkpoint(self, model_tag: str, path: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion slime/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port):
"trust_remote_code": True,
"random_seed": args.seed + rank,
# memory
"enable_memory_saver": args.offload,
"enable_memory_saver": args.offload_rollout,
# distributed
"host": host,
"port": port,
Expand Down
2 changes: 1 addition & 1 deletion slime/ray/actor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor, wandb_run_id: Optiona
**{name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST},
}

if self.args.offload:
if self.args.offload_train:
import torch_memory_saver

dynlib_path = os.path.join(
Expand Down
2 changes: 1 addition & 1 deletion slime/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def create_rollout_manager(args, pg, wandb_run_id):
args.num_rollout = num_rollout_per_epoch * args.num_epoch
assert args.num_rollout > 0

if args.offload:
if args.offload_rollout:
ray.get(rollout_manager.offload.remote())

return rollout_manager, num_rollout_per_epoch
26 changes: 23 additions & 3 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,23 @@ def add_cluster_arguments(parser):
"--offload",
action="store_true",
default=False,
help=("Equivalent to --offload-train + --offload-rollout. "),
)
parser.add_argument(
"--offload-train",
action="store_true",
default=False,
help=(
"Whether to offload the training actor to CPU during training. "
"This will always be true when --colocate is set."
),
)
parser.add_argument(
"--offload-rollout",
action="store_true",
default=False,
help=(
"Whether to offload the rollout generator and training actor to CPU during training. "
"Whether to offload the rollout generator to CPU during training. "
"This will always be true when --colocate is set."
),
)
Expand Down Expand Up @@ -1190,22 +1205,27 @@ def slime_validate_args(args):
if args.critic_lr is None:
args.critic_lr = args.lr

if args.offload:
args.offload_train = True
args.offload_rollout = True
del args.offload

if args.debug_rollout_only:
if args.colocate and args.rollout_num_gpus is None:
args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes
else:
args.actor_num_gpus_per_node = min(8, args.rollout_num_gpus)
args.actor_num_nodes = args.rollout_num_gpus // args.actor_num_gpus_per_node
args.colocate = False
args.offload = False
args.offload_train = args.offload_rollout = False

assert not (args.debug_rollout_only and args.debug_train_only), (
"debug_rollout_only and debug_train_only cannot be set at the same time, " "please set only one of them."
)

# always true on offload for colocate at the moment.
if args.colocate:
args.offload = True
args.offload_train = args.offload_rollout = True
if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes:
print(
f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} "
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def train(args):
# create the actor and critic models
actor_model, critic_model = create_training_models(args, pgs, rollout_manager, wandb_run_id=wandb_run_id)

if args.offload:
if args.offload_rollout:
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))

# always update weight first so that sglang has the loaded weights from training.
actor_model.update_weights()

if args.offload:
if args.offload_rollout:
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
Expand All @@ -47,7 +47,7 @@ def train(args):

rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))

if args.offload:
if args.offload_rollout:
ray.get(rollout_manager.offload.remote())

if args.use_critic:
Expand All @@ -69,7 +69,7 @@ def train(args):
if args.rollout_global_dataset:
ray.get(rollout_manager.save.remote(rollout_id))

if args.offload:
if args.offload_train:
if args.use_critic:
critic_model.offload()
if rollout_id >= args.num_critic_only_steps:
Expand All @@ -81,7 +81,7 @@ def train(args):

actor_model.update_weights()

if args.offload:
if args.offload_rollout:
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
Expand Down
Loading