diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 6040d68d3..d7a37aa4c 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -129,7 +129,7 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F # Initialize data packing parameters self.max_tokens_per_gpu = args.max_tokens_per_gpu # From main arguments - if self.args.offload: + if self.args.offload_train: self.sleep(("model")) Timer().start("train_wait") @@ -318,7 +318,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: """ Timer().end("train_wait") - if self.args.offload: + if self.args.offload_train: self.wake_up(("model")) world_size = dist.get_world_size() @@ -559,7 +559,7 @@ def update_weights(self) -> None: # type: ignore[override] self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock) dist.barrier(group=get_gloo_group()) - with torch_memory_saver.disable() if self.args.offload and not torch.version.hip else nullcontext(): + with torch_memory_saver.disable() if self.args.offload_train and not torch.version.hip else nullcontext(): self.weight_updater.update_weights() @torch.no_grad() diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 675d5bc61..c1c05af4a 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -73,7 +73,7 @@ def init( ) if role == "critic": - if self.args.offload: + if self.args.offload_train: self.sleep(("model")) Timer().start("train_wait") return @@ -106,7 +106,7 @@ def init( # empty cache after initialization clear_memory() - if self.args.offload: + if self.args.offload_train: # recover to actor in the end. self.update_gpu_params_dict(self.weights["actor"]) self.sleep(("model")) @@ -156,7 +156,7 @@ def update_gpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None: @timer def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None: - assert self.args.offload + assert self.args.offload_train assert "model" in tags if isinstance(tags, str): tags = (tags,) @@ -171,7 +171,7 @@ def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None: @timer def wake_up(self, tags: Union[str, Tuple[str, ...]]) -> None: - assert self.args.offload + assert self.args.offload_train # there are weird times when sglang is not offloaded immediately, so we wait here. mem_fraction_static = self.args.sglang_mem_fraction_static or 0.8 @@ -243,7 +243,7 @@ def compute_log_prob( def train(self, rollout_id: int, rollout_data_ref: Box) -> None: Timer().end("train_wait") - if self.args.offload: + if self.args.offload_train: self.wake_up(("model")) with timer("data_preprocess"): @@ -408,7 +408,7 @@ def update_weights(self) -> None: if self.args.debug_train_only or self.args.debug_rollout_only: return - if self.args.offload: + if self.args.offload_train: reload_process_groups() rollout_engines, rollout_engine_lock, num_new_engines = ray.get( @@ -418,7 +418,7 @@ def update_weights(self) -> None: self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock) dist.barrier(group=get_gloo_group()) - with torch_memory_saver.disable() if self.args.offload else nullcontext(): + with torch_memory_saver.disable() if self.args.offload_train else nullcontext(): print_memory("before update_weights") self.weight_updater.update_weights() print_memory("after update_weights") @@ -435,7 +435,7 @@ def update_weights(self) -> None: else: self.update_cpu_params_dict(self.weights["old_actor"]) - if self.args.offload: + if self.args.offload_train: destroy_process_groups() def load_other_checkpoint(self, model_tag: str, path: str) -> None: diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index ed728f70f..b79640250 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -334,7 +334,7 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port): "trust_remote_code": True, "random_seed": args.seed + rank, # memory - "enable_memory_saver": args.offload, + "enable_memory_saver": args.offload_rollout, # distributed "host": host, "port": port, diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index d056f0e24..76f67d399 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -61,7 +61,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor, wandb_run_id: Optiona **{name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST}, } - if self.args.offload: + if self.args.offload_train: import torch_memory_saver dynlib_path = os.path.join( diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index 0317644eb..beae77930 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -172,7 +172,7 @@ def create_rollout_manager(args, pg, wandb_run_id): args.num_rollout = num_rollout_per_epoch * args.num_epoch assert args.num_rollout > 0 - if args.offload: + if args.offload_rollout: ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index b65fcdc15..d76ef75ac 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -79,8 +79,23 @@ def add_cluster_arguments(parser): "--offload", action="store_true", default=False, + help=("Equivalent to --offload-train + --offload-rollout. "), + ) + parser.add_argument( + "--offload-train", + action="store_true", + default=False, + help=( + "Whether to offload the training actor to CPU during training. " + "This will always be true when --colocate is set." + ), + ) + parser.add_argument( + "--offload-rollout", + action="store_true", + default=False, help=( - "Whether to offload the rollout generator and training actor to CPU during training. " + "Whether to offload the rollout generator to CPU during training. " "This will always be true when --colocate is set." ), ) @@ -1190,6 +1205,11 @@ def slime_validate_args(args): if args.critic_lr is None: args.critic_lr = args.lr + if args.offload: + args.offload_train = True + args.offload_rollout = True + del args.offload + if args.debug_rollout_only: if args.colocate and args.rollout_num_gpus is None: args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes @@ -1197,7 +1217,7 @@ def slime_validate_args(args): args.actor_num_gpus_per_node = min(8, args.rollout_num_gpus) args.actor_num_nodes = args.rollout_num_gpus // args.actor_num_gpus_per_node args.colocate = False - args.offload = False + args.offload_train = args.offload_rollout = False assert not (args.debug_rollout_only and args.debug_train_only), ( "debug_rollout_only and debug_train_only cannot be set at the same time, " "please set only one of them." @@ -1205,7 +1225,7 @@ def slime_validate_args(args): # always true on offload for colocate at the moment. if args.colocate: - args.offload = True + args.offload_train = args.offload_rollout = True if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes: print( f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} " diff --git a/train.py b/train.py index bbbb9fcde..3439234ce 100644 --- a/train.py +++ b/train.py @@ -23,13 +23,13 @@ def train(args): # create the actor and critic models actor_model, critic_model = create_training_models(args, pgs, rollout_manager, wandb_run_id=wandb_run_id) - if args.offload: + if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # always update weight first so that sglang has the loaded weights from training. actor_model.update_weights() - if args.offload: + if args.offload_rollout: if GPU_MEMORY_TYPE_CUDA_GRAPH is not None: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])) ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE])) @@ -47,7 +47,7 @@ def train(args): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) - if args.offload: + if args.offload_rollout: ray.get(rollout_manager.offload.remote()) if args.use_critic: @@ -69,7 +69,7 @@ def train(args): if args.rollout_global_dataset: ray.get(rollout_manager.save.remote(rollout_id)) - if args.offload: + if args.offload_train: if args.use_critic: critic_model.offload() if rollout_id >= args.num_critic_only_steps: @@ -81,7 +81,7 @@ def train(args): actor_model.update_weights() - if args.offload: + if args.offload_rollout: if GPU_MEMORY_TYPE_CUDA_GRAPH is not None: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])) ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))