Skip to content

Commit 4c0c214

Browse files
committed
fix(compat): add torch API compatibility patches for Titan imports
- add patch_torch_dcp_consolidate(): skip safetensors consolidate if missing - add patch_torch_pipelining_schedules(): inject fallback for missing ScheduleDualPipeV - add patch_torch_flex_attention_auxoutput(): inject AuxOutput stub for flex_attention - add detailed background comments and structured log messages - ensure all patches applied before importing TorchTitan
1 parent f8e35c9 commit 4c0c214

File tree

3 files changed

+163
-4
lines changed

3 files changed

+163
-4
lines changed

examples/torchtitan/prepare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def install_torch_for_rocm(nightly=True):
184184

185185

186186
if __name__ == "__main__":
187-
log_info("========== Prepare torch for Torchtitan ==========")
188-
install_torch_for_rocm(nightly=True)
187+
# log_info("========== Prepare torch for Torchtitan ==========")
188+
# install_torch_for_rocm(nightly=True)
189189

190190
log_info("========== Prepare Torchtitan dataset ==========")
191191
main()

primus/backends/torchtitan/primus_turbo_extensions/primus_turbo_converter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import torch
88
from torchtitan.config.job_config import JobConfig
99
from torchtitan.distributed import ParallelDims
10-
from torchtitan.models.attention import FlexAttention, ScaledDotProductAttention
10+
from torchtitan.models.attention import (
11+
FlexAttentionWrapper,
12+
ScaledDotProductAttentionWrapper,
13+
)
1114
from torchtitan.protocols.model_converter import (
1215
ModelConverter,
1316
register_model_converter,
@@ -18,7 +21,7 @@ def replace_turbo_attention_modules(model: torch.nn.Module, backend_type: str, u
1821
from primus_turbo.pytorch.modules import TurboAttention # TODO: import Check
1922

2023
for name, module in model.named_children():
21-
if isinstance(module, (FlexAttention, ScaledDotProductAttention)):
24+
if isinstance(module, (FlexAttentionWrapper, ScaledDotProductAttentionWrapper)):
2225
setattr(
2326
model,
2427
name,

primus/modules/trainer/torchtitan/pre_trainer.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,29 @@ def __init__(self, *args, **kwargs):
1818
# important: make sure patch torchtitan logger first
1919
self.patch_torchtitan_logger()
2020

21+
# ensure checkpoint patch applied before import torchtitan
22+
# background: consolidate_safetensors_files_on_every_rank is a new DCP
23+
# utility introduced in newer torch versions. our current build does not
24+
# include it yet. this patch safely skips safetensors consolidation and
25+
# issues a warning so Titan checkpoints can still work normally.
26+
self.patch_torch_dcp_consolidate()
27+
28+
# ensure ScheduleDualPipeV is available
29+
# background: ScheduleDualPipeV is a newer pipeline schedule recently
30+
# introduced in torch.distributed; our current torch build does not
31+
# include it yet. this patch injects a temporary alias to fall back to
32+
# Schedule1F1B or ScheduleGPipe so Titan imports can succeed.
33+
self.patch_torch_pipelining_schedules()
34+
35+
# ensure AuxOutput exists in flex_attention for model imports
36+
# background: AuxOutput is a newly introduced optional return type in
37+
# torch.nn.attention.flex_attention, used for debug or profiling data
38+
# (e.g., attention probabilities or mask stats). our current torch build
39+
# does not yet include it. this patch injects a lightweight stub class
40+
# so model imports succeed. Titan does not rely on AuxOutput in its
41+
# attention or training logic, so this patch does not affect behavior.
42+
self.patch_torch_flex_attention_auxoutput()
43+
2144
from torchtitan.config.job_config import JobConfig
2245
from torchtitan.train import Trainer
2346

@@ -59,6 +82,139 @@ def patch_torchtitan_logger(self):
5982
titan_logging.logger = primus_logger
6083
titan_logging.init_logger = lambda: None
6184

85+
def patch_torch_dcp_consolidate(self):
86+
"""
87+
Monkey patch for torch.distributed.checkpoint._consolidate_hf_safetensors
88+
when current torch build does not export consolidate_safetensors_files_on_every_rank.
89+
This avoids ImportError in TorchTitan when last_save_in_hf=True.
90+
"""
91+
import sys
92+
import types
93+
import warnings
94+
95+
mod_name = "torch.distributed.checkpoint._consolidate_hf_safetensors"
96+
func_name = "consolidate_safetensors_files_on_every_rank"
97+
98+
try:
99+
mod = __import__(mod_name, fromlist=["*"])
100+
if hasattr(mod, func_name):
101+
primus_logger.info("[PrimusPatch][DCP] consolidate available, no patch needed.")
102+
return # OK, torch build supports it
103+
except Exception:
104+
pass
105+
106+
# Patch missing module/function
107+
dummy_mod = types.ModuleType(mod_name)
108+
109+
def _warn_and_skip(*args, **kwargs):
110+
warnings.warn(
111+
"[PrimusPatch][DCP] Current PyTorch build does not support "
112+
f"{mod_name}.{func_name}; safetensors export will be skipped.",
113+
UserWarning,
114+
)
115+
return None
116+
117+
setattr(dummy_mod, func_name, _warn_and_skip)
118+
sys.modules[mod_name] = dummy_mod
119+
120+
from primus.core.utils.logger import _logger as primus_logger
121+
122+
primus_logger.warning(
123+
f"[PrimusPatch][DCP] Installed fallback for missing {mod_name}.{func_name}, "
124+
"HuggingFace safetensors export will be disabled."
125+
)
126+
127+
def patch_torch_pipelining_schedules(self):
128+
"""
129+
Ensure torch.distributed.pipelining.schedules.ScheduleDualPipeV exists.
130+
131+
If this class is missing in the current PyTorch build (common in ROCm 7.0 / 2.9),
132+
we create a fallback alias that inherits from Schedule1F1B or ScheduleGPipe.
133+
This prevents ImportError in TorchTitan pipeline modules.
134+
"""
135+
136+
from primus.core.utils.logger import _logger as primus_logger
137+
138+
try:
139+
import torch.distributed.pipelining.schedules as sched
140+
except Exception as e:
141+
primus_logger.warning(f"[PrimusPatch][Pipe] failed to import schedules: {e}")
142+
return
143+
144+
# Check if DualPipeV is already provided
145+
if hasattr(sched, "ScheduleDualPipeV"):
146+
primus_logger.info("[PrimusPatch][Pipe] ScheduleDualPipeV available, no patch needed.")
147+
return # No patch needed
148+
149+
# Pick a safe fallback
150+
fallback = getattr(sched, "Schedule1F1B", None) or getattr(sched, "ScheduleGPipe", None)
151+
152+
if fallback is None:
153+
primus_logger.warning(
154+
"[PrimusPatch][Pipe] No pipeline schedule available; pipeline parallel may be unsupported."
155+
)
156+
return
157+
158+
# Define the fallback class with identical signature
159+
class ScheduleDualPipeV(fallback): # type: ignore[misc]
160+
def __init__(self, *args, **kwargs):
161+
primus_logger.warning(
162+
f"[PrimusPatch][Pipe] ScheduleDualPipeV not found, using fallback {fallback.__name__}. "
163+
f"This is a temporary compatibility patch; functionality may differ from the official DualPipeV."
164+
)
165+
super().__init__(*args, **kwargs)
166+
167+
# Inject into torch namespace
168+
setattr(sched, "ScheduleDualPipeV", ScheduleDualPipeV)
169+
primus_logger.warning(
170+
f"[PrimusPatch][Pipe] Installed fallback: ScheduleDualPipeV -> {fallback.__name__}"
171+
)
172+
173+
def patch_torch_flex_attention_auxoutput(self):
174+
"""
175+
Ensure torch.nn.attention.flex_attention has an AuxOutput symbol.
176+
Some PyTorch builds (e.g., certain ROCm 2.9 dev builds) rename or drop it.
177+
We provide a safe alias so Titan's imports won't fail.
178+
"""
179+
from primus.core.utils.logger import _logger as primus_logger
180+
181+
try:
182+
import torch.nn.attention.flex_attention as flex_mod
183+
except Exception as e:
184+
primus_logger.warning(f"[PrimusPatch][FlexAttn] flex_attention import failed: {e}")
185+
return
186+
187+
# If AuxOutput already exists, nothing to do.
188+
if hasattr(flex_mod, "AuxOutput"):
189+
primus_logger.info("[PrimusPatch][FlexAttn] AuxOutput available, no patch needed.")
190+
return
191+
192+
primus_logger.warning(
193+
"[PrimusPatch][FlexAttn] AuxOutput not found. "
194+
"This torch build predates the new debug/profiling return type. "
195+
"Injecting a lightweight stub so Titan model imports can succeed."
196+
)
197+
198+
from dataclasses import dataclass
199+
200+
import torch
201+
202+
@dataclass
203+
class _AuxOutput:
204+
attn_probs: torch.Tensor = torch.empty(0)
205+
block_mask: torch.Tensor | None = None
206+
stats: dict | None = None
207+
extra: dict | None = None
208+
209+
def __init__(self, **kwargs):
210+
for k, v in kwargs.items():
211+
setattr(self, k, v)
212+
213+
setattr(flex_mod, "AuxOutput", _AuxOutput)
214+
primus_logger.info(
215+
"[PrimusPatch][FlexAttn] Injected fallback AuxOutput stub (Titan does not rely on this)."
216+
)
217+
62218
def enable_primus_turbo_extension(self):
63219
"""
64220
Enable Primus-Turbo features and extensions.

0 commit comments

Comments
 (0)