Skip to content

Commit 235f8be

Browse files
committed
implement quantized matmul
Signed-off-by: Alex Chi Z <[email protected]>
1 parent cc36d98 commit 235f8be

19 files changed

+178
-141
lines changed

.vscode/settings.json

-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
11
{
2-
"clangd.arguments": [
3-
"--compile-commands-dir=${workspaceFolder}/src/extensions_ref/build/temp.macosx-15.0-arm64-cpython-312/tiny_llm_ext_ref._ext"
4-
],
52
"cmake.ignoreCMakeListsMissing": true
63
}

book/src/setup.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ pdm install -v # this will automatically create a virtual environment and instal
3636
```bash
3737
pdm run python check.py
3838
# The reference solution should pass all the tests
39-
pdm run pytest tests_ref_impl_week1
39+
pdm run test_ref_impl_week1
4040
```
4141

4242
## Run Unit Tests
4343

4444
Your code is in `src/tiny_llm`. You can run the unit tests with:
4545

4646
```bash
47-
pdm run pytest tests
47+
pdm run test
4848
```
4949

5050
## Download the Model Parameters
@@ -70,7 +70,7 @@ huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX
7070
Then, you can run:
7171

7272
```bash
73-
pdm run python main_ref_impl_week1.py
73+
pdm run main --solution week1_ref
7474
```
7575

7676
It should load the model and print some text.

book/src/week1-01-attention.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ we will pass a tensor of the shape `N.. x 1024 x 512` to the attention layer.
2525
In this task, we will implement the scaled dot product attention function.
2626

2727
```
28-
pdm run pytest tests -k week_1_day_1_task_1 -v
28+
pdm run test -k week_1_day_1_task_1 -v
2929
```
3030

3131

@@ -66,8 +66,8 @@ mask: 1 x H x L x L
6666
At the end of this task, you should be able to pass the following tests:
6767

6868
```
69-
pdm run pytest tests -k test_attention_simple
70-
pdm run pytest tests -k test_attention_with_mask
69+
pdm run test -k test_attention_simple
70+
pdm run test -k test_attention_with_mask
7171
```
7272

7373
## Task 2: Implement `MultiHeadAttention`
@@ -115,7 +115,7 @@ W_o: (H x D) x E
115115
At the end of the day, you should be able to pass the following tests:
116116

117117
```
118-
pdm run pytest tests -k week_1_day_1_task_2 -v
118+
pdm run test -k week_1_day_1_task_2 -v
119119
```
120120

121121
{{#include copyright.md}}

book/src/week1-02-positional-encodings.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ You can do this by reshaping `x` to (N, L, H, D // 2, 2) and then applying the a
5252
You can test your implementation by running the following command:
5353

5454
```
55-
pdm run pytest tests -k week_1_day_2_task_1 -v
55+
pdm run test -k week_1_day_2_task_1 -v
5656
```
5757

5858
## Task 2: Implement `RoPE` in the non-traditional form
@@ -74,7 +74,7 @@ frequencies to each half separately.
7474
You can test your implementation by running the following command:
7575

7676
```
77-
pdm run pytest tests -k week_1_day_2_task_2 -v
77+
pdm run test -k week_1_day_2_task_2 -v
7878
```
7979

8080
**📚 Readings**

book/src/week2-overview.md

+5
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@ MLX uses INT4 W4A16
44
https://ml-explore.github.io/mlx/build/html/dev/extensions.html
55

66
pdm run ./build_ext.sh
7+
8+
speculative decoding
9+
prefill and decode separation
10+
quantized kv cache
11+
Assert return data type

build_ext.sh

-4
This file was deleted.

main.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
11
from mlx_lm import load
2-
from tiny_llm import Qwen2Model, simple_generate
32
import mlx.core as mx
3+
import argparse
4+
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument("--model", type=str, default="Qwen/Qwen2-7B-Instruct-MLX")
7+
parser.add_argument("--prompt", type=str, default="Give me a short introduction to large language model.")
8+
parser.add_argument("--solution", type=str, default="tiny_llm")
9+
args = parser.parse_args()
10+
11+
if args.solution == "tiny_llm":
12+
from tiny_llm import Qwen2Model, simple_generate
13+
print("Using your tiny_llm solution")
14+
elif args.solution == "tiny_llm_week1_ref" or args.solution == "week1_ref":
15+
from tiny_llm_week1_ref import Qwen2Model, simple_generate
16+
print("Using tiny_llm_week1_ref solution")
17+
elif args.solution == "tiny_llm_week2_ref" or args.solution == "week2_ref":
18+
from tiny_llm_week2_ref import Qwen2Model, simple_generate
19+
print("Using tiny_llm_week2_ref solution")
20+
else:
21+
raise ValueError(f"Solution {args.solution} not supported")
422

523
with mx.stream(mx.gpu):
624
mlx_model, tokenizer = load(
7-
"Qwen/Qwen2-7B-Instruct-MLX",
25+
args.model,
826
tokenizer_config={"eos_token": "<|im_end|>"},
927
model_config={"tie_word_embeddings": False, "rope_traditional": True},
1028
)
1129
tiny_llm_model = Qwen2Model(mlx_model)
1230

13-
prompt = "Give me a short introduction to large language model."
1431
messages = [
1532
{"role": "system", "content": "You are a helpful assistant."},
16-
{"role": "user", "content": prompt},
33+
{"role": "user", "content": args.prompt},
1734
]
1835
prompt = tokenizer.apply_chat_template(
1936
messages, tokenize=False, add_generation_prompt=True

main_ref_impl_week1.py

-21
This file was deleted.

main_ref_impl_week2.py

-21
This file was deleted.

pyproject.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
2-
requires = ["setuptools>=62", "cmake>=3.25", "mlx>=0.25.0", "nanobind==2.4.0"]
3-
build-backend = "setuptools.build_meta"
2+
requires = ["pdm-backend"]
3+
build-backend = "pdm.backend"
44

55
[project]
66
name = "tiny-llm"
@@ -21,6 +21,14 @@ dependencies = [
2121
"nanobind==2.4.0"
2222
]
2323

24+
[tool.pdm.scripts]
25+
build-ext-ref.cmd = "python build.py"
26+
build-ext-ref.working_dir = "src/extensions_ref"
27+
main.cmd = "python main.py"
28+
test.cmd = "pytest tests"
29+
test-week1-ref.cmd = "pytest tests_ref_impl_week1"
30+
test-week2-ref.cmd = "pytest tests_ref_impl_week2"
31+
2432
[tool.pytest.ini_options]
2533
addopts = [
2634
"--import-mode=importlib",

src/extensions_ref/.clangd

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
CompileFlags:
2+
CompilationDatabase: build/tiny_llm_ext_ref._ext

src/extensions_ref/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ target_sources(
3636
tiny_llm_ext_ref
3737
PUBLIC
3838
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
39-
# ${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.cpp
39+
${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.cpp
4040
)
4141

4242
# Add include headers

src/extensions_ref/bindings.cpp

+18-20
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,22 @@ NB_MODULE(_ext, m) {
3030
array: ``alpha * x + beta * y``
3131
)");
3232

33-
// m.def("quantized_linear", &tiny_llm_ext_ref::quantized_linear, "scales"_a, "biases"_a, "group_size"_a, "bits"_a,
34-
// "x"_a, "w"_a, "bias"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(),
35-
// R"(
36-
// Quantized linear layer
37-
38-
// Follows numpy style broadcasting between ``x`` and ``w``
39-
// Inputs are upcasted to floats if needed
40-
41-
// Args:
42-
// scales (array): Scaling factors for ``x``.
43-
// biases (array): Biases for ``x``.
44-
// group_size (int): Group size for ``x``.
45-
// bits (int): Number of bits for ``x``.
46-
// x (array): Input array.
47-
// w (array): Input array.
48-
// bias (array): Input array.
49-
50-
// Returns:
51-
// array: ``x * w + bias``
52-
// )");
33+
m.def("quantized_matmul", &tiny_llm_ext_ref::quantized_matmul,
34+
"scales"_a, "biases"_a, "group_size"_a, "bits"_a,
35+
"a"_a, "b"_a, "transpose_b"_a = false, "stream"_a = nb::none(),
36+
R"(
37+
Quantized matmul layer
38+
39+
Args:
40+
scales (array): Scaling factors for ``a``.
41+
biases (array): Biases for ``a``.
42+
group_size (int): Group size for ``a``.
43+
bits (int): Number of bits for ``a``.
44+
a (array): Input array.
45+
b (array): Input array.
46+
transpose_b (bool): Whether to transpose ``b`` before multiplication.
47+
48+
Returns:
49+
array: ``a * b``
50+
)");
5351
}

src/extensions_ref/build.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from pathlib import Path
2+
import shutil
3+
from mlx import extension
4+
from setuptools import Distribution
5+
6+
if __name__ == "__main__":
7+
src_dir = Path(__file__).parent
8+
distribution = Distribution(
9+
{
10+
"name": "tiny_llm_ext_ref",
11+
"ext_modules": [extension.CMakeExtension("tiny_llm_ext_ref._ext")],
12+
}
13+
)
14+
cmd = extension.CMakeBuild(distribution)
15+
cmd.initialize_options()
16+
cmd.build_temp = Path("build")
17+
cmd.build_lib = Path("build") / "lib"
18+
cmd.inplace = False # we do the copy by ourselves
19+
cmd.ensure_finalized()
20+
cmd.run()
21+
for output in cmd.get_outputs():
22+
output = Path(output)
23+
relative_extension = src_dir / output.relative_to(cmd.build_lib)
24+
shutil.copyfile(output, relative_extension)
25+
print(f"Copied {output} to {relative_extension}")

src/extensions_ref/setup.py

-12
This file was deleted.

src/extensions_ref/src/quantized_matmul.cpp

+63-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#include <arm_fp16.h>
2+
3+
#include <cstdint>
14
#include <iostream>
25
#include <sstream>
36

@@ -37,6 +40,12 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
3740
if (b.shape().size() != 2) {
3841
throw std::runtime_error("quantized_matmul: b must be a 2D array");
3942
}
43+
if (bits != 4) {
44+
throw std::runtime_error("quantized_matmul: bits must be 4");
45+
}
46+
if (group_size != 64) {
47+
throw std::runtime_error("quantized_matmul: group_size must be 64");
48+
}
4049
auto out_shape = a.shape();
4150
if (out_shape.size() != 2) {
4251
throw std::runtime_error("quantized_matmul: a must be a 2D array");
@@ -64,17 +73,61 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
6473
encoder.set_input_array(b);
6574
encoder.set_output_array(out);
6675

67-
// Launch the CPU kernel
68-
encoder.dispatch([a_ptr = a.data<uint32_t>(), a_shape = a.shape(), a_strides = a.strides(),
69-
b_ptr = b.data<float16_t>(), b_shape = b.shape(), b_strides = b.strides(),
70-
out_ptr = out.data<float16_t>(), scales_ptr = scales.data<float16_t>(),
71-
scales_shape = scales.shape(), scales_strides = scales.strides(),
72-
biases_ptr = biases.data<float16_t>(), biases_shape = biases.shape(),
73-
biases_strides = biases.strides(), group_size, bits]() {
74-
int M = a_shape[0];
75-
int N = a_shape[1];
76-
int K = b_shape[0]; // because we transposed b
76+
if (scales.shape() != biases.shape()) {
77+
throw std::runtime_error("quantized_matmul: scales and biases must have the same shape");
78+
}
79+
if (b.shape()[0] != scales.shape()[0]) {
80+
throw std::runtime_error("quantized_matmul: b must have the same number of rows as scales");
81+
}
82+
if (b.shape()[1] != scales.shape()[1] * group_size / 8) {
83+
throw std::runtime_error("quantized_matmul: a must have the same number of columns as scales");
84+
}
7785

86+
// Launch the CPU kernel
87+
encoder.dispatch([out_ptr = out.data<float16_t>(), out_shape = out.shape(), out_strides = out.strides(),
88+
a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b),
89+
scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() {
90+
int M = a.shape()[0];
91+
int N = a.shape()[1];
92+
int K = b.shape()[0];
93+
const int group_size = 64;
94+
const int bits = 4;
95+
const int group_per_row = N / group_size;
96+
const float16_t *a_ptr = a.data<float16_t>();
97+
const uint32_t *b_ptr = b.data<uint32_t>();
98+
const float16_t *scales_ptr = scales.data<float16_t>();
99+
const float16_t *biases_ptr = biases.data<float16_t>();
100+
uint32_t item_mask = (1 << bits) - 1;
101+
for (int i = 0; i < M; i++) {
102+
for (int k = 0; k < K; k++) {
103+
for (int group_idx = 0; group_idx < group_per_row; group_idx++) {
104+
int64_t scales_loc =
105+
mx::elem_to_loc(k * N / group_size + group_idx, scales.shape(), scales.strides());
106+
int64_t biases_loc =
107+
mx::elem_to_loc(k * N / group_size + group_idx, biases.shape(), biases.strides());
108+
float16_t sum = 0;
109+
float16_t scale = scales_ptr[scales_loc];
110+
float16_t bias = biases_ptr[biases_loc];
111+
const int packs_per_item = 32 / bits;
112+
for (int item_idx = 0; item_idx < group_size; item_idx += packs_per_item) {
113+
int64_t b_loc =
114+
mx::elem_to_loc((k * N + group_idx * group_size + item_idx) / 8, b.shape(), b.strides());
115+
uint32_t b_val = b_ptr[b_loc];
116+
uint8_t *b_bytes = reinterpret_cast<uint8_t *>(&b_val);
117+
for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) {
118+
int64_t a_loc = mx::elem_to_loc(i * N + group_idx * group_size + item_idx + pack_idx,
119+
a.shape(), a.strides());
120+
uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & item_mask;
121+
float16_t b = static_cast<float16_t>(item_val) * scale + bias;
122+
float16_t a = a_ptr[a_loc];
123+
sum += a * b;
124+
}
125+
}
126+
int64_t out_loc = mx::elem_to_loc(i * K + k, out_shape, out_strides);
127+
out_ptr[out_loc] = sum;
128+
}
129+
}
130+
}
78131
});
79132
}
80133

0 commit comments

Comments
 (0)