Skip to content

Commit f90407c

Browse files
committed
pre-commit
1 parent ccdfa5c commit f90407c

1 file changed

Lines changed: 8 additions & 25 deletions

File tree

src/parallax/models/glm4_moe_lite.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@
66

77
import mlx.core as mx
88
from mlx_lm.models.base import scaled_dot_product_attention
9-
from mlx_lm.models.glm4_moe_lite import (
10-
Glm4MoeLiteAttention as MLXGLM4MoeLiteAttention,
11-
)
12-
from mlx_lm.models.glm4_moe_lite import (
13-
Glm4MoeLiteDecoderLayer as MLXGLM4MoeLiteBlock,
14-
)
9+
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteAttention as MLXGLM4MoeLiteAttention
10+
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer as MLXGLM4MoeLiteBlock
1511
from mlx_lm.models.glm4_moe_lite import ModelArgs
1612

1713
from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache
@@ -64,16 +60,11 @@ def __call__(
6460
else:
6561
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
6662

67-
q = q.reshape(batch, target_len, self.num_heads, self.q_head_dim).transpose(
68-
0, 2, 1, 3
69-
)
63+
q = q.reshape(batch, target_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
7064
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
7165
compressed_kv = self.kv_a_proj_with_mqa(x)
7266
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
73-
k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(
74-
0, 2, 1, 3
75-
)
76-
67+
k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
7768

7869
kv_latent = self.kv_a_layernorm(compressed_kv)
7970

@@ -140,15 +131,11 @@ def __call__(
140131
output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1)
141132
else:
142133
# Prefill phase
143-
has_prefix_cache = prefix_lens is not None and bool(
144-
mx.any(prefix_lens > 0)
145-
)
134+
has_prefix_cache = prefix_lens is not None and bool(mx.any(prefix_lens > 0))
146135

147136
if has_prefix_cache:
148137
k_new = keys # (batch, 1, target_len, key_head_dim)
149-
v_new = values.transpose(
150-
0, 2, 1, 3
151-
) # (batch, 1, target_len, kv_lora_rank)
138+
v_new = values.transpose(0, 2, 1, 3) # (batch, 1, target_len, kv_lora_rank)
152139
output = compute_attention_with_prefix_cache(
153140
queries,
154141
k_new,
@@ -165,9 +152,7 @@ def __call__(
165152
# output: (batch, num_heads, target_len, kv_lora_rank)
166153
output = self.unembed_out(output)
167154
# output: (batch, num_heads, target_len, v_head_dim)
168-
output = output.transpose(0, 2, 1, 3).reshape(
169-
batch, target_len, -1
170-
)
155+
output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1)
171156
else:
172157
# No prefix cache, standard self-attention
173158
if mask is not None:
@@ -184,9 +169,7 @@ def __call__(
184169
# output: (batch, num_heads, target_len, kv_lora_rank)
185170
output = self.unembed_out(output)
186171
# output: (batch, num_heads, target_len, v_head_dim)
187-
output = output.transpose(0, 2, 1, 3).reshape(
188-
batch, target_len, -1
189-
)
172+
output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1)
190173

191174
return self.o_proj(output)
192175

0 commit comments

Comments
 (0)