File tree Expand file tree Collapse file tree 1 file changed +39
-0
lines changed Expand file tree Collapse file tree 1 file changed +39
-0
lines changed Original file line number Diff line number Diff line change 1616# Lazy backend loader
1717def load_backend_trainer (framework : str ):
1818 if framework == "megatron" :
19+ import megatron .training .training as training
20+ import torch
21+
22+ _original_build_model = training .get_model
23+
24+ def _patched_get_model (* args , ** kwargs ):
25+ """
26+ Monkey-patched version of build_model that removes the second
27+ DDP construction inside torch.cuda.stream() block.
28+ """
29+ import inspect
30+
31+ from megatron .training import training as tr
32+
33+ inspect .getsource (tr .get_model )
34+ print ("[PrimusPatch] Overriding build_model to disable second DDP construction..." )
35+
36+ _orig_stream_ctx = torch .cuda .stream
37+
38+ def _noop_stream (* args , ** kwargs ):
39+ class DummyCtx :
40+ def __enter__ (self ):
41+ return None
42+
43+ def __exit__ (self , * a ):
44+ return False
45+
46+ return DummyCtx ()
47+
48+ torch .cuda .stream = _noop_stream
49+
50+ try :
51+ return _original_build_model (* args , ** kwargs )
52+ finally :
53+ torch .cuda .stream = _orig_stream_ctx
54+
55+ training .get_model = _patched_get_model
56+ print ("[PrimusPatch] Applied Megatron build_model monkey-patch to disable second DDP." )
57+
1958 from primus .modules .trainer .megatron .pre_trainer import MegatronPretrainTrainer
2059
2160 return MegatronPretrainTrainer
You can’t perform that action at this time.
0 commit comments