Skip to content

Commit cc850bc

Browse files
committed
Support float16 quantized matmul
Signed-off-by: Connor1996 <zbk602423539@gmail.com>
1 parent 7fac5ff commit cc850bc

5 files changed

Lines changed: 98 additions & 58 deletions

File tree

book/src/week2-02-quantized-matmul.md

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ We're using only ~4% of available compute!
4343

4444
### The Solution: Quantization
4545

46-
By compressing weights from 16 bits (bfloat16) to 4 bits (int4), we:
46+
By compressing weights from 16-bit floating point (`float16` or `bfloat16`) to
47+
4-bit integers (int4), we:
4748

4849
- **Reduce memory bandwidth by 4×**: 880 MB → ~220 MB per token
4950
- **Improve arithmetic intensity by 4×**: 1.0 → ~4.0 FLOPs/Byte
@@ -58,7 +59,7 @@ Instead of quantizing all weights uniformly, we divide them into **groups** and
5859
For a weight matrix $W$ of shape $(K, N)$, we divide each row into groups of size $G$. In this course we use Qwen3 MLX 4-bit weights, whose group size is fixed at 128:
5960

6061
```plain
61-
Original weight matrix W: K × N (bfloat16)
62+
Original weight matrix W: K × N (float16 or bfloat16)
6263
6364
Group size: G = 128
6465
Number of groups per row = N / G
@@ -69,7 +70,10 @@ For each group of G consecutive values in a row:
6970
3. Quantize each value using: quantized = round((value - bias) / scale)
7071
```
7172

72-
All quantized matmul tests use `group_size = 128`, matching the Qwen3 MLX 4-bit weights used by the rest of the course.
73+
All quantized matmul tests use `group_size = 128`, matching the Qwen3 MLX
74+
4-bit weights used by the rest of the course. The tests cover both `float16`
75+
and `bfloat16` because different MLX checkpoints store their scales, biases,
76+
and activations in different 16-bit dtypes.
7377

7478
### Affine Quantization
7579

@@ -119,15 +123,15 @@ Quantized: [0, 2, 7, 10, 15] (4 bits each)
119123
For efficient storage and computation, quantized weights are packed:
120124

121125
```plain
122-
Original: K × N bfloat16 (2 bytes each) = 2KN bytes
126+
Original: K × N float16/bfloat16 (2 bytes each) = 2KN bytes
123127
Quantized: K × N int4 (0.5 bytes each) = 0.5KN bytes
124128
125129
Packing: 8 × 4-bit values fit in one uint32 (32 bits)
126130
127131
Weight matrix shape: K × N
128132
Quantized storage shape: K × (N / 8) uint32
129-
Scales shape: K × (N / G) bfloat16
130-
Biases shape: K × (N / G) bfloat16
133+
Scales shape: K × (N / G) float16/bfloat16
134+
Biases shape: K × (N / G) float16/bfloat16
131135
```
132136

133137
Example packing for 8 consecutive 4-bit values `[a, b, c, d, e, f, g, h]`:
@@ -150,9 +154,9 @@ Unpacking:
150154

151155
For standard matrix multiplication $C = AB^T$ where:
152156

153-
- $A$: shape $(M, N)$, bfloat16 (activations)
157+
- $A$: shape $(M, N)$, float16 or bfloat16 (activations)
154158
- $B$: shape $(K, N)$, **quantized** to int4 (weights)
155-
- $C$: shape $(M, K)$, bfloat16 (output)
159+
- $C$: shape $(M, K)$, same 16-bit dtype as $A$ (output)
156160

157161
Each element $C[i, k]$ is computed as:
158162

@@ -186,13 +190,13 @@ This shows we can factor out the scale and bias per group, reducing the number o
186190

