You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: book/src/week2-02-quantized-matmul.md
+31-21Lines changed: 31 additions & 21 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -43,7 +43,8 @@ We're using only ~4% of available compute!
43
43
44
44
### The Solution: Quantization
45
45
46
-
By compressing weights from 16 bits (bfloat16) to 4 bits (int4), we:
46
+
By compressing weights from 16-bit floating point (`float16` or `bfloat16`) to
47
+
4-bit integers (int4), we:
47
48
48
49
-**Reduce memory bandwidth by 4×**: 880 MB → ~220 MB per token
49
50
-**Improve arithmetic intensity by 4×**: 1.0 → ~4.0 FLOPs/Byte
@@ -58,7 +59,7 @@ Instead of quantizing all weights uniformly, we divide them into **groups** and
58
59
For a weight matrix $W$ of shape $(K, N)$, we divide each row into groups of size $G$. In this course we use Qwen3 MLX 4-bit weights, whose group size is fixed at 128:
59
60
60
61
```plain
61
-
Original weight matrix W: K × N (bfloat16)
62
+
Original weight matrix W: K × N (float16 or bfloat16)
62
63
63
64
Group size: G = 128
64
65
Number of groups per row = N / G
@@ -69,7 +70,10 @@ For each group of G consecutive values in a row:
69
70
3. Quantize each value using: quantized = round((value - bias) / scale)
70
71
```
71
72
72
-
All quantized matmul tests use `group_size = 128`, matching the Qwen3 MLX 4-bit weights used by the rest of the course.
73
+
All quantized matmul tests use `group_size = 128`, matching the Qwen3 MLX
74
+
4-bit weights used by the rest of the course. The tests cover both `float16`
75
+
and `bfloat16` because different MLX checkpoints store their scales, biases,
For efficient storage and computation, quantized weights are packed:
120
124
121
125
```plain
122
-
Original: K × N bfloat16 (2 bytes each) = 2KN bytes
126
+
Original: K × N float16/bfloat16 (2 bytes each) = 2KN bytes
123
127
Quantized: K × N int4 (0.5 bytes each) = 0.5KN bytes
124
128
125
129
Packing: 8 × 4-bit values fit in one uint32 (32 bits)
126
130
127
131
Weight matrix shape: K × N
128
132
Quantized storage shape: K × (N / 8) uint32
129
-
Scales shape: K × (N / G) bfloat16
130
-
Biases shape: K × (N / G) bfloat16
133
+
Scales shape: K × (N / G) float16/bfloat16
134
+
Biases shape: K × (N / G) float16/bfloat16
131
135
```
132
136
133
137
Example packing for 8 consecutive 4-bit values `[a, b, c, d, e, f, g, h]`:
@@ -150,9 +154,9 @@ Unpacking:
150
154
151
155
For standard matrix multiplication $C = AB^T$ where:
152
156
153
-
- $A$: shape $(M, N)$, bfloat16 (activations)
157
+
- $A$: shape $(M, N)$, float16 or bfloat16 (activations)
154
158
- $B$: shape $(K, N)$, **quantized** to int4 (weights)
155
-
- $C$: shape $(M, K)$, bfloat16 (output)
159
+
- $C$: shape $(M, K)$, same 16-bit dtype as $A$ (output)
156
160
157
161
Each element $C[i, k]$ is computed as:
158
162
@@ -186,13 +190,13 @@ This shows we can factor out the scale and bias per group, reducing the number o
186
190
187
191
```plain
188
192
Input:
189
-
A: M × N (bfloat16, activations)
193
+
A: M × N (float16 or bfloat16, activations)
190
194
B_quantized: K × (N/8) (uint32, packed weights)
191
-
scales: K × (N/G) (bfloat16)
192
-
biases: K × (N/G) (bfloat16)
195
+
scales: K × (N/G) (same dtype as A)
196
+
biases: K × (N/G) (same dtype as A)
193
197
194
198
Output:
195
-
C: M × K (bfloat16)
199
+
C: M × K (same dtype as A)
196
200
197
201
For each output element C[i, k]:
198
202
sum = 0 # float accumulator
@@ -211,7 +215,7 @@ For each output element C[i, k]:
211
215
a_value = A[i, g*G + p*8 + bit_offset/4]
212
216
sum = sum + a_value * b_value
213
217
214
-
C[i, k] = bfloat16(sum)
218
+
C[i, k] = same_dtype_as_A(sum)
215
219
```
216
220
217
221
## Task 1: Implement QuantizedWeights
@@ -225,8 +229,8 @@ First, familiarize yourself with the `QuantizedWeights` class, which stores quan
225
229
| Field | Shape | Description |
226
230
|-------|-------|-------------|
227
231
|`weight`| $(K, N/8)$ uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape $(K, N)$, and after packing, it becomes $(K, N/8)$. |
228
-
|`scales`| $(K, N/G)$ bfloat16 | Per-group scale factors for dequantization. Each group of $G$ consecutive values shares one scale. Recall: $\text{scale} = (v_{max} - v_{min}) / 15$ |
|`group_size`| int | Number of consecutive values that share the same scale/bias. For the Qwen3 MLX 4-bit weights used here, this is `128`. |
231
235
|`bits`| int | Quantization bit width (typically 4, meaning values are in range $[0, 15]$) |
232
236
@@ -251,7 +255,11 @@ You need to touch three files, all within the `tiny_llm_ext` namespace:
251
255
-**`bindings.cpp`** — Add an `m.def(...)` call to expose the function to Python.
252
256
-**`quantized_matmul.cpp`** — Implement the `quantized_matmul(...)` function (validate inputs, compute output shape, return a lazy `mx::array`) and the `eval_cpu` method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).
253
257
254
-
The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, dequantize each packed value, accumulate the products in `float`, and write the `bfloat16` result to the output.
258
+
The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, dequantize each packed value, accumulate the products in `float`, and write the result back as either `float16` or `bfloat16`, matching the input dtype.
259
+
260
+
Follow the `axpby` dtype-dispatch pattern here: write the CPU implementation as
261
+
a template, then dispatch with `mx::float16_t` or `mx::bfloat16_t` based on the
262
+
output dtype.
255
263
256
264
Don't forget to add `src/quantized_matmul.cpp` to `target_sources` in `CMakeLists.txt`.
257
265
@@ -276,22 +284,24 @@ In this task, you will write the Metal kernel for quantized matmul **and** wire
276
284
You need to implement one kernel entry in `quantized_matmul.metal`:
277
285
278
286
- Use a **one-thread-per-output-element** mapping: each thread computes `out[i, k]`.
279
-
- The kernel should use`bfloat16_t` inputs and outputs.
287
+
- The kernel should support both `half` and`bfloat16_t` inputs and outputs.
280
288
- Apply the same group-wise dequantization loop as the CPU version:
281
289
- Iterate over groups of 128 values
282
290
- Unpack int4 values from packed `uint32`
283
291
- Dequantize with `q * scale + bias`
284
-
- Accumulate the products in `float` and cast the final output back to `bfloat16_t`
292
+
- Accumulate the products in `float` and cast the final output back to the kernel dtype
The custom kernel only needs to handle `bits = 4` and `group_size = 128`. Use that group size to compute `groups_per_row` and the packed weight offsets.
296
+
Instantiate the same templated Metal kernel twice, once for `half` and once for
297
+
`bfloat16_t`, and select the matching kernel name in `eval_gpu`.
288
298
289
299
### GPU Dispatch
290
300
291
301
Complete the `eval_gpu` method in `quantized_matmul.cpp` to dispatch your Metal kernel. Follow the same pattern as `axpby`'s GPU dispatch:
292
302
293
303
1. Get the Metal device and command encoder from the stream.
294
-
2. Load the bfloat16 quantized matmul kernel from the Metal library.
304
+
2. Load the quantized matmul kernel matching the output dtype from the Metal library.
295
305
3. Set input/output buffers and dimension constants (`M`, `N`, `K`) on the encoder — make sure the buffer order matches your kernel signature.
296
306
4. Calculate a 2D thread group configuration: use `kernel->maxTotalThreadsPerThreadgroup()` to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K).
297
307
5. Dispatch with `dispatchThreadgroups`.
@@ -311,9 +321,9 @@ src/tiny_llm/qwen3_week2.py
311
321
312
322
Integrate your quantized matmul into the Week 2 Qwen3 model so that inference runs on quantized weights end-to-end.
313
323
314
-
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full bfloat16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize.
324
+
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full 16-bit matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize.
315
325
316
-
Qwen3 MLX quantized layers use **bfloat16** for the tensors involved in dequantization. Your kernel should take`scales`, `biases`, and activations as bfloat16. If you see `nan` or garbage output, a dtype mismatch is the most likely cause.
326
+
Qwen3 MLX quantized layers may use **float16** or **bfloat16** for the tensors involved in dequantization. Your kernel should accept`scales`, `biases`, and activations in either dtype, require them to match, and return the same dtype. If you see `nan` or garbage output, a dtype mismatch is the most likely cause.
317
327
318
328
Also keep the quantized layer's parameters. The model code should pass through `w.group_size` and `w.bits`; the extension should validate that they match the Qwen3 course assumptions: `group_size = 128` and `bits = 4`.
0 commit comments