Skip to content

Commit a78aed8

Browse files
committed
update
1 parent 6548c16 commit a78aed8

2 files changed

Lines changed: 56 additions & 5 deletions

File tree

src/parallax/models/deepseek_v32.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,10 @@ def __call__(
142142
compressed_kv = self.kv_a_proj_with_mqa(x)
143143
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
144144
k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
145-
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
146-
kv = kv.reshape(batch, target_len, self.num_heads, -1)
147-
148-
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
149-
k_nope = k_nope.transpose(0, 2, 1, 3)
145+
kv_latent = self.kv_a_layernorm(compressed_kv)
146+
kv_latent = kv_latent[:, None, :, :]
147+
k_nope = self.embed_q(kv_latent, transpose=False)
148+
values = self.unembed_out(kv_latent).transpose(0, 2, 1, 3)
150149
key_cache_global, value_cache_global = cache.get_cache()
151150
indexer_cache = cache.get_indexer_cache()
152151

tests/test_deepseek_v32.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import sys
2+
3+
import mlx.core as mx
4+
import pytest
5+
6+
from parallax.models.deepseek_v32 import ModelArgs, ParallaxDeepSeekV32Attention
7+
from parallax.server.cache.dsa_cache import DeepSeekSparseCache
8+
9+
pytestmark = pytest.mark.skipif(sys.platform != "darwin", reason="MLX tests require macOS")
10+
11+
12+
def _tiny_args():
13+
return ModelArgs(
14+
hidden_size=16,
15+
num_attention_heads=2,
16+
num_key_value_heads=2,
17+
q_lora_rank=8,
18+
kv_lora_rank=4,
19+
qk_nope_head_dim=2,
20+
qk_rope_head_dim=2,
21+
v_head_dim=4,
22+
index_head_dim=4,
23+
index_n_heads=2,
24+
index_topk=4,
25+
num_hidden_layers=1,
26+
max_position_embeddings=16,
27+
)
28+
29+
30+
def test_attention_decode_forward_uses_glm_style_kv_cache():
31+
args = _tiny_args()
32+
attention = ParallaxDeepSeekV32Attention(args)
33+
cache = DeepSeekSparseCache(
34+
num_blocks=1,
35+
block_size=8,
36+
num_kv_heads=args.num_key_value_heads,
37+
head_dim=args.qk_nope_head_dim + args.qk_rope_head_dim,
38+
head_dim_v=args.v_head_dim,
39+
dtype=mx.float32,
40+
index_head_dim=args.index_head_dim,
41+
index_n_heads=args.index_n_heads,
42+
)
43+
44+
output = attention(
45+
mx.zeros((1, 1, args.hidden_size), dtype=mx.float32),
46+
cache=cache,
47+
block_tables=mx.array([[0]], dtype=mx.int32),
48+
context_lengths=mx.array([1], dtype=mx.int32),
49+
)
50+
mx.eval(output)
51+
52+
assert output.shape == (1, 1, args.hidden_size)

0 commit comments

Comments
 (0)