187191
```plain
188192
Input:
189-
A: M × N (bfloat16, activations)
193+
A: M × N (float16 or bfloat16, activations)
190194
B_quantized: K × (N/8) (uint32, packed weights)
191-
scales: K × (N/G) (bfloat16)
192-
biases: K × (N/G) (bfloat16)
195+
scales: K × (N/G) (same dtype as A)
196+
biases: K × (N/G) (same dtype as A)
193197
194198
Output:
195-
C: M × K (bfloat16)
199+
C: M × K (same dtype as A)
196200
197201
For each output element C[i, k]:
198202
sum = 0 # float accumulator
@@ -211,7 +215,7 @@ For each output element C[i, k]:
211215
a_value = A[i, g*G + p*8 + bit_offset/4]
212216
sum = sum + a_value * b_value
213217
214-
C[i, k] = bfloat16(sum)
218+
C[i, k] = same_dtype_as_A(sum)
215219
```
216220

217221
## Task 1: Implement QuantizedWeights
@@ -225,8 +229,8 @@ First, familiarize yourself with the `QuantizedWeights` class, which stores quan
225229
| Field | Shape | Description |
226230
|-------|-------|-------------|
227231
| `weight` | $(K, N/8)$ uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape $(K, N)$, and after packing, it becomes $(K, N/8)$. |
228-
| `scales` | $(K, N/G)$ bfloat16 | Per-group scale factors for dequantization. Each group of $G$ consecutive values shares one scale. Recall: $\text{scale} = (v_{max} - v_{min}) / 15$ |
229-
| `biases` | $(K, N/G)$ bfloat16 | Per-group bias (offset) for dequantization. Recall: $\text{bias} = v_{min}$ |
232+
| `scales` | $(K, N/G)$ float16/bfloat16 | Per-group scale factors for dequantization. Each group of $G$ consecutive values shares one scale. Recall: $\text{scale} = (v_{max} - v_{min}) / 15$ |
233+
| `biases` | $(K, N/G)$ float16/bfloat16 | Per-group bias (offset) for dequantization. Recall: $\text{bias} = v_{min}$ |
230234
| `group_size` | int | Number of consecutive values that share the same scale/bias. For the Qwen3 MLX 4-bit weights used here, this is `128`. |
231235
| `bits` | int | Quantization bit width (typically 4, meaning values are in range $[0, 15]$) |
232236

@@ -251,7 +255,11 @@ You need to touch three files, all within the `tiny_llm_ext` namespace:
251255
- **`bindings.cpp`** — Add an `m.def(...)` call to expose the function to Python.
252256
- **`quantized_matmul.cpp`** — Implement the `quantized_matmul(...)` function (validate inputs, compute output shape, return a lazy `mx::array`) and the `eval_cpu` method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).
253257

254-
The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, dequantize each packed value, accumulate the products in `float`, and write the `bfloat16` result to the output.
258+
The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, dequantize each packed value, accumulate the products in `float`, and write the result back as either `float16` or `bfloat16`, matching the input dtype.
259+
260+
Follow the `axpby` dtype-dispatch pattern here: write the CPU implementation as
261+
a template, then dispatch with `mx::float16_t` or `mx::bfloat16_t` based on the
262+
output dtype.
255263

256264
Don't forget to add `src/quantized_matmul.cpp` to `target_sources` in `CMakeLists.txt`.
257265

@@ -276,22 +284,24 @@ In this task, you will write the Metal kernel for quantized matmul **and** wire
276284
You need to implement one kernel entry in `quantized_matmul.metal`:
277285

278286
- Use a **one-thread-per-output-element** mapping: each thread computes `out[i, k]`.
279-
- The kernel should use `bfloat16_t` inputs and outputs.
287+
- The kernel should support both `half` and `bfloat16_t` inputs and outputs.
280288
- Apply the same group-wise dequantization loop as the CPU version:
281289
- Iterate over groups of 128 values
282290
- Unpack int4 values from packed `uint32`
283291
- Dequantize with `q * scale + bias`
284-
- Accumulate the products in `float` and cast the final output back to `bfloat16_t`
292+
- Accumulate the products in `float` and cast the final output back to the kernel dtype
285293
- Add boundary checks (`i < M`, `k < K`) before writing output.
286294

