Skip to content

Commit b416219

Browse files
committed
integrate quantized linear with the model
Signed-off-by: Alex Chi Z <[email protected]>
1 parent 2f2196d commit b416219

File tree

8 files changed

+140
-53
lines changed

8 files changed

+140
-53
lines changed

README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ You may join skyzh's Discord server and study with the tiny-llm community.
3939
| 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 |
4040
| 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 |
4141
| 3.3 | Prefill-Decode Separation | 🚧 | 🚧 | 🚧 |
42-
| 3.4 | Parallelism | 🚧 | 🚧 | 🚧 |
43-
| 3.5 | AI Agent | 🚧 | 🚧 | 🚧 |
44-
| 3.6 | Streaming API Server | 🚧 | 🚧 | 🚧 |
42+
| 3.4 | Scheduler | 🚧 | 🚧 | 🚧 |
43+
| 3.5 | Parallelism | 🚧 | 🚧 | 🚧 |
44+
| 3.6 | AI Agent | 🚧 | 🚧 | 🚧 |
45+
| 3.7 | Streaming API Server | 🚧 | 🚧 | 🚧 |
46+
47+
Other topics not covered: quantized/compressed kv cache
4548

4649
<!--
4750

main.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,38 @@
44

55
parser = argparse.ArgumentParser()
66
parser.add_argument("--model", type=str, default="Qwen/Qwen2-7B-Instruct-MLX")
7-
parser.add_argument("--prompt", type=str, default="Give me a short introduction to large language model.")
7+
parser.add_argument(
8+
"--prompt",
9+
type=str,
10+
default="Give me a short introduction to large language model.",
11+
)
812
parser.add_argument("--solution", type=str, default="tiny_llm")
13+
parser.add_argument("--device", type=str, default="gpu")
914
args = parser.parse_args()
1015

1116
if args.solution == "tiny_llm":
1217
from tiny_llm import Qwen2Model, simple_generate
18+
1319
print("Using your tiny_llm solution")
1420
elif args.solution == "tiny_llm_week1_ref" or args.solution == "week1_ref":
1521
from tiny_llm_week1_ref import Qwen2Model, simple_generate
22+
1623
print("Using tiny_llm_week1_ref solution")
1724
elif args.solution == "tiny_llm_week2_ref" or args.solution == "week2_ref":
1825
from tiny_llm_week2_ref import Qwen2Model, simple_generate
26+
1927
print("Using tiny_llm_week2_ref solution")
2028
else:
2129
raise ValueError(f"Solution {args.solution} not supported")
2230

23-
with mx.stream(mx.gpu):
24-
mlx_model, tokenizer = load(
25-
args.model,
26-
tokenizer_config={"eos_token": "<|im_end|>"},
27-
model_config={"tie_word_embeddings": False, "rope_traditional": True},
28-
)
29-
tiny_llm_model = Qwen2Model(mlx_model)
31+
mlx_model, tokenizer = load(
32+
args.model,
33+
tokenizer_config={"eos_token": "<|im_end|>"},
34+
model_config={"tie_word_embeddings": False, "rope_traditional": True},
35+
)
3036

