66
77import mlx .core as mx
88from mlx_lm .models .base import scaled_dot_product_attention
9- from mlx_lm .models .glm4_moe_lite import (
10- Glm4MoeLiteAttention as MLXGLM4MoeLiteAttention ,
11- )
12- from mlx_lm .models .glm4_moe_lite import (
13- Glm4MoeLiteDecoderLayer as MLXGLM4MoeLiteBlock ,
14- )
9+ from mlx_lm .models .glm4_moe_lite import Glm4MoeLiteAttention as MLXGLM4MoeLiteAttention
10+ from mlx_lm .models .glm4_moe_lite import Glm4MoeLiteDecoderLayer as MLXGLM4MoeLiteBlock
1511from mlx_lm .models .glm4_moe_lite import ModelArgs
1612
1713from parallax .metal .paged_attention .kernel import paged_attention , reshape_and_cache
@@ -64,16 +60,11 @@ def __call__(
6460 else :
6561 q = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (x )))
6662
67- q = q .reshape (batch , target_len , self .num_heads , self .q_head_dim ).transpose (
68- 0 , 2 , 1 , 3
69- )
63+ q = q .reshape (batch , target_len , self .num_heads , self .q_head_dim ).transpose (0 , 2 , 1 , 3 )
7064 q_nope , q_pe = mx .split (q , [self .qk_nope_head_dim ], axis = - 1 )
7165 compressed_kv = self .kv_a_proj_with_mqa (x )
7266 compressed_kv , k_pe = mx .split (compressed_kv , [self .kv_lora_rank ], axis = - 1 )
73- k_pe = k_pe .reshape (batch , target_len , 1 , self .qk_rope_head_dim ).transpose (
74- 0 , 2 , 1 , 3
75- )
76-
67+ k_pe = k_pe .reshape (batch , target_len , 1 , self .qk_rope_head_dim ).transpose (0 , 2 , 1 , 3 )
7768
7869 kv_latent = self .kv_a_layernorm (compressed_kv )
7970
@@ -140,15 +131,11 @@ def __call__(
140131 output = output .transpose (0 , 2 , 1 , 3 ).reshape (batch , target_len , - 1 )
141132 else :
142133 # Prefill phase
143- has_prefix_cache = prefix_lens is not None and bool (
144- mx .any (prefix_lens > 0 )
145- )
134+ has_prefix_cache = prefix_lens is not None and bool (mx .any (prefix_lens > 0 ))
146135
147136 if has_prefix_cache :
148137 k_new = keys # (batch, 1, target_len, key_head_dim)
149- v_new = values .transpose (
150- 0 , 2 , 1 , 3
151- ) # (batch, 1, target_len, kv_lora_rank)
138+ v_new = values .transpose (0 , 2 , 1 , 3 ) # (batch, 1, target_len, kv_lora_rank)
152139 output = compute_attention_with_prefix_cache (
153140 queries ,
154141 k_new ,
@@ -165,9 +152,7 @@ def __call__(
165152 # output: (batch, num_heads, target_len, kv_lora_rank)
166153 output = self .unembed_out (output )
167154 # output: (batch, num_heads, target_len, v_head_dim)
168- output = output .transpose (0 , 2 , 1 , 3 ).reshape (
169- batch , target_len , - 1
170- )
155+ output = output .transpose (0 , 2 , 1 , 3 ).reshape (batch , target_len , - 1 )
171156 else :
172157 # No prefix cache, standard self-attention
173158 if mask is not None :
@@ -184,9 +169,7 @@ def __call__(
184169 # output: (batch, num_heads, target_len, kv_lora_rank)
185170 output = self .unembed_out (output )
186171 # output: (batch, num_heads, target_len, v_head_dim)
187- output = output .transpose (0 , 2 , 1 , 3 ).reshape (
188- batch , target_len , - 1
189- )
172+ output = output .transpose (0 , 2 , 1 , 3 ).reshape (batch , target_len , - 1 )
190173
191174 return self .o_proj (output )
192175
0 commit comments