287295
The custom kernel only needs to handle `bits = 4` and `group_size = 128`. Use that group size to compute `groups_per_row` and the packed weight offsets.
296+
Instantiate the same templated Metal kernel twice, once for `half` and once for
297+
`bfloat16_t`, and select the matching kernel name in `eval_gpu`.
288298

289299
### GPU Dispatch
290300

291301
Complete the `eval_gpu` method in `quantized_matmul.cpp` to dispatch your Metal kernel. Follow the same pattern as `axpby`'s GPU dispatch:
292302

293303
1. Get the Metal device and command encoder from the stream.
294-
2. Load the bfloat16 quantized matmul kernel from the Metal library.
304+
2. Load the quantized matmul kernel matching the output dtype from the Metal library.
295305
3. Set input/output buffers and dimension constants (`M`, `N`, `K`) on the encoder — make sure the buffer order matches your kernel signature.
296306
4. Calculate a 2D thread group configuration: use `kernel->maxTotalThreadsPerThreadgroup()` to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K).
297307
5. Dispatch with `dispatchThreadgroups`.
@@ -311,9 +321,9 @@ src/tiny_llm/qwen3_week2.py
311321

312322
Integrate your quantized matmul into the Week 2 Qwen3 model so that inference runs on quantized weights end-to-end.
313323

314-
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full bfloat16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize.
324+
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full 16-bit matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize.
315325

316-
Qwen3 MLX quantized layers use **bfloat16** for the tensors involved in dequantization. Your kernel should take `scales`, `biases`, and activations as bfloat16. If you see `nan` or garbage output, a dtype mismatch is the most likely cause.
326+
Qwen3 MLX quantized layers may use **float16** or **bfloat16** for the tensors involved in dequantization. Your kernel should accept `scales`, `biases`, and activations in either dtype, require them to match, and return the same dtype. If you see `nan` or garbage output, a dtype mismatch is the most likely cause.
317327

318328
Also keep the quantized layer's parameters. The model code should pass through `w.group_size` and `w.bits`; the extension should validate that they match the Qwen3 course assumptions: `group_size = 128` and `bits = 4`.
319329

src/extensions_ref/src/quantized_matmul.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
2323
const bool transpose_b, // Whether to transpose b
2424
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
2525
) {
26-
if (scales.dtype() != mx::bfloat16) {
27-
throw std::runtime_error("quantized_matmul: scales must be bfloat16");
26+
if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16) {
27+
throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16");
2828
}
2929
if (scales.dtype() != biases.dtype()) {
3030
throw std::runtime_error("quantized_matmul: scales and biases must be the same dtype");
@@ -81,6 +81,7 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
8181
/* const std::vector<mx::array>& inputs = */ {scales, biases, a, b});
8282
}
8383

84+
template <typename T>
8485
void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, const mx::array &a, const mx::array &b,
8586
mx::array &out, mx::Stream stream) {
8687
out.set_data(mx::allocator::malloc(out.nbytes()));
@@ -99,7 +100,7 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
99100
throw std::runtime_error("quantized_matmul: b must be contiguous");
100101
}
101102

102-
encoder.dispatch([out_ptr = out.data<mx::bfloat16_t>(), out_shape = out.shape(), out_strides = out.strides(),
103+
encoder.dispatch([out_ptr = out.data<T>(), out_shape = out.shape(), out_strides = out.strides(),
103104
a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b),
104105
scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() {
105106
int M = a.shape()[0];
@@ -108,10 +109,10 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
108109
const int group_size = 128;
109110
const int bits = 4;
110111
const int group_per_row = N / group_size;
111-
const mx::bfloat16_t *a_ptr = a.data<mx::bfloat16_t>();
112+
const T *a_ptr = a.data<T>();
112113
const uint32_t *b_ptr = b.data<uint32_t>();
113-
const mx::bfloat16_t *scales_ptr = scales.data<mx::bfloat16_t>();
114-
const mx::bfloat16_t *biases_ptr = biases.data<mx::bfloat16_t>();
114+
const T *scales_ptr = scales.data<T>();
115+
const T *biases_ptr = biases.data<T>();
115116
uint32_t item_mask = (1 << bits) - 1;
116117
for (int i = 0; i < M; i++) {
117118
for (int k = 0; k < K; k++) {
@@ -121,8 +122,8 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
121122
mx::elem_to_loc(k * group_per_row + group_idx, scales.shape(), scales.strides());
122123
int64_t biases_loc =
123124
mx::elem_to_loc(k * group_per_row + group_idx, biases.shape(), biases.strides());
124-
mx::bfloat16_t scale = scales_ptr[scales_loc];
125-
mx::bfloat16_t bias = biases_ptr[biases_loc];
125+
T scale = scales_ptr[scales_loc];
126+
T bias = biases_ptr[biases_loc];
126127
int64_t b_loc = mx::elem_to_loc((k * N + group_idx * group_size) / 8, b.shape(), b.strides());
127128
int64_t a_loc = mx::elem_to_loc(i * N + group_idx * group_size, a.shape(), a.strides());
128129
const int packs_per_item = 32 / bits;
@@ -140,7 +141,7 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
140141
}
141142
}
142143
int64_t out_loc = mx::elem_to_loc(i * K + k, out_shape, out_strides);
143-
out_ptr[out_loc] = static_cast<mx::bfloat16_t>(sum);
144+
out_ptr[out_loc] = static_cast<T>(sum);
144145
}
145146
}
146147
});
@@ -153,7 +154,13 @@ void QuantizedMatmul::eval_cpu(const std::vector<mx::array> &inputs, std::vector
153154
auto &b = inputs[3];
154155
auto &out = outputs[0];
155156

156-
quantized_matmul_impl(scales, biases, a, b, out, stream());
157+
if (out.dtype() == mx::float16) {
158+
return quantized_matmul_impl<mx::float16_t>(scales, biases, a, b, out, stream());
159+
} else if (out.dtype() == mx::bfloat16) {
160+
return quantized_matmul_impl<mx::bfloat16_t>(scales, biases, a, b, out, stream());
161+
} else {
162+
throw std::runtime_error("quantized_matmul: output must be float16 or bfloat16");
163+
}
157164
}
158165

159166
void QuantizedMatmul::eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
@@ -169,7 +176,8 @@ void QuantizedMatmul::eval_gpu(const std::vector<mx::array> &inputs, std::vector
169176

170177
// Make a kernel from this metal library
171178
auto library = d.get_library("tiny_llm_ext_ref");
172-
const char* kernel_name = "quantized_matmul_w4a16_g128_bf16";
179+
const char* kernel_name = out.dtype() == mx::float16 ? "quantized_matmul_w4a16_g128_f16"
180+
: "quantized_matmul_w4a16_g128_bf16";
173181
auto kernel = d.get_kernel(kernel_name, library);
174182

175183
// Prepare to encode kernel

src/extensions_ref/src/quantized_matmul.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@ template <typename T>
5151
}
5252
}
5353

54+
instantiate_kernel("quantized_matmul_w4a16_g128_f16", quantized_matmul_w4a16_g128, half);
5455
instantiate_kernel("quantized_matmul_w4a16_g128_bf16", quantized_matmul_w4a16_g128, bfloat16_t);

src/extensions_ref/test.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import mlx.core as mx
33
import numpy as np
44

5-
input = mx.array(np.random.randn(3, 128)).astype(mx.bfloat16)
6-
weight = mx.array(np.random.randn(5, 128)).astype(mx.bfloat16)
7-
w_q, scales, biases = mx.quantize(weight, group_size=128, bits=4)
8-
user_out = quantized_matmul(
9-
scales=scales,
10-
biases=biases,
11-
group_size=128,
12-
bits=4,
13-
a=input,
14-
b=w_q,
15-
transpose_b=True,
16-
)
17-
print(user_out)
5+
for dtype in (mx.float16, mx.bfloat16):
6+
input = mx.array(np.random.randn(3, 128)).astype(dtype)
7+
weight = mx.array(np.random.randn(5, 128)).astype(dtype)
8+
w_q, scales, biases = mx.quantize(weight, group_size=128, bits=4)
9+
user_out = quantized_matmul(
10+
scales=scales,
11+
biases=biases,
12+
group_size=128,
13+
bits=4,
14+
a=input,
15+
b=w_q,
16+
transpose_b=True,
17+
)
18+
print(dtype, user_out)

tests_refsol/test_week_2_day_2.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import mlx.core as mx
2-
from .tiny_llm_base import *
3-
from .utils import *
2+
from .tiny_llm_base import quantized_matmul
3+
from .utils import assert_allclose
44

55

6-
def quantized_matmul_helper(stream: mx.Stream, identity_matrix: bool):
6+
def quantized_matmul_helper(
7+
stream: mx.Stream,
8+
precision: mx.Dtype,
9+
identity_matrix: bool,
10+
):
711
with mx.stream(stream):
812
group_size = 128
913
if identity_matrix:
10-
input = mx.eye(group_size, dtype=mx.bfloat16)
14+
input = mx.eye(group_size, dtype=precision)
1115
else:
12-
input = mx.random.normal(shape=(3, group_size), dtype=mx.bfloat16)
13-
weight = mx.random.normal(shape=(5, group_size), dtype=mx.bfloat16)
16+
input = mx.random.normal(shape=(3, group_size), dtype=precision)
17+
weight = mx.random.normal(shape=(5, group_size), dtype=precision)
1418
w_q, scales, biases = mx.quantize(weight, group_size=group_size, bits=4)
1519
user_out = quantized_matmul(
1620
scales=scales,
@@ -31,28 +35,44 @@ def quantized_matmul_helper(stream: mx.Stream, identity_matrix: bool):
3135
transpose=True,
3236
)
3337
if identity_matrix:
34-
assert_allclose(user_out, ref_out, mx.bfloat16)
38+
assert_allclose(user_out, ref_out, precision)
3539
else:
3640
assert_allclose(
3741
user_out,
3842
ref_out,
39-
mx.bfloat16,
43+
precision,
4044
atol=5.0e-1,
41-
message="quantized matmul bf16 comparison",
45+
message=f"quantized matmul {precision} comparison",
4246
)
4347

4448

4549
def test_task_2_quantized_matmul_simple_bf16_cpu():
46-
quantized_matmul_helper(mx.cpu, True)
50+
quantized_matmul_helper(mx.cpu, mx.bfloat16, True)
4751

4852

4953
def test_task_2_quantized_matmul_complex_bf16_cpu():
50-
quantized_matmul_helper(mx.cpu, False)
54+
quantized_matmul_helper(mx.cpu, mx.bfloat16, False)
55+
56+
57+
def test_task_2_quantized_matmul_simple_f16_cpu():
58+
quantized_matmul_helper(mx.cpu, mx.float16, True)
59+
60+
61+
def test_task_2_quantized_matmul_complex_f16_cpu():
62+
quantized_matmul_helper(mx.cpu, mx.float16, False)
5163

5264

5365
def test_task_3_quantized_matmul_simple_bf16_gpu():
54-
quantized_matmul_helper(mx.gpu, True)
66+
quantized_matmul_helper(mx.gpu, mx.bfloat16, True)
5567

5668

5769
def test_task_3_quantized_matmul_complex_bf16_gpu():
58-
quantized_matmul_helper(mx.gpu, False)
70+
quantized_matmul_helper(mx.gpu, mx.bfloat16, False)
71+
72+
73+
def test_task_3_quantized_matmul_simple_f16_gpu():
74+
quantized_matmul_helper(mx.gpu, mx.float16, True)
75+
76+
77+
def test_task_3_quantized_matmul_complex_f16_gpu():
78+
quantized_matmul_helper(mx.gpu, mx.float16, False)

0 commit comments

Comments
 (0)