Skip to content

Commit 1baa042

Browse files
Xiaoming-AMDXiaoming-AMD
andauthored
fix: Disable double DDP construction inside build_model() via runtime patch (#264)
Co-authored-by: Xiaoming-AMD <[email protected]>
1 parent 03a36b1 commit 1baa042

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

primus/pretrain.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,45 @@
1616
# Lazy backend loader
1717
def load_backend_trainer(framework: str):
1818
if framework == "megatron":
19+
import megatron.training.training as training
20+
import torch
21+
22+
_original_build_model = training.get_model
23+
24+
def _patched_get_model(*args, **kwargs):
25+
"""
26+
Monkey-patched version of build_model that removes the second
27+
DDP construction inside torch.cuda.stream() block.
28+
"""
29+
import inspect
30+
31+
from megatron.training import training as tr
32+
33+
inspect.getsource(tr.get_model)
34+
print("[PrimusPatch] Overriding build_model to disable second DDP construction...")
35+
36+
_orig_stream_ctx = torch.cuda.stream
37+
38+
def _noop_stream(*args, **kwargs):
39+
class DummyCtx:
40+
def __enter__(self):
41+
return None
42+
43+
def __exit__(self, *a):
44+
return False
45+
46+
return DummyCtx()
47+
48+
torch.cuda.stream = _noop_stream
49+
50+
try:
51+
return _original_build_model(*args, **kwargs)
52+
finally:
53+
torch.cuda.stream = _orig_stream_ctx
54+
55+
training.get_model = _patched_get_model
56+
print("[PrimusPatch] Applied Megatron build_model monkey-patch to disable second DDP.")
57+
1958
from primus.modules.trainer.megatron.pre_trainer import MegatronPretrainTrainer
2059

2160
return MegatronPretrainTrainer

0 commit comments

Comments
 (0)