Skip to content

Commit a133f93

Browse files
authored
Tiny extract and enhance oom dumper (#568)
1 parent 8e2a8a3 commit a133f93

File tree

5 files changed

+45
-20
lines changed

5 files changed

+45
-20
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import wandb
2323

2424
from slime.ray.train_actor import TrainRayActor
25+
from slime.utils import profile_utils
2526
from slime.utils.data import get_minimum_num_micro_batch_size, process_rollout_data
2627
from slime.utils.distributed_utils import get_gloo_group
2728
from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss
@@ -60,6 +61,12 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
6061
self.args = args
6162
torch.manual_seed(args.seed)
6263

64+
if args.record_memory_history:
65+
profile_utils.attach_oom_dump_memory_history(
66+
memory_snapshot_dir=args.memory_snapshot_dir,
67+
memory_snapshot_path=args.memory_snapshot_path,
68+
)
69+
6370
for i in range(dist.get_world_size()):
6471
if i == dist.get_rank():
6572
self.hf_config = AutoConfig.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)

slime/backends/fsdp_utils/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class FSDPArgs:
3232
# FSDP configuration
3333
fsdp_full_params: bool = False # If True, use full_tensor; if False, use shard_tensor
3434

35+
# Profile
36+
record_memory_history: bool = False
37+
memory_snapshot_path: str = "snapshot.pickle"
38+
3539
# YAML bookkeeping
3640
config: str | None = None
3741

slime/backends/megatron_utils/model_provider.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from megatron.core.transformer.spec_utils import import_module
1616
from megatron.core.transformer.transformer_config import TransformerConfig
1717
from megatron.training.arguments import core_transformer_config_from_args
18+
from slime.utils import profile_utils
1819

1920

2021
# Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82
@@ -70,29 +71,13 @@ def model_provider(
7071
"""
7172
use_te = args.transformer_impl == "transformer_engine"
7273

74+
# TODO maybe move this to other parts
7375
if args.record_memory_history:
74-
torch.cuda.memory._record_memory_history(
75-
# True,
76-
# keep 100,000 alloc/free events from before the snapshot
77-
max_entries=100000,
78-
# record stack information for the trace events
79-
# trace_alloc_record_context=True,
80-
stacks="all",
76+
profile_utils.attach_oom_dump_memory_history(
77+
memory_snapshot_dir=args.memory_snapshot_dir,
78+
memory_snapshot_path=args.memory_snapshot_path,
8179
)
8280

83-
def oom_observer(device, alloc, device_alloc, device_free):
84-
# snapshot right after an OOM happened
85-
print("saving allocated state during OOM")
86-
snapshot = torch.cuda.memory._snapshot()
87-
from pickle import dump
88-
89-
dump(
90-
snapshot,
91-
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", "wb"),
92-
)
93-
94-
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
95-
9681
# Experimental loading arguments from yaml
9782
config: TransformerConfig = core_transformer_config_from_args(args)
9883

slime/utils/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,12 @@ def add_debug_arguments(parser):
863863
default=None,
864864
help=("Dump all details of training for post-hoc analysis and visualization."),
865865
)
866+
# use together with --record-memory-history and --memory-snapshot-path (defined in Megatron)
867+
parser.add_argument(
868+
"--memory-snapshot-dir",
869+
type=str,
870+
default=".",
871+
)
866872
return parser
867873

868874
def add_network_arguments(parser):

slime/utils/profile_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pickle import dump
2+
3+
import torch
4+
5+
6+
# The memory_snapshot_path is not a full path, but we name like this to be compatible with megatron
7+
def attach_oom_dump_memory_history(memory_snapshot_dir, memory_snapshot_path):
8+
torch.cuda.memory._record_memory_history(
9+
max_entries=100000,
10+
# record stack information for the trace events
11+
# trace_alloc_record_context=True,
12+
stacks="all",
13+
)
14+
15+
def oom_observer(device, alloc, device_alloc, device_free):
16+
path_dump = memory_snapshot_dir / f"oom_rank-{torch.distributed.get_rank()}_{memory_snapshot_path}"
17+
print(f"Observe OOM, will dump snapshot to {path_dump}. ({device=} {alloc=} {device_alloc=} {device_free=})")
18+
19+
# TODO use `_dump_snapshot` instead?
20+
snapshot = torch.cuda.memory._snapshot()
21+
dump(snapshot, open(path_dump, "wb"))
22+
23+
torch._C._cuda_attach_out_of_memory_observer(oom_observer)

0 commit comments

Comments
 (0)