37+
with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu):
38+
tiny_llm_model = Qwen2Model(mlx_model)
3139
messages = [
3240
{"role": "system", "content": "You are a helpful assistant."},
3341
{"role": "user", "content": args.prompt},

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ main.cmd = "python main.py"
2828
test.cmd = "pytest tests"
2929
test-week1-ref.cmd = "pytest tests_ref_impl_week1"
3030
test-week2-ref.cmd = "pytest tests_ref_impl_week2"
31+
format = "ruff format"
3132

3233
[tool.pytest.ini_options]
3334
addopts = [

src/extensions_ref/build.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
cmd.initialize_options()
1616
cmd.build_temp = Path("build")
1717
cmd.build_lib = Path("build") / "lib"
18-
cmd.inplace = False # we do the copy by ourselves
18+
cmd.inplace = False # we do the copy by ourselves
1919
cmd.ensure_finalized()
2020
cmd.run()
2121
for output in cmd.get_outputs():

src/tiny_llm_week2_ref/basics.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import mlx.core as mx
2-
import math
2+
from .quantize import quantized_matmul
3+
from typing import Any
34

45

56
def softmax(x: mx.array, axis: int) -> mx.array:
@@ -18,5 +19,49 @@ def linear(
1819
return mx.matmul(x, w.T)
1920

2021

22+
class QuantizedWeights:
23+
def __init__(
24+
self,
25+
scales: mx.array,
26+
biases: mx.array,
27+
group_size: int,
28+
bits: int,
29+
weight: mx.array,
30+
):
31+
self.scales = scales
32+
self.biases = biases
33+
self.group_size = group_size
34+
self.bits = bits
35+
self.weight = weight
36+
37+
@staticmethod
38+
def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights":
39+
return QuantizedWeights(
40+
scales=mlx_layer.scales,
41+
biases=mlx_layer.biases,
42+
group_size=mlx_layer.group_size,
43+
bits=mlx_layer.bits,
44+
weight=mlx_layer.weight,
45+
)
46+
47+
48+
def quantized_linear(
49+
x: mx.array,
50+
w: QuantizedWeights,
51+
bias: mx.array | None = None,
52+
) -> mx.array:
53+
if bias is not None:
54+
return (
55+
quantized_matmul(
56+
w.scales, w.biases, w.group_size, w.bits, x, w.weight, True
57+
)
58+
+ bias
59+
)
60+
else:
61+
return quantized_matmul(
62+
w.scales, w.biases, w.group_size, w.bits, x, w.weight, True
63+
)
64+
65+
2166
def silu(x: mx.array) -> mx.array:
2267
return x / (1 + mx.exp(-x))

src/tiny_llm_week2_ref/quantize.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def dequantize_linear(mx_layer: Any) -> mx.array:
1414
)
1515
return w
1616

17+
1718
def quantized_matmul(
1819
scales: mx.array,
1920
biases: mx.array,
@@ -23,4 +24,8 @@ def quantized_matmul(
2324
b: mx.array,
2425
transpose_b: bool = False,
2526
) -> mx.array:
26-
return tiny_llm_ext_ref.quantized_matmul(scales, biases, group_size, bits, a, b, transpose_b)
27+
*N, D = a.shape
28+
a = a.reshape(-1, D)
29+
return tiny_llm_ext_ref.quantized_matmul(
30+
scales, biases, group_size, bits, a, b, transpose_b
31+
).reshape(*N, -1)

src/tiny_llm_week2_ref/qwen2.py

+53-36
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import mlx.core as mx
2-
from .basics import linear, silu
2+
from .basics import linear, silu, QuantizedWeights, quantized_linear
33
from .attention import scaled_dot_product_attention_grouped
44
from .layer_norm import RMSNorm
55
from .positional_encoding import RoPE
@@ -15,10 +15,10 @@ def __init__(
1515
hidden_size: int,
1616
num_heads: int,
1717
num_kv_heads: int,
18-
wq: mx.array,
19-
wk: mx.array,
20-
wv: mx.array,
21-
wo: mx.array,
18+
wq: QuantizedWeights,
19+
wk: QuantizedWeights,
20+
wv: QuantizedWeights,
21+
wo: QuantizedWeights,
2222
bq: mx.array,
2323
bk: mx.array,
2424
bv: mx.array,
@@ -52,13 +52,13 @@ def __call__(
5252
cache: TinyKvCache,
5353
) -> mx.array:
5454
B, L, _ = x.shape
55-
projection_q = linear(x, self.wq, bias=self.bq).reshape(
55+
projection_q = quantized_linear(x, self.wq, bias=self.bq).reshape(
5656
B, L, self.num_heads, self.head_dim
5757
)
58-
projection_k = linear(x, self.wk, bias=self.bk).reshape(
58+
projection_k = quantized_linear(x, self.wk, bias=self.bk).reshape(
5959
B, L, self.num_kv_heads, self.head_dim
6060
)
61-
projection_v = linear(x, self.wv, bias=self.bv).reshape(
61+
projection_v = quantized_linear(x, self.wv, bias=self.bv).reshape(
6262
B, L, self.num_kv_heads, self.head_dim
6363
)
6464
projection_q = self.rope(projection_q, offset=slice(offset, offset + L))
@@ -76,17 +76,17 @@ def __call__(
7676
scale=self.scale,
7777
).astype(x.dtype)
7878
x = x.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size)
79-
return linear(x, self.wo)
79+
return quantized_linear(x, self.wo)
8080

8181

8282
class Qwen2MLP:
8383
def __init__(
8484
self,
8585
dim: int,
8686
hidden_dim: int,
87-
w_gate: mx.array,
88-
w_up: mx.array,
89-
w_down: mx.array,
87+
w_gate: QuantizedWeights,
88+
w_up: QuantizedWeights,
89+
w_down: QuantizedWeights,
9090
):
9191
self.dim = dim
9292
self.hidden_dim = hidden_dim
@@ -95,7 +95,10 @@ def __init__(
9595
self.w_down = w_down
9696

9797
def __call__(self, x: mx.array) -> mx.array:
98-
return linear(silu(linear(x, self.w_gate)) * linear(x, self.w_up), self.w_down)
98+
return quantized_linear(
99+
silu(quantized_linear(x, self.w_gate)) * quantized_linear(x, self.w_up),
100+
self.w_down,
101+
)
99102

100103

101104
class Qwen2TransformerBlock:
@@ -106,16 +109,16 @@ def __init__(
106109
hidden_size: int,
107110
intermediate_size: int,
108111
rms_norm_eps: float,
109-
wq: mx.array,
110-
wk: mx.array,
111-
wv: mx.array,
112-
wo: mx.array,
112+
wq: QuantizedWeights,
113+
wk: QuantizedWeights,
114+
wv: QuantizedWeights,
115+
wo: QuantizedWeights,
113116
bq: mx.array,
114117
bk: mx.array,
115118
bv: mx.array,
116-
w_gate: mx.array,
117-
w_up: mx.array,
118-
w_down: mx.array,
119+
w_gate: QuantizedWeights,
120+
w_up: QuantizedWeights,
121+
w_down: QuantizedWeights,
119122
w_input_layernorm: mx.array,
120123
w_post_attention_layernorm: mx.array,
121124
max_seq_len: int = 32768,
@@ -175,30 +178,44 @@ def __init__(
175178
self.layers_inner = []
176179

177180
for i in range(mlx_model.args.num_hidden_layers):
178-
wq = dequantize_linear(mlx_model.model.layers[i].self_attn.q_proj)
179-
wk = dequantize_linear(mlx_model.model.layers[i].self_attn.k_proj)
180-
wv = dequantize_linear(mlx_model.model.layers[i].self_attn.v_proj)
181-
wo = dequantize_linear(mlx_model.model.layers[i].self_attn.o_proj)
182-
w_gate = dequantize_linear(mlx_model.model.layers[i].mlp.gate_proj)
183-
w_up = dequantize_linear(mlx_model.model.layers[i].mlp.up_proj)
184-
w_down = dequantize_linear(mlx_model.model.layers[i].mlp.down_proj)
181+
wq = QuantizedWeights.from_mlx_layer(
182+
mlx_model.model.layers[i].self_attn.q_proj
183+
)
184+
wk = QuantizedWeights.from_mlx_layer(
185+
mlx_model.model.layers[i].self_attn.k_proj
186+
)
187+
wv = QuantizedWeights.from_mlx_layer(
188+
mlx_model.model.layers[i].self_attn.v_proj
189+
)
190+
wo = QuantizedWeights.from_mlx_layer(
191+
mlx_model.model.layers[i].self_attn.o_proj
192+
)
193+
w_gate = QuantizedWeights.from_mlx_layer(
194+
mlx_model.model.layers[i].mlp.gate_proj
195+
)
196+
w_up = QuantizedWeights.from_mlx_layer(
197+
mlx_model.model.layers[i].mlp.up_proj
198+
)
199+
w_down = QuantizedWeights.from_mlx_layer(
200+
mlx_model.model.layers[i].mlp.down_proj
201+
)
185202

186203
layer = Qwen2TransformerBlock(
187204
num_attention_heads=mlx_model.args.num_attention_heads,
188205
num_kv_heads=mlx_model.args.num_key_value_heads,
189206
hidden_size=mlx_model.args.hidden_size,
190207
intermediate_size=mlx_model.args.intermediate_size,
191208
rms_norm_eps=mlx_model.args.rms_norm_eps,
192-
wq=wq.astype(precision),
193-
wk=wk.astype(precision),
194-
wv=wv.astype(precision),
195-
wo=wo.astype(precision),
209+
wq=wq,
210+
wk=wk,
211+
wv=wv,
212+
wo=wo,
196213
bq=mlx_model.model.layers[i].self_attn.q_proj.bias.astype(precision),
197214
bk=mlx_model.model.layers[i].self_attn.k_proj.bias.astype(precision),
198215
bv=mlx_model.model.layers[i].self_attn.v_proj.bias.astype(precision),
199-
w_gate=w_gate.astype(precision),
200-
w_up=w_up.astype(precision),
201-
w_down=w_down.astype(precision),
216+
w_gate=w_gate,
217+
w_up=w_up,
218+
w_down=w_down,
202219
w_input_layernorm=mlx_model.model.layers[
203220
i
204221
].input_layernorm.weight.astype(precision),
@@ -214,7 +231,7 @@ def __init__(
214231
weight=mlx_model.model.norm.weight.astype(precision),
215232
eps=mlx_model.args.rms_norm_eps,
216233
)
217-
self.w_lm_head = dequantize_linear(mlx_model.lm_head)
234+
self.w_lm_head = QuantizedWeights.from_mlx_layer(mlx_model.lm_head)
218235
self.mlx_model = mlx_model
219236

220237
def __call__(
@@ -227,4 +244,4 @@ def __call__(
227244
for layer in range(self.num_hidden_layers):
228245
h = self.layers_inner[layer](h, offset, cache[layer])
229246
h = self.norm(h)
230-
return linear(h, self.w_lm_head)
247+
return quantized_linear(h, self.w_lm_head)

tests/test_week_2_day_2.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import numpy as np
55
from .utils import *
66

7-
def quantized_matmul_helper(stream: mx.Stream, identity_matrix: bool, precision: np.dtype):
7+
8+
def quantized_matmul_helper(
9+
stream: mx.Stream, identity_matrix: bool, precision: np.dtype
10+
):
811
with mx.stream(stream):
912
if identity_matrix:
1013
input = mx.array(np.eye(64).astype(precision))
@@ -32,5 +35,10 @@ def quantized_matmul_helper(stream: mx.Stream, identity_matrix: bool, precision:
3235
)
3336
assert_allclose(user_out, ref_out, precision)
3437

35-
def test_task_1_quantized_matmul_f16_cpu():
36-
quantized_matmul_helper(mx.cpu, True,np.float16)
38+
39+
def test_task_1_quantized_matmul_simple_f16_cpu():
40+
quantized_matmul_helper(mx.cpu, True, np.float16)
41+
42+
43+
def test_task_1_quantized_matmul_complex_f16_cpu():
44+
quantized_matmul_helper(mx.cpu, False, np.float16)

0 commit comments

Comments
 (0)