Skip to content

Commit d4c6dcc

Browse files
authored
Support zero host or device memory waste for weight update (THUDM#973)
1 parent 1b4c94e commit d4c6dcc

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def init(
9191

9292
self.weights_backuper = TensorBackuper.create(
9393
source_getter=lambda: named_params_and_buffers(
94-
self.args, self.model, convert_to_global_name=args.megatron_to_hf_mode == "raw"
94+
self.args,
95+
self.model,
96+
convert_to_global_name=args.megatron_to_hf_mode == "raw",
97+
translate_gpu_to_cpu=not self.args.enable_weights_backuper,
9598
),
9699
single_tag=None if args.enable_weights_backuper else "actor",
97100
)

slime/backends/megatron_utils/update_weight/common.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,26 @@ def named_params_and_buffers(
117117
args: Namespace,
118118
model: Sequence[torch.nn.Module],
119119
convert_to_global_name: bool = True,
120+
translate_gpu_to_cpu: bool = False,
120121
) -> Iterator[tuple[str, torch.Tensor]]:
121122
if convert_to_global_name:
122-
return _named_params_and_buffers_global(args, model)
123+
ans = _named_params_and_buffers_global(args, model)
123124
else:
124-
return _named_params_and_buffers_vanilla(model)
125+
ans = _named_params_and_buffers_vanilla(model)
126+
127+
if translate_gpu_to_cpu:
128+
ans = ((name, _maybe_get_cpu_backup(tensor)) for name, tensor in ans)
129+
130+
return ans
131+
132+
133+
def _maybe_get_cpu_backup(x: torch.Tensor):
134+
from torch_memory_saver import torch_memory_saver
135+
136+
if (cpu_tensor := torch_memory_saver.get_cpu_backup(x)) is not None:
137+
return cpu_tensor
138+
139+
return x
125140

126141

127142
def _named_params_and_buffers_vanilla(model: Sequence[torch.nn.Module]) -> Iterator[tuple[str, torch.Tensor]]:

train.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@ def train(args):
2828
if args.offload_rollout:
2929
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
3030

31-
if args.offload_train and not args.enable_weights_backuper:
32-
actor_model.onload()
3331
# always update weight first so that sglang has the loaded weights from training.
3432
actor_model.update_weights()
35-
if args.offload_train and not args.enable_weights_backuper:
36-
actor_model.offload()
3733

3834
if args.check_weight_update_equal:
3935
ray.get(rollout_manager.check_weights.remote(action="compare"))
@@ -93,15 +89,9 @@ def onload_rollout():
9389
if args.rollout_global_dataset:
9490
ray.get(rollout_manager.save.remote(rollout_id))
9591

96-
if args.enable_weights_backuper:
97-
offload_train()
98-
onload_rollout()
99-
actor_model.update_weights()
100-
else:
101-
actor_model.clear_memory()
102-
onload_rollout()
103-
actor_model.update_weights()
104-
offload_train()
92+
offload_train()
93+
onload_rollout()
94+
actor_model.update_weights()
10595

10696
if args.offload_rollout:
10797
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:

0 commit comments

Comments
 (0)