Skip to content

Commit 1e2e1b1

Browse files
Xiaoming-AMDzhenhuang12xiaobochen-amd
authored
refactor(torchtitan): rollback Titan to 99c0cb2(20250907) and stabilize trainer UTs (#262)
Co-authored-by: zhenhuang12 <[email protected]> Co-authored-by: xiaobochen-amd <[email protected]>
1 parent 700fa2e commit 1e2e1b1

File tree

18 files changed

+331
-144
lines changed

18 files changed

+331
-144
lines changed

.github/workflows/docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Base image
22
# FROM docker.io/rocm/megatron-lm:v25.9_gfx942
3-
FROM docker.io/rocm/pyt-megatron-lm-jax-nightly-private:pytorch_rocm7.0_20251024
3+
FROM docker.io/rocm/primus:v25.9_gfx942
44

55
# Specify the commit of Primus-Turbo when building: docker build --build-arg PRIMUS_TURBO_COMMIT=xxx .)
66
ARG PRIMUS_TURBO_COMMIT

examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ modules:
7171
enable: true
7272
components: ["loss"] # ["model", "loss"]
7373

74+
primus_turbo:
75+
enable_primus_turbo: true
76+
enable_attention_float8: false
77+
7478
# quantize:
7579
# linear:
7680
# float8:

examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ modules:
5252
enable_async_tensor_parallel: false
5353
pipeline_parallel_degree: 1
5454
pipeline_parallel_schedule: "Interleaved1F1B"
55-
expert_parallel_degree: 1
55+
expert_parallel_degree: 8
5656
expert_tensor_parallel_degree: 1
5757

5858
checkpoint:

primus/backends/torchtitan/models/deepseek_v3/__init__.py

Whitespace-only changes.

primus/backends/torchtitan/models/deepseek_v3/model/__init__.py

Whitespace-only changes.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
import torch
8+
from torchtitan.models.deepseek_v3.model.model import Attention as TTAttention
9+
from torchtitan.models.deepseek_v3.model.model import apply_rotary_emb
10+
11+
12+
class Attention(TTAttention):
13+
def forward(
14+
self,
15+
x: torch.Tensor,
16+
freqs_cis: torch.Tensor,
17+
):
18+
"""
19+
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
20+
21+
Args:
22+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
23+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
24+
25+
Returns:
26+
torch.Tensor: Output tensor with the same shape as the input.
27+
"""
28+
bsz, seqlen, _ = x.size()
29+
30+
# Query projection
31+
if self.q_lora_rank == 0:
32+
q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim)
33+
else:
34+
q = self.wq_a(x)
35+
q = self.wq_b(self.q_norm(q))
36+
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
37+
# local heads from sizes of q and kv as TP may have sharded them after
38+
# the above linear ops.
39+
q = q.view(bsz, seqlen, -1, self.qk_head_dim)
40+
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
41+
q_pe = apply_rotary_emb(q_pe, freqs_cis)
42+
q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim)
43+
44+
# Key-value projection
45+
kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
46+
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
47+
48+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # (bsz, seqlen, 1, qk_rope_head_dim)
49+
50+
kv = self.wkv_b(self.kv_norm(kv)) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
51+
kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim)
52+
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
53+
k = torch.cat(
54+
[k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1
55+
) # (bsz, seqlen, n_heads, qk_head_dim)
56+
57+
q = q.view(bsz, seqlen, -1, self.qk_head_dim)
58+
k = k.view(bsz, seqlen, -1, self.qk_head_dim)
59+
v = v.view(bsz, seqlen, -1, self.v_head_dim)
60+
61+
output = self.sdpa(q, k, v)
62+
output = output.view(bsz, seqlen, -1)
63+
return self.wo(output)

primus/backends/torchtitan/models/llama3/model/model.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,16 @@
55
###############################################################################
66

77
import torch
8-
from torch.nn.attention.flex_attention import BlockMask
8+
9+
# from torch.nn.attention.flex_attention import BlockMask
910
from torchtitan.models.llama3.model.model import Attention as TTAttention
1011
from torchtitan.models.llama3.model.model import apply_rotary_emb
1112

12-
AttentionMasksType = dict[str, BlockMask] | BlockMask
13+
# AttentionMasksType = dict[str, BlockMask] | BlockMask
1314

1415

1516
class Attention(TTAttention):
16-
def forward(
17-
self,
18-
x: torch.Tensor,
19-
freqs_cis: torch.Tensor,
20-
attention_masks: AttentionMasksType | None,
21-
):
17+
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
2218
bs, seqlen, _ = x.shape
2319
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
2420

@@ -35,7 +31,8 @@ def forward(
3531
# xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
3632
# xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
3733

38-
output = self.inner_attention(xq, xk, xv)
34+
# output = self.inner_attention(xq, xk, xv)
35+
output = self.sdpa(xq, xk, xv)
3936

4037
output = output.view(bs, seqlen, -1)
4138
return self.wo(output)

primus/backends/torchtitan/primus_turbo_extensions/primus_turbo_converter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
import torch
88
from torchtitan.config.job_config import JobConfig
99
from torchtitan.distributed import ParallelDims
10-
from torchtitan.models.attention import (
11-
FlexAttentionWrapper,
12-
ScaledDotProductAttentionWrapper,
13-
)
10+
from torchtitan.models.attention import FlexAttention, ScaledDotProductAttention
1411
from torchtitan.protocols.model_converter import (
1512
ModelConverter,
1613
register_model_converter,
@@ -21,7 +18,7 @@ def replace_turbo_attention_modules(model: torch.nn.Module, backend_type: str, u
2118
from primus_turbo.pytorch.modules import TurboAttention # TODO: import Check
2219

2320
for name, module in model.named_children():
24-
if isinstance(module, (FlexAttentionWrapper, ScaledDotProductAttentionWrapper)):
21+
if isinstance(module, (FlexAttention, ScaledDotProductAttention)):
2522
setattr(
2623
model,
2724
name,

primus/configs/models/torchtitan/llama3.1_70B-fp8.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ model:
77
flavor: "70B"
88
hf_assets_path: "meta-llama/Llama-3.1-8B"
99
converters:
10-
- quantize.linear.float8
11-
- quantize.grouped_mm.float8
10+
- "float8"

primus/configs/models/torchtitan/llama3.1_8B-fp8.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ model:
77
flavor: "8B"
88
hf_assets_path: "meta-llama/Llama-3.1-8B"
99
converters:
10-
- quantize.linear.float8
11-
- quantize.grouped_mm.float8
10+
- "float8"

0 commit comments

Comments
 (0)