Skip to content

Commit 75699ad

Browse files
authored
feat: implement non-absorbed MLAAttention layer (#911)
1 parent 7807050 commit 75699ad

3 files changed

Lines changed: 460 additions & 0 deletions

File tree

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""Multi-head Latent Attention (MLA) layer."""
2+
3+
from typing import Any
4+
5+
import jax
6+
import jax.numpy as jnp
7+
from flax import nnx
8+
9+
from sgl_jax.srt.layers.embeddings import get_rope
10+
from sgl_jax.srt.layers.layernorm import RMSNorm
11+
from sgl_jax.srt.layers.linear import LinearBase
12+
from sgl_jax.srt.layers.radix_attention import RadixAttention
13+
from sgl_jax.srt.mem_cache.memory_pool import KVCache
14+
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch
15+
16+
17+
class MLAAttention(nnx.Module):
18+
"""Multi-head Latent Attention (non-absorbed mode).
19+
20+
Decompresses latent state into Q/K/V during forward and reuses
21+
RadixAttention + MHATokenToKVPool. ~43x more KV cache than absorbed
22+
mode; will be replaced once the MLA Pallas kernel is production-ready.
23+
24+
Data flow:
25+
Q path: hidden -> q_a_proj -> norm -> q_b_proj -> split(q_nope, q_rope)
26+
KV path: hidden -> kv_a_proj -> split(compressed, k_rope)
27+
compressed -> norm -> kv_b_proj -> split(k_nope, v)
28+
RoPE: applied only to q_rope and k_rope
29+
Assembly: Q = concat(q_nope, q_rope'), K = concat(k_nope, k_rope')
30+
"""
31+
32+
def __init__(
33+
self,
34+
hidden_size: int,
35+
num_heads: int,
36+
q_lora_rank: int | None,
37+
kv_lora_rank: int,
38+
qk_nope_head_dim: int,
39+
qk_rope_head_dim: int,
40+
v_head_dim: int,
41+
mesh: jax.sharding.Mesh,
42+
layer_id: int = 0,
43+
rope_theta: float = 10000.0,
44+
rope_scaling: dict[str, Any] | None = None,
45+
rope_interleave: bool = True,
46+
max_position_embeddings: int = 163840,
47+
dtype: jnp.dtype = jnp.bfloat16,
48+
):
49+
super().__init__()
50+
51+
self.mesh = mesh
52+
self.num_heads = num_heads
53+
self.qk_nope_head_dim = qk_nope_head_dim
54+
self.qk_rope_head_dim = qk_rope_head_dim
55+
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
56+
self.v_head_dim = v_head_dim
57+
self.kv_lora_rank = kv_lora_rank
58+
self.q_lora_rank = q_lora_rank
59+
60+
if q_lora_rank is None:
61+
self.q_proj = LinearBase(
62+
hidden_size,
63+
num_heads * self.qk_head_dim,
64+
mesh,
65+
use_bias=False,
66+
params_dtype=dtype,
67+
kernel_axes=(None, "tensor"),
68+
scope_name="q_proj",
69+
)
70+
else:
71+
self.q_a_proj = LinearBase(
72+
hidden_size,
73+
q_lora_rank,
74+
mesh,
75+
use_bias=False,
76+
params_dtype=dtype,
77+
kernel_axes=(None, None),
78+
scope_name="q_a_proj",
79+
)
80+
self.q_a_layernorm = RMSNorm(q_lora_rank, param_dtype=jnp.float32)
81+
self.q_b_proj = LinearBase(
82+
q_lora_rank,
83+
num_heads * self.qk_head_dim,
84+
mesh,
85+
use_bias=False,
86+
params_dtype=dtype,
87+
kernel_axes=(None, "tensor"),
88+
scope_name="q_b_proj",
89+
)
90+
91+
self.kv_a_proj = LinearBase(
92+
hidden_size,
93+
kv_lora_rank + qk_rope_head_dim,
94+
mesh,
95+
use_bias=False,
96+
params_dtype=dtype,
97+
kernel_axes=(None, None),
98+
scope_name="kv_a_proj",
99+
)
100+
self.kv_a_layernorm = RMSNorm(kv_lora_rank, param_dtype=jnp.float32)
101+
self.kv_b_proj = LinearBase(
102+
kv_lora_rank,
103+
num_heads * (qk_nope_head_dim + v_head_dim),
104+
mesh,
105+
use_bias=False,
106+
params_dtype=dtype,
107+
kernel_axes=(None, "tensor"),
108+
scope_name="kv_b_proj",
109+
)
110+
111+
self.o_proj = LinearBase(
112+
num_heads * v_head_dim,
113+
hidden_size,
114+
mesh,
115+
use_bias=False,
116+
params_dtype=dtype,
117+
kernel_axes=("tensor", None),
118+
scope_name="o_proj",
119+
)
120+
121+
self.rotary_emb = get_rope(
122+
head_size=qk_rope_head_dim,
123+
rotary_dim=qk_rope_head_dim,
124+
max_position=max_position_embeddings,
125+
base=int(rope_theta),
126+
is_neox_style=not rope_interleave,
127+
rope_scaling=rope_scaling,
128+
dtype=dtype,
129+
)
130+
131+
self.attn = RadixAttention(
132+
num_heads=num_heads,
133+
head_dim=self.qk_head_dim,
134+
scaling=self.qk_head_dim**-0.5,
135+
num_kv_heads=num_heads,
136+
layer_id=layer_id,
137+
)
138+
139+
def __call__(
140+
self,
141+
positions: jax.Array,
142+
hidden_states: jax.Array,
143+
forward_batch: ForwardBatch,
144+
token_to_kv_pool: KVCache,
145+
) -> tuple[jax.Array, jax.Array]:
146+
if self.q_lora_rank is None:
147+
q, _ = self.q_proj(hidden_states)
148+
else:
149+
q_compressed, _ = self.q_a_proj(hidden_states)
150+
q_compressed = self.q_a_layernorm(q_compressed)
151+
q, _ = self.q_b_proj(q_compressed)
152+
q = q.reshape(-1, self.num_heads, self.qk_head_dim)
153+
q_nope = q[:, :, : self.qk_nope_head_dim]
154+
q_rope = q[:, :, self.qk_nope_head_dim :]
155+
156+
kv_a_out, _ = self.kv_a_proj(hidden_states)
157+
compressed = kv_a_out[:, : self.kv_lora_rank]
158+
k_rope_raw = kv_a_out[:, self.kv_lora_rank :]
159+
160+
compressed = self.kv_a_layernorm(compressed)
161+
kv_out, _ = self.kv_b_proj(compressed)
162+
kv_out = kv_out.reshape(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
163+
k_nope = kv_out[:, :, : self.qk_nope_head_dim]
164+
v = kv_out[:, :, self.qk_nope_head_dim :]
165+
166+
# Pad V to qk_head_dim to match K, required by fused MHATokenToKVPool.
167+
v = jnp.pad(v, ((0, 0), (0, 0), (0, self.qk_head_dim - self.v_head_dim)))
168+
169+
k_rope = k_rope_raw.reshape(-1, 1, self.qk_rope_head_dim)
170+
171+
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
172+
k_rope = jnp.broadcast_to(
173+
k_rope,
174+
(k_rope.shape[0], self.num_heads, self.qk_rope_head_dim),
175+
out_sharding=jax.sharding.PartitionSpec(None, "tensor", None),
176+
)
177+
178+
q = jnp.concatenate([q_nope, q_rope], axis=-1)
179+
k = jnp.concatenate([k_nope, k_rope], axis=-1)
180+
181+
attn_output, kv_fused = self.attn(
182+
q,
183+
k,
184+
v,
185+
forward_batch=forward_batch,
186+
token_to_kv_pool=token_to_kv_pool,
187+
)
188+
189+
# Strip V padding: o_proj expects num_heads * v_head_dim.
190+
attn_output = attn_output.reshape(-1, self.num_heads, self.qk_head_dim)
191+
attn_output = attn_output[:, :, : self.v_head_dim].reshape(
192+
-1, self.num_heads * self.v_head_dim
193+
)
194+
195+
output, _ = self.o_proj(attn_output)
196+
return output, kv_fused

0 commit comments

Comments
 (0)