1
1
import mlx .core as mx
2
- from .basics import linear , silu
2
+ from .basics import linear , silu , QuantizedWeights , quantized_linear
3
3
from .attention import scaled_dot_product_attention_grouped
4
4
from .layer_norm import RMSNorm
5
5
from .positional_encoding import RoPE
@@ -15,10 +15,10 @@ def __init__(
15
15
hidden_size : int ,
16
16
num_heads : int ,
17
17
num_kv_heads : int ,
18
- wq : mx . array ,
19
- wk : mx . array ,
20
- wv : mx . array ,
21
- wo : mx . array ,
18
+ wq : QuantizedWeights ,
19
+ wk : QuantizedWeights ,
20
+ wv : QuantizedWeights ,
21
+ wo : QuantizedWeights ,
22
22
bq : mx .array ,
23
23
bk : mx .array ,
24
24
bv : mx .array ,
@@ -52,13 +52,13 @@ def __call__(
52
52
cache : TinyKvCache ,
53
53
) -> mx .array :
54
54
B , L , _ = x .shape
55
- projection_q = linear (x , self .wq , bias = self .bq ).reshape (
55
+ projection_q = quantized_linear (x , self .wq , bias = self .bq ).reshape (
56
56
B , L , self .num_heads , self .head_dim
57
57
)
58
- projection_k = linear (x , self .wk , bias = self .bk ).reshape (
58
+ projection_k = quantized_linear (x , self .wk , bias = self .bk ).reshape (
59
59
B , L , self .num_kv_heads , self .head_dim
60
60
)
61
- projection_v = linear (x , self .wv , bias = self .bv ).reshape (
61
+ projection_v = quantized_linear (x , self .wv , bias = self .bv ).reshape (
62
62
B , L , self .num_kv_heads , self .head_dim
63
63
)
64
64
projection_q = self .rope (projection_q , offset = slice (offset , offset + L ))
@@ -76,17 +76,17 @@ def __call__(
76
76
scale = self .scale ,
77
77
).astype (x .dtype )
78
78
x = x .transpose (0 , 2 , 1 , 3 ).reshape (B , L , self .hidden_size )
79
- return linear (x , self .wo )
79
+ return quantized_linear (x , self .wo )
80
80
81
81
82
82
class Qwen2MLP :
83
83
def __init__ (
84
84
self ,
85
85
dim : int ,
86
86
hidden_dim : int ,
87
- w_gate : mx . array ,
88
- w_up : mx . array ,
89
- w_down : mx . array ,
87
+ w_gate : QuantizedWeights ,
88
+ w_up : QuantizedWeights ,
89
+ w_down : QuantizedWeights ,
90
90
):
91
91
self .dim = dim
92
92
self .hidden_dim = hidden_dim
@@ -95,7 +95,10 @@ def __init__(
95
95
self .w_down = w_down
96
96
97
97
def __call__ (self , x : mx .array ) -> mx .array :
98
- return linear (silu (linear (x , self .w_gate )) * linear (x , self .w_up ), self .w_down )
98
+ return quantized_linear (
99
+ silu (quantized_linear (x , self .w_gate )) * quantized_linear (x , self .w_up ),
100
+ self .w_down ,
101
+ )
99
102
100
103
101
104
class Qwen2TransformerBlock :
@@ -106,16 +109,16 @@ def __init__(
106
109
hidden_size : int ,
107
110
intermediate_size : int ,
108
111
rms_norm_eps : float ,
109
- wq : mx . array ,
110
- wk : mx . array ,
111
- wv : mx . array ,
112
- wo : mx . array ,
112
+ wq : QuantizedWeights ,
113
+ wk : QuantizedWeights ,
114
+ wv : QuantizedWeights ,
115
+ wo : QuantizedWeights ,
113
116
bq : mx .array ,
114
117
bk : mx .array ,
115
118
bv : mx .array ,
116
- w_gate : mx . array ,
117
- w_up : mx . array ,
118
- w_down : mx . array ,
119
+ w_gate : QuantizedWeights ,
120
+ w_up : QuantizedWeights ,
121
+ w_down : QuantizedWeights ,
119
122
w_input_layernorm : mx .array ,
120
123
w_post_attention_layernorm : mx .array ,
121
124
max_seq_len : int = 32768 ,
@@ -175,30 +178,44 @@ def __init__(
175
178
self .layers_inner = []
176
179
177
180
for i in range (mlx_model .args .num_hidden_layers ):
178
- wq = dequantize_linear (mlx_model .model .layers [i ].self_attn .q_proj )
179
- wk = dequantize_linear (mlx_model .model .layers [i ].self_attn .k_proj )
180
- wv = dequantize_linear (mlx_model .model .layers [i ].self_attn .v_proj )
181
- wo = dequantize_linear (mlx_model .model .layers [i ].self_attn .o_proj )
182
- w_gate = dequantize_linear (mlx_model .model .layers [i ].mlp .gate_proj )
183
- w_up = dequantize_linear (mlx_model .model .layers [i ].mlp .up_proj )
184
- w_down = dequantize_linear (mlx_model .model .layers [i ].mlp .down_proj )
181
+ wq = QuantizedWeights .from_mlx_layer (
182
+ mlx_model .model .layers [i ].self_attn .q_proj
183
+ )
184
+ wk = QuantizedWeights .from_mlx_layer (
185
+ mlx_model .model .layers [i ].self_attn .k_proj
186
+ )
187
+ wv = QuantizedWeights .from_mlx_layer (
188
+ mlx_model .model .layers [i ].self_attn .v_proj
189
+ )
190
+ wo = QuantizedWeights .from_mlx_layer (
191
+ mlx_model .model .layers [i ].self_attn .o_proj
192
+ )
193
+ w_gate = QuantizedWeights .from_mlx_layer (
194
+ mlx_model .model .layers [i ].mlp .gate_proj
195
+ )
196
+ w_up = QuantizedWeights .from_mlx_layer (
197
+ mlx_model .model .layers [i ].mlp .up_proj
198
+ )
199
+ w_down = QuantizedWeights .from_mlx_layer (
200
+ mlx_model .model .layers [i ].mlp .down_proj
201
+ )
185
202
186
203
layer = Qwen2TransformerBlock (
187
204
num_attention_heads = mlx_model .args .num_attention_heads ,
188
205
num_kv_heads = mlx_model .args .num_key_value_heads ,
189
206
hidden_size = mlx_model .args .hidden_size ,
190
207
intermediate_size = mlx_model .args .intermediate_size ,
191
208
rms_norm_eps = mlx_model .args .rms_norm_eps ,
192
- wq = wq . astype ( precision ) ,
193
- wk = wk . astype ( precision ) ,
194
- wv = wv . astype ( precision ) ,
195
- wo = wo . astype ( precision ) ,
209
+ wq = wq ,
210
+ wk = wk ,
211
+ wv = wv ,
212
+ wo = wo ,
196
213
bq = mlx_model .model .layers [i ].self_attn .q_proj .bias .astype (precision ),
197
214
bk = mlx_model .model .layers [i ].self_attn .k_proj .bias .astype (precision ),
198
215
bv = mlx_model .model .layers [i ].self_attn .v_proj .bias .astype (precision ),
199
- w_gate = w_gate . astype ( precision ) ,
200
- w_up = w_up . astype ( precision ) ,
201
- w_down = w_down . astype ( precision ) ,
216
+ w_gate = w_gate ,
217
+ w_up = w_up ,
218
+ w_down = w_down ,
202
219
w_input_layernorm = mlx_model .model .layers [
203
220
i
204
221
].input_layernorm .weight .astype (precision ),
@@ -214,7 +231,7 @@ def __init__(
214
231
weight = mlx_model .model .norm .weight .astype (precision ),
215
232
eps = mlx_model .args .rms_norm_eps ,
216
233
)
217
- self .w_lm_head = dequantize_linear (mlx_model .lm_head )
234
+ self .w_lm_head = QuantizedWeights . from_mlx_layer (mlx_model .lm_head )
218
235
self .mlx_model = mlx_model
219
236
220
237
def __call__ (
@@ -227,4 +244,4 @@ def __call__(
227
244
for layer in range (self .num_hidden_layers ):
228
245
h = self .layers_inner [layer ](h , offset , cache [layer ])
229
246
h = self .norm (h )
230
- return linear (h , self .w_lm_head )
247
+ return quantized_linear (h , self .w_lm_head )
0 commit comments