Skip to content

Commit e539dc7

Browse files
authored
fix: fix torchtitan traning issue in TurboAttention (#253)
1 parent 9ee2c51 commit e539dc7

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

primus/backends/torchtitan/models/llama3/model/__init__.py

Whitespace-only changes.

primus/backends/torchtitan/models/llama3/model.py renamed to primus/backends/torchtitan/models/llama3/model/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
###############################################################################
66

77
import torch
8+
from torch.nn.attention.flex_attention import BlockMask
89
from torchtitan.models.llama3.model.model import Attention as TTAttention
910
from torchtitan.models.llama3.model.model import apply_rotary_emb
1011

12+
AttentionMasksType = dict[str, BlockMask] | BlockMask
13+
1114

1215
class Attention(TTAttention):
1316
def forward(
1417
self,
1518
x: torch.Tensor,
1619
freqs_cis: torch.Tensor,
20+
attention_masks: AttentionMasksType | None,
1721
):
1822
bs, seqlen, _ = x.shape
1923
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@@ -31,7 +35,7 @@ def forward(
3135
# xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
3236
# xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
3337

34-
output = self.sdpa(xq, xk, xv)
38+
output = self.inner_attention(xq, xk, xv)
3539

3640
output = output.view(bs, seqlen, -1)
3741
return self.wo(output)

primus/modules/trainer/torchtitan/pre_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ def enable_primus_turbo_extension(self):
239239

240240
if self.titan_config.primus_turbo.use_turbo_attention:
241241
# ******* llama3 Attention Model *******
242-
import torchtitan.models.llama3.model
242+
import torchtitan.models.llama3.model.model
243243

244-
from primus.backends.torchtitan.models.llama3.model import Attention
244+
from primus.backends.torchtitan.models.llama3.model.model import Attention
245245

246-
torchtitan.models.llama3.model.Attention = Attention
246+
torchtitan.models.llama3.model.model.Attention = Attention
247247
logger.warning(f"TorchtitanPretrainTrainer: Patch Turbo Attention")
248248

249249
if self.titan_config.primus_turbo.use_turbo_mx_linear:

0 commit comments

Comments
 (0)