Skip to content

Commit e992ce9

Browse files
authored
Fix NCCL out-of-memory error even when there is memory (THUDM#669)
1 parent 669dbb2 commit e992ce9

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

slime/ray/train_actor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ray
77
import torch
88
import torch.distributed as dist
9+
from torch_memory_saver import torch_memory_saver
910

1011
import slime.utils.eval_config
1112
from slime.ray.ray_actor import RayActor
@@ -46,6 +47,11 @@ def init(self, args, role, wandb_run_id, with_ref=False):
4647
self.role = role
4748
self.with_ref = with_ref
4849

50+
if (x := args.train_memory_margin_bytes) > 0:
51+
print(f"Set torch_memory_saver.memory_margin_bytes to {x}")
52+
assert args.offload_train
53+
torch_memory_saver.memory_margin_bytes = x
54+
4955
torch.serialization.add_safe_globals([slime.utils.eval_config.EvalDatasetConfig])
5056

5157
local_rank = int(os.environ.get("LOCAL_RANK", 0))

slime/utils/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def add_train_arguments(parser):
130130
default="{}",
131131
help="Extra environment variables for training process, e.g. PyTorch memory management ones.",
132132
)
133+
parser.add_argument(
134+
"--train-memory-margin-bytes",
135+
type=int,
136+
default=0,
137+
help="Add margin for train memory allocation.",
138+
)
133139

134140
return parser
135141

0 commit comments

Comments
 (0)