Skip to content

Commit a8eaa52

Browse files
committed
implement matrix kernel; wait for mlx bug fix
Signed-off-by: Alex Chi <[email protected]>
1 parent b416219 commit a8eaa52

File tree

12 files changed

+181
-26
lines changed

12 files changed

+181
-26
lines changed

book/src/week2-overview.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ https://github.com/ml-explore/mlx/blob/main/mlx/backend/cpu/quantized.cpp
22
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
33
MLX uses INT4 W4A16
44
https://ml-explore.github.io/mlx/build/html/dev/extensions.html
5+
https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-metal/ggml-metal.metal
6+
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h#L962
57

68
pdm run ./build_ext.sh
79

build_ext.sh

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
set -e
4+
pdm run build-ext-ref
5+
cp src/extensions_ref/build/lib/tiny_llm_ext_ref/tiny_llm_ext_ref.metallib .venv/lib/python3.12/site-packages/mlx/lib/
6+
pdm run test-week2-ref -k 'week_2_day_2'

src/extensions_ref/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ if(MLX_BUILD_METAL)
5858
tiny_llm_ext_ref
5959
SOURCES
6060
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
61+
${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.metal
6162
INCLUDE_DIRS
6263
${PROJECT_SOURCE_DIR}
6364
${MLX_INCLUDE_DIRS}

src/extensions_ref/bindings.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using namespace nb::literals;
1212
NB_MODULE(_ext, m) {
1313
m.doc() = "tiny-llm extensions for MLX";
1414

15+
m.def("load_library", &tiny_llm_ext_ref::load_library, "device"_a, "path"_a);
16+
1517
m.def("axpby", &tiny_llm_ext_ref::axpby, "x"_a, "y"_a, "alpha"_a, "beta"_a, nb::kw_only(), "stream"_a = nb::none(),
1618
R"(
1719
Scale and sum two vectors element-wise

src/extensions_ref/build.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@
22
import shutil
33
from mlx import extension
44
from setuptools import Distribution
5+
import inspect
6+
import mlx
7+
import os
58

69
if __name__ == "__main__":
710
src_dir = Path(__file__).parent
811
distribution = Distribution(
912
{
1013
"name": "tiny_llm_ext_ref",
1114
"ext_modules": [extension.CMakeExtension("tiny_llm_ext_ref._ext")],
15+
"package_data": {"tiny_llm_ext_ref": ["*.so", "*.dylib", "*.metallib"]},
1216
}
1317
)
1418
cmd = extension.CMakeBuild(distribution)
1519
cmd.initialize_options()
1620
cmd.build_temp = Path("build")
1721
cmd.build_lib = Path("build") / "lib"
18-
cmd.inplace = False # we do the copy by ourselves
22+
cmd.inplace = True
1923
cmd.ensure_finalized()
2024
cmd.run()
21-
for output in cmd.get_outputs():
22-
output = Path(output)
23-
relative_extension = src_dir / output.relative_to(cmd.build_lib)
24-
shutil.copyfile(output, relative_extension)
25-
print(f"Copied {output} to {relative_extension}")

src/extensions_ref/src/quantized_matmul.cpp

+85-20
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
5454
if (!transpose_b) {
5555
throw std::runtime_error("quantized_matmul: b must be transposed");
5656
}
57+
58+
if (scales.shape() != biases.shape()) {
59+
throw std::runtime_error("quantized_matmul: scales and biases must have the same shape");
60+
}
61+
if (b.shape()[0] != scales.shape()[0]) {
62+
throw std::runtime_error("quantized_matmul: b must have the same number of rows as scales");
63+
}
64+
if (b.shape()[1] != scales.shape()[1] * group_size / 8) {
65+
throw std::runtime_error("quantized_matmul: a must have the same number of columns as scales");
66+
}
67+
5768
return mx::array(
5869
/* const mx::Shape& shape = */ out_shape,
5970
/* mx::Dtype dtype = */ mx::float16,
@@ -73,14 +84,11 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
7384
encoder.set_input_array(b);
7485
encoder.set_output_array(out);
7586

76-
if (scales.shape() != biases.shape()) {
77-
throw std::runtime_error("quantized_matmul: scales and biases must have the same shape");
87+
if (!a.flags().row_contiguous) {
88+
throw std::runtime_error("quantized_matmul: a must be contiguous");
7889
}
79-
if (b.shape()[0] != scales.shape()[0]) {
80-
throw std::runtime_error("quantized_matmul: b must have the same number of rows as scales");
81-
}
82-
if (b.shape()[1] != scales.shape()[1] * group_size / 8) {
83-
throw std::runtime_error("quantized_matmul: a must have the same number of columns as scales");
90+
if (!b.flags().row_contiguous) {
91+
throw std::runtime_error("quantized_matmul: b must be contiguous");
8492
}
8593

8694
// Launch the CPU kernel
@@ -100,32 +108,32 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
100108
uint32_t item_mask = (1 << bits) - 1;
101109
for (int i = 0; i < M; i++) {
102110
for (int k = 0; k < K; k++) {
111+
float sum = 0;
103112
for (int group_idx = 0; group_idx < group_per_row; group_idx++) {
104113
int64_t scales_loc =
105-
mx::elem_to_loc(k * N / group_size + group_idx, scales.shape(), scales.strides());
114+
mx::elem_to_loc(k * group_per_row + group_idx, scales.shape(), scales.strides());
106115
int64_t biases_loc =
107-
mx::elem_to_loc(k * N / group_size + group_idx, biases.shape(), biases.strides());
108-
float16_t sum = 0;
116+
mx::elem_to_loc(k * group_per_row + group_idx, biases.shape(), biases.strides());
109117
float16_t scale = scales_ptr[scales_loc];
110118
float16_t bias = biases_ptr[biases_loc];
119+
int64_t b_loc = mx::elem_to_loc((k * N + group_idx * group_size) / 8, b.shape(), b.strides());
120+
int64_t a_loc = mx::elem_to_loc(i * N + group_idx * group_size, a.shape(), a.strides());
111121
const int packs_per_item = 32 / bits;
112122
for (int item_idx = 0; item_idx < group_size; item_idx += packs_per_item) {
113-
int64_t b_loc =
114-
mx::elem_to_loc((k * N + group_idx * group_size + item_idx) / 8, b.shape(), b.strides());
115123
uint32_t b_val = b_ptr[b_loc];
116124
uint8_t *b_bytes = reinterpret_cast<uint8_t *>(&b_val);
117125
for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) {
118-
int64_t a_loc = mx::elem_to_loc(i * N + group_idx * group_size + item_idx + pack_idx,
119-
a.shape(), a.strides());
120126
uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & item_mask;
121-
float16_t b = static_cast<float16_t>(item_val) * scale + bias;
122-
float16_t a = a_ptr[a_loc];
127+
float b = static_cast<float>(item_val) * scale + bias;
128+
float a = a_ptr[a_loc];
123129
sum += a * b;
130+
a_loc += 1;
124131
}
132+
b_loc += 1;
125133
}
126-
int64_t out_loc = mx::elem_to_loc(i * K + k, out_shape, out_strides);
127-
out_ptr[out_loc] = sum;
128134
}
135+
int64_t out_loc = mx::elem_to_loc(i * K + k, out_shape, out_strides);
136+
out_ptr[out_loc] = static_cast<float16_t>(sum);
129137
}
130138
}
131139
});
@@ -142,8 +150,65 @@ void QuantizedMatmul::eval_cpu(const std::vector<mx::array> &inputs, std::vector
142150
quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream());
143151
}
144152

145-
void QuantizedMatmul::eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &out) {
146-
throw std::runtime_error("QuantizedMatmul has no GPU implementation.");
153+
void load_library(mx::Device d, const char* path) {
154+
auto &md = mx::metal::device(d);
155+
md.register_library("tiny_llm_ext_ref", path);
156+
}
157+
158+
void QuantizedMatmul::eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
159+
auto &scales = inputs[0];
160+
auto &biases = inputs[1];
161+
auto &a = inputs[2];
162+
auto &b = inputs[3];
163+
auto &out = outputs[0];
164+
165+
auto &s = stream();
166+
auto &d = mx::metal::device(s.device);
167+
out.set_data(mx::allocator::malloc(out.nbytes()));
168+
169+
// Make a kernel from this metal library
170+
auto kernel = d.get_kernel("quantized_matmul_w4a16_g64", "tiny_llm_ext_ref");
171+
172+
// Prepare to encode kernel
173+
auto &compute_encoder = d.get_command_encoder(s.index);
174+
compute_encoder.set_compute_pipeline_state(kernel);
175+
176+
// Kernel parameters are registered with buffer indices corresponding to
177+
// those in the kernel declaration at axpby.metal
178+
int ndim = out.ndim();
179+
180+
// Encode input arrays to kernel
181+
compute_encoder.set_input_array(scales, 0);
182+
compute_encoder.set_input_array(biases, 1);
183+
compute_encoder.set_input_array(a, 2);
184+
compute_encoder.set_input_array(b, 3);
185+
// Encode output arrays to kernel
186+
compute_encoder.set_output_array(out, 4);
187+
188+
189+
if (!a.flags().row_contiguous) {
190+
throw std::runtime_error("quantized_matmul: a must be contiguous");
191+
}
192+
if (!b.flags().row_contiguous) {
193+
throw std::runtime_error("quantized_matmul: b must be contiguous");
194+
}
195+
196+
int M = a.shape()[0];
197+
int N = a.shape()[1];
198+
int K = b.shape()[0];
199+
200+
// Encode matrix parameters
201+
compute_encoder.set_bytes(M, 5);
202+
compute_encoder.set_bytes(N, 6);
203+
compute_encoder.set_bytes(K, 7);
204+
205+
size_t tgp_size = kernel->maxTotalThreadsPerThreadgroup();
206+
MTL::Size num_threadgroups = MTL::Size((M * K + tgp_size - 1) / tgp_size, 1, 1);
207+
MTL::Size num_threads_per_group = MTL::Size(tgp_size, 1, 1);
208+
209+
// Launch the grid with the given number of threads divided among
210+
// the given threadgroups
211+
compute_encoder.dispatch_threadgroups(num_threadgroups, num_threads_per_group);
147212
}
148213

149214
bool QuantizedMatmul::is_equivalent(const Primitive &other) const {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
[[kernel]] void quantized_matmul_w4a16_g64(
2+
device const half* scales [[buffer(0)]],
3+
device const half* biases [[buffer(1)]],
4+
device const half* a [[buffer(2)]],
5+
device const uint32_t* b [[buffer(3)]],
6+
device half* out [[buffer(4)]],
7+
device const int &M [[buffer(5)]],
8+
device const int &N [[buffer(6)]],
9+
device const int &K [[buffer(7)]],
10+
uint2 groupId [[threadgroup_position_in_grid]],
11+
uint2 threadId [[thread_position_in_threadgroup]],
12+
uint2 threads_per_threadgroup [[threads_per_threadgroup]]) {
13+
const int group_size = 64;
14+
const int bits = 4;
15+
const int packs_per_item = 32 / bits;
16+
const int item_mask = (1 << bits) - 1;
17+
const int groups_per_row = N / group_size;
18+
// Each threadgroup processes an element in the output matrix
19+
const int64_t idx = groupId.x * threads_per_threadgroup.x + threadId.x;
20+
const int64_t i = idx / K;
21+
const int64_t k = idx % K;
22+
float sum = 0;
23+
for (int group_idx = 0; group_idx < groups_per_row; group_idx++) {
24+
const int64_t scales_biases_loc = k * groups_per_row + group_idx;
25+
const float scale = scales[scales_biases_loc];
26+
const float bias = biases[scales_biases_loc];
27+
int64_t b_loc = (k * N + group_idx * group_size) / 8;
28+
int64_t a_loc = i * N + group_idx * group_size;
29+
for (int item_idx = 0; item_idx < group_size; item_idx += packs_per_item) {
30+
const uint32_t b_val = b[b_loc];
31+
thread const uint32_t *b_val_ref = &b_val;
32+
thread const uint8_t *b_bytes = reinterpret_cast<thread const uint8_t *>(b_val_ref);
33+
for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) {
34+
const uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & item_mask;
35+
const float b_val = static_cast<float>(item_val) * scale + bias;
36+
const float a_val = a[a_loc];
37+
sum += a_val * b_val;
38+
a_loc += 1;
39+
}
40+
b_loc += 1;
41+
}
42+
}
43+
if (i < M && k < K) {
44+
out[i * K + k] = sum;
45+
}
46+
}

src/extensions_ref/src/tiny_llm_ext.h

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace mx = mlx::core;
77

88
namespace tiny_llm_ext_ref {
99

10+
void load_library(mx::Device d, const char* path);
11+
1012
mx::array quantized_matmul(const mx::array &scales, // Input array scales
1113
const mx::array &biases, // Input array biases
1214
const int group_size, // Group size

src/extensions_ref/test.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from tiny_llm_ext_ref import quantized_matmul
2+
import mlx.core as mx
3+
import numpy as np
4+
5+
precision = np.float16
6+
input = mx.array(np.random.randn(3, 64).astype(precision))
7+
weight = mx.array(np.random.randn(5, 64).astype(precision))
8+
w_q, scales, biases = mx.quantize(weight)
9+
user_out = quantized_matmul(
10+
scales=scales,
11+
biases=biases,
12+
group_size=64,
13+
bits=4,
14+
a=input,
15+
b=w_q,
16+
transpose_b=True,
17+
)
18+
print(user_out)

src/extensions_ref/tiny_llm_ext_ref/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
import mlx.core as mx
44

55
from ._ext import *
6+
from pathlib import Path
7+
8+
current_path = Path(__file__).parent
9+
load_library(mx.gpu, str(current_path))

src/tiny_llm_week2_ref/quantize.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def quantized_matmul(
2626
) -> mx.array:
2727
*N, D = a.shape
2828
a = a.reshape(-1, D)
29+
a = mx.contiguous(a)
30+
b = mx.contiguous(b)
2931
return tiny_llm_ext_ref.quantized_matmul(
3032
scales, biases, group_size, bits, a, b, transpose_b
3133
).reshape(*N, -1)

tests/test_week_2_day_2.py

+8
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,11 @@ def test_task_1_quantized_matmul_simple_f16_cpu():
4242

4343
def test_task_1_quantized_matmul_complex_f16_cpu():
4444
quantized_matmul_helper(mx.cpu, False, np.float16)
45+
46+
47+
def test_task_2_quantized_matmul_simple_f16_gpu():
48+
quantized_matmul_helper(mx.gpu, True, np.float16)
49+
50+
51+
def test_task_2_quantized_matmul_complex_f16_gpu():
52+
quantized_matmul_helper(mx.gpu, False, np.float16)

0 commit comments

Comments
 (0)