Skip to content

Commit 8c5f5c0

Browse files
authored
Split args.offload to train and rollout (#569)
1 parent fb87a0e commit 8c5f5c0

File tree

7 files changed

+42
-22
lines changed

7 files changed

+42
-22
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
142142
# Initialize data packing parameters
143143
self.max_tokens_per_gpu = args.max_tokens_per_gpu # From main arguments
144144

145-
if self.args.offload:
145+
if self.args.offload_train:
146146
self.sleep(("model"))
147147

148148
Timer().start("train_wait")
@@ -331,7 +331,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
331331
"""
332332
Timer().end("train_wait")
333333

334-
if self.args.offload:
334+
if self.args.offload_train:
335335
self.wake_up(("model"))
336336

337337
world_size = dist.get_world_size()
@@ -578,7 +578,7 @@ def update_weights(self) -> None: # type: ignore[override]
578578
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
579579
dist.barrier(group=get_gloo_group())
580580

581-
with torch_memory_saver.disable() if self.args.offload and not torch.version.hip else nullcontext():
581+
with torch_memory_saver.disable() if self.args.offload_train and not torch.version.hip else nullcontext():
582582
self.weight_updater.update_weights()
583583

584584
@torch.no_grad()

slime/backends/megatron_utils/actor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def init(
7373
)
7474

7575
if role == "critic":
76-
if self.args.offload:
76+
if self.args.offload_train:
7777
self.sleep(("model"))
7878
Timer().start("train_wait")
7979
return
@@ -106,7 +106,7 @@ def init(
106106
# empty cache after initialization
107107
clear_memory()
108108

109-
if self.args.offload:
109+
if self.args.offload_train:
110110
# recover to actor in the end.
111111
self.update_gpu_params_dict(self.weights["actor"])
112112
self.sleep(("model"))
@@ -156,7 +156,7 @@ def update_gpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None:
156156

157157
@timer
158158
def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
159-
assert self.args.offload
159+
assert self.args.offload_train
160160
assert "model" in tags
161161
if isinstance(tags, str):
162162
tags = (tags,)
@@ -171,7 +171,7 @@ def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
171171

172172
@timer
173173
def wake_up(self, tags: Union[str, Tuple[str, ...]]) -> None:
174-
assert self.args.offload
174+
assert self.args.offload_train
175175

176176
# there are weird times when sglang is not offloaded immediately, so we wait here.
177177
mem_fraction_static = self.args.sglang_mem_fraction_static or 0.8
@@ -243,7 +243,7 @@ def compute_log_prob(
243243
def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
244244
Timer().end("train_wait")
245245

246-
if self.args.offload:
246+
if self.args.offload_train:
247247
self.wake_up(("model"))
248248

249249
with timer("data_preprocess"):
@@ -408,7 +408,7 @@ def update_weights(self) -> None:
408408
if self.args.debug_train_only or self.args.debug_rollout_only:
409409
return
410410

411-
if self.args.offload:
411+
if self.args.offload_train:
412412
reload_process_groups()
413413

414414
rollout_engines, rollout_engine_lock, num_new_engines = ray.get(
@@ -418,7 +418,7 @@ def update_weights(self) -> None:
418418
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
419419
dist.barrier(group=get_gloo_group())
420420

421-
with torch_memory_saver.disable() if self.args.offload else nullcontext():
421+
with torch_memory_saver.disable() if self.args.offload_train else nullcontext():
422422
print_memory("before update_weights")
423423
self.weight_updater.update_weights()
424424
print_memory("after update_weights")
@@ -435,7 +435,7 @@ def update_weights(self) -> None:
435435
else:
436436
self.update_cpu_params_dict(self.weights["old_actor"])
437437

438-
if self.args.offload:
438+
if self.args.offload_train:
439439
destroy_process_groups()
440440

441441
def load_other_checkpoint(self, model_tag: str, path: str) -> None:

slime/backends/sglang_utils/sglang_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port):
334334
"trust_remote_code": True,
335335
"random_seed": args.seed + rank,
336336
# memory
337-
"enable_memory_saver": args.offload,
337+
"enable_memory_saver": args.offload_rollout,
338338
# distributed
339339
"host": host,
340340
"port": port,

slime/ray/actor_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor, wandb_run_id: Optiona
6161
**{name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST},
6262
}
6363

64-
if self.args.offload:
64+
if self.args.offload_train:
6565
import torch_memory_saver
6666

6767
dynlib_path = os.path.join(

slime/ray/placement_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_rollout_manager(args, pg, wandb_run_id):
172172
args.num_rollout = num_rollout_per_epoch * args.num_epoch
173173
assert args.num_rollout > 0
174174

175-
if args.offload:
175+
if args.offload_rollout:
176176
ray.get(rollout_manager.offload.remote())
177177

178178
return rollout_manager, num_rollout_per_epoch

slime/utils/arguments.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,23 @@ def add_cluster_arguments(parser):
7979
"--offload",
8080
action="store_true",
8181
default=False,
82+
help=("Equivalent to --offload-train + --offload-rollout. "),
83+
)
84+
parser.add_argument(
85+
"--offload-train",
86+
action="store_true",
87+
default=False,
88+
help=(
89+
"Whether to offload the training actor to CPU during training. "
90+
"This will always be true when --colocate is set."
91+
),
92+
)
93+
parser.add_argument(
94+
"--offload-rollout",
95+
action="store_true",
96+
default=False,
8297
help=(
83-
"Whether to offload the rollout generator and training actor to CPU during training. "
98+
"Whether to offload the rollout generator to CPU during training. "
8499
"This will always be true when --colocate is set."
85100
),
86101
)
@@ -1202,22 +1217,27 @@ def slime_validate_args(args):
12021217
if args.critic_lr is None:
12031218
args.critic_lr = args.lr
12041219

1220+
if args.offload:
1221+
args.offload_train = True
1222+
args.offload_rollout = True
1223+
del args.offload
1224+
12051225
if args.debug_rollout_only:
12061226
if args.colocate and args.rollout_num_gpus is None:
12071227
args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes
12081228
else:
12091229
args.actor_num_gpus_per_node = min(8, args.rollout_num_gpus)
12101230
args.actor_num_nodes = args.rollout_num_gpus // args.actor_num_gpus_per_node
12111231
args.colocate = False
1212-
args.offload = False
1232+
args.offload_train = args.offload_rollout = False
12131233

12141234
assert not (args.debug_rollout_only and args.debug_train_only), (
12151235
"debug_rollout_only and debug_train_only cannot be set at the same time, " "please set only one of them."
12161236
)
12171237

12181238
# always true on offload for colocate at the moment.
12191239
if args.colocate:
1220-
args.offload = True
1240+
args.offload_train = args.offload_rollout = True
12211241
if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes:
12221242
print(
12231243
f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} "

train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def train(args):
2323
# create the actor and critic models
2424
actor_model, critic_model = create_training_models(args, pgs, rollout_manager, wandb_run_id=wandb_run_id)
2525

26-
if args.offload:
26+
if args.offload_rollout:
2727
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
2828

2929
# always update weight first so that sglang has the loaded weights from training.
3030
actor_model.update_weights()
3131

32-
if args.offload:
32+
if args.offload_rollout:
3333
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
3434
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
3535
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
@@ -47,7 +47,7 @@ def train(args):
4747

4848
rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))
4949

50-
if args.offload:
50+
if args.offload_rollout:
5151
ray.get(rollout_manager.offload.remote())
5252

5353
if args.use_critic:
@@ -69,7 +69,7 @@ def train(args):
6969
if args.rollout_global_dataset:
7070
ray.get(rollout_manager.save.remote(rollout_id))
7171

72-
if args.offload:
72+
if args.offload_train:
7373
if args.use_critic:
7474
critic_model.offload()
7575
if rollout_id >= args.num_critic_only_steps:
@@ -81,7 +81,7 @@ def train(args):
8181

8282
actor_model.update_weights()
8383

84-
if args.offload:
84+
if args.offload_rollout:
8585
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
8686
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
8787
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))

0 commit comments

Comments
 (0)