@@ -18,6 +18,29 @@ def __init__(self, *args, **kwargs):
1818 # important: make sure patch torchtitan logger first
1919 self .patch_torchtitan_logger ()
2020
21+ # ensure checkpoint patch applied before import torchtitan
22+ # background: consolidate_safetensors_files_on_every_rank is a new DCP
23+ # utility introduced in newer torch versions. our current build does not
24+ # include it yet. this patch safely skips safetensors consolidation and
25+ # issues a warning so Titan checkpoints can still work normally.
26+ self .patch_torch_dcp_consolidate ()
27+
28+ # ensure ScheduleDualPipeV is available
29+ # background: ScheduleDualPipeV is a newer pipeline schedule recently
30+ # introduced in torch.distributed; our current torch build does not
31+ # include it yet. this patch injects a temporary alias to fall back to
32+ # Schedule1F1B or ScheduleGPipe so Titan imports can succeed.
33+ self .patch_torch_pipelining_schedules ()
34+
35+ # ensure AuxOutput exists in flex_attention for model imports
36+ # background: AuxOutput is a newly introduced optional return type in
37+ # torch.nn.attention.flex_attention, used for debug or profiling data
38+ # (e.g., attention probabilities or mask stats). our current torch build
39+ # does not yet include it. this patch injects a lightweight stub class
40+ # so model imports succeed. Titan does not rely on AuxOutput in its
41+ # attention or training logic, so this patch does not affect behavior.
42+ self .patch_torch_flex_attention_auxoutput ()
43+
2144 from torchtitan .config .job_config import JobConfig
2245 from torchtitan .train import Trainer
2346
@@ -59,6 +82,139 @@ def patch_torchtitan_logger(self):
5982 titan_logging .logger = primus_logger
6083 titan_logging .init_logger = lambda : None
6184
85+ def patch_torch_dcp_consolidate (self ):
86+ """
87+ Monkey patch for torch.distributed.checkpoint._consolidate_hf_safetensors
88+ when current torch build does not export consolidate_safetensors_files_on_every_rank.
89+ This avoids ImportError in TorchTitan when last_save_in_hf=True.
90+ """
91+ import sys
92+ import types
93+ import warnings
94+
95+ mod_name = "torch.distributed.checkpoint._consolidate_hf_safetensors"
96+ func_name = "consolidate_safetensors_files_on_every_rank"
97+
98+ try :
99+ mod = __import__ (mod_name , fromlist = ["*" ])
100+ if hasattr (mod , func_name ):
101+ primus_logger .info ("[PrimusPatch][DCP] consolidate available, no patch needed." )
102+ return # OK, torch build supports it
103+ except Exception :
104+ pass
105+
106+ # Patch missing module/function
107+ dummy_mod = types .ModuleType (mod_name )
108+
109+ def _warn_and_skip (* args , ** kwargs ):
110+ warnings .warn (
111+ "[PrimusPatch][DCP] Current PyTorch build does not support "
112+ f"{ mod_name } .{ func_name } ; safetensors export will be skipped." ,
113+ UserWarning ,
114+ )
115+ return None
116+
117+ setattr (dummy_mod , func_name , _warn_and_skip )
118+ sys .modules [mod_name ] = dummy_mod
119+
120+ from primus .core .utils .logger import _logger as primus_logger
121+
122+ primus_logger .warning (
123+ f"[PrimusPatch][DCP] Installed fallback for missing { mod_name } .{ func_name } , "
124+ "HuggingFace safetensors export will be disabled."
125+ )
126+
127+ def patch_torch_pipelining_schedules (self ):
128+ """
129+ Ensure torch.distributed.pipelining.schedules.ScheduleDualPipeV exists.
130+
131+ If this class is missing in the current PyTorch build (common in ROCm 7.0 / 2.9),
132+ we create a fallback alias that inherits from Schedule1F1B or ScheduleGPipe.
133+ This prevents ImportError in TorchTitan pipeline modules.
134+ """
135+
136+ from primus .core .utils .logger import _logger as primus_logger
137+
138+ try :
139+ import torch .distributed .pipelining .schedules as sched
140+ except Exception as e :
141+ primus_logger .warning (f"[PrimusPatch][Pipe] failed to import schedules: { e } " )
142+ return
143+
144+ # Check if DualPipeV is already provided
145+ if hasattr (sched , "ScheduleDualPipeV" ):
146+ primus_logger .info ("[PrimusPatch][Pipe] ScheduleDualPipeV available, no patch needed." )
147+ return # No patch needed
148+
149+ # Pick a safe fallback
150+ fallback = getattr (sched , "Schedule1F1B" , None ) or getattr (sched , "ScheduleGPipe" , None )
151+
152+ if fallback is None :
153+ primus_logger .warning (
154+ "[PrimusPatch][Pipe] No pipeline schedule available; pipeline parallel may be unsupported."
155+ )
156+ return
157+
158+ # Define the fallback class with identical signature
159+ class ScheduleDualPipeV (fallback ): # type: ignore[misc]
160+ def __init__ (self , * args , ** kwargs ):
161+ primus_logger .warning (
162+ f"[PrimusPatch][Pipe] ScheduleDualPipeV not found, using fallback { fallback .__name__ } . "
163+ f"This is a temporary compatibility patch; functionality may differ from the official DualPipeV."
164+ )
165+ super ().__init__ (* args , ** kwargs )
166+
167+ # Inject into torch namespace
168+ setattr (sched , "ScheduleDualPipeV" , ScheduleDualPipeV )
169+ primus_logger .warning (
170+ f"[PrimusPatch][Pipe] Installed fallback: ScheduleDualPipeV -> { fallback .__name__ } "
171+ )
172+
173+ def patch_torch_flex_attention_auxoutput (self ):
174+ """
175+ Ensure torch.nn.attention.flex_attention has an AuxOutput symbol.
176+ Some PyTorch builds (e.g., certain ROCm 2.9 dev builds) rename or drop it.
177+ We provide a safe alias so Titan's imports won't fail.
178+ """
179+ from primus .core .utils .logger import _logger as primus_logger
180+
181+ try :
182+ import torch .nn .attention .flex_attention as flex_mod
183+ except Exception as e :
184+ primus_logger .warning (f"[PrimusPatch][FlexAttn] flex_attention import failed: { e } " )
185+ return
186+
187+ # If AuxOutput already exists, nothing to do.
188+ if hasattr (flex_mod , "AuxOutput" ):
189+ primus_logger .info ("[PrimusPatch][FlexAttn] AuxOutput available, no patch needed." )
190+ return
191+
192+ primus_logger .warning (
193+ "[PrimusPatch][FlexAttn] AuxOutput not found. "
194+ "This torch build predates the new debug/profiling return type. "
195+ "Injecting a lightweight stub so Titan model imports can succeed."
196+ )
197+
198+ from dataclasses import dataclass
199+
200+ import torch
201+
202+ @dataclass
203+ class _AuxOutput :
204+ attn_probs : torch .Tensor = torch .empty (0 )
205+ block_mask : torch .Tensor | None = None
206+ stats : dict | None = None
207+ extra : dict | None = None
208+
209+ def __init__ (self , ** kwargs ):
210+ for k , v in kwargs .items ():
211+ setattr (self , k , v )
212+
213+ setattr (flex_mod , "AuxOutput" , _AuxOutput )
214+ primus_logger .info (
215+ "[PrimusPatch][FlexAttn] Injected fallback AuxOutput stub (Titan does not rely on this)."
216+ )
217+
62218 def enable_primus_turbo_extension (self ):
63219 """
64220 Enable Primus-Turbo features and extensions.
0 commit comments