|
15 | 15 | from megatron.core.transformer.spec_utils import import_module |
16 | 16 | from megatron.core.transformer.transformer_config import TransformerConfig |
17 | 17 | from megatron.training.arguments import core_transformer_config_from_args |
| 18 | +from slime.utils import profile_utils |
18 | 19 |
|
19 | 20 |
|
20 | 21 | # Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82 |
@@ -70,29 +71,13 @@ def model_provider( |
70 | 71 | """ |
71 | 72 | use_te = args.transformer_impl == "transformer_engine" |
72 | 73 |
|
| 74 | + # TODO maybe move this to other parts |
73 | 75 | if args.record_memory_history: |
74 | | - torch.cuda.memory._record_memory_history( |
75 | | - # True, |
76 | | - # keep 100,000 alloc/free events from before the snapshot |
77 | | - max_entries=100000, |
78 | | - # record stack information for the trace events |
79 | | - # trace_alloc_record_context=True, |
80 | | - stacks="all", |
| 76 | + profile_utils.attach_oom_dump_memory_history( |
| 77 | + memory_snapshot_dir=args.memory_snapshot_dir, |
| 78 | + memory_snapshot_path=args.memory_snapshot_path, |
81 | 79 | ) |
82 | 80 |
|
83 | | - def oom_observer(device, alloc, device_alloc, device_free): |
84 | | - # snapshot right after an OOM happened |
85 | | - print("saving allocated state during OOM") |
86 | | - snapshot = torch.cuda.memory._snapshot() |
87 | | - from pickle import dump |
88 | | - |
89 | | - dump( |
90 | | - snapshot, |
91 | | - open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", "wb"), |
92 | | - ) |
93 | | - |
94 | | - torch._C._cuda_attach_out_of_memory_observer(oom_observer) |
95 | | - |
96 | 81 | # Experimental loading arguments from yaml |
97 | 82 | config: TransformerConfig = core_transformer_config_from_args(args) |
98 | 83 |
|
|
0 commit comments