Skip to content

Commit 6008387

Browse files
committed
Add Qwen3 MoE lesson
Signed-off-by: Connor1996 <zbk602423539@gmail.com>
1 parent e8da33f commit 6008387

18 files changed

Lines changed: 1003 additions & 19 deletions

batch-main.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
random.shuffle(prompts)
3636

3737
parser.add_argument("--solution", type=str, default="tiny_llm")
38+
parser.add_argument("--loader", type=str, choices=["week2", "week3"], default="week2")
3839
parser.add_argument("--device", type=str, default="gpu")
3940
parser.add_argument("--batch-size", type=int, default=5)
4041
parser.add_argument("--prefill-step", type=int, default=128)
42+
parser.add_argument("--max-seq-len", type=int, default=512)
4143
parser.add_argument("--enable-flash-attn", action="store_true")
4244
parser.add_argument("--enable-thinking", action="store_true")
4345
args = parser.parse_args()
@@ -57,11 +59,20 @@
5759
mlx_model, tokenizer = load(args.model)
5860

5961
with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu):
62+
dispatch_kwargs = {}
63+
if args.loader == "week2":
64+
dispatch_kwargs["enable_flash_attn"] = args.enable_flash_attn
65+
elif args.enable_flash_attn:
66+
print("--enable-flash-attn is only used by the week2 loader; ignoring it")
67+
6068
print(
61-
f"Using week2 loader with flash_attn={args.enable_flash_attn} thinking={args.enable_thinking} for {args.model}"
69+
f"Using {args.loader} loader with thinking={args.enable_thinking} for {args.model}"
6270
)
6371
tiny_llm_model = models.dispatch_model(
64-
args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn
72+
args.model,
73+
mlx_model,
74+
week=int(args.loader.removeprefix("week")),
75+
**dispatch_kwargs,
6576
)
6677
encoded_prompts = []
6778
for idx, prompt in enumerate(prompts):
@@ -81,6 +92,7 @@
8192
tiny_llm_model,
8293
tokenizer,
8394
encoded_prompts,
95+
max_seq_len=args.max_seq_len,
8496
batch_size=args.batch_size,
8597
prefill_step=args.prefill_step,
8698
)

book/src/SUMMARY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
- [Week 3: Serving]()
2222
- [Paged Attention, Part 1](./week3-01-paged-attention-part1.md)
2323
- [Paged Attention, Part 2](./week3-02-paged-attention-part2.md)
24+
- [Mixture of Experts](./week3-03-moe.md)
25+
- [Extended: Profiling](./week3-04-profiling.md)
2426

2527
---
2628

book/src/week2-02-quantized-matmul.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,9 @@ src/tiny_llm/qwen3_week2.py
321321

322322
Integrate your quantized matmul into the Week 2 Qwen3 model so that inference runs on quantized weights end-to-end.
323323

324-
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full 16-bit matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize.
324+
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full 16-bit matrix. Make sure the Week 1 loader still dequantizes these projection weights (since Week 1 layers expect plain `mx.array`), while the Week 2 loader keeps them quantized.
325+
326+
The input embedding is the main exception. `embed_tokens(input_ids)` is a row lookup, not a matrix multiplication, so it is not the operator implemented by `quantized_matmul`. For Week 2, first keep the input embedding on the existing `Embedding` path and focus quantized matmul on projection layers. If the model has a separate `lm_head`, that head is a normal linear projection and should use `quantized_linear`. If output weights are tied, `embedding.as_linear(h)` is the projection side of the embedding table; a later optimization can keep that table quantized, use `mx.quantized_matmul` for `as_linear`, and dequantize only the selected rows during lookup.
325327

326328
Qwen3 MLX quantized layers may use **float16** or **bfloat16** for the tensors involved in dequantization. Your kernel should accept `scales`, `biases`, and activations in either dtype, require them to match, and return the same dtype. If you see `nan` or garbage output, a dtype mismatch is the most likely cause.
327329

book/src/week3-03-moe.md

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Week 3 Day 3: Mixture of Experts
2+
3+
In this chapter, we will implement the feed-forward shape of **Mixture of
4+
Experts**, or **MoE**, for the Qwen3 family.
5+
6+
So far, every transformer block in tiny-llm has used the same dense Qwen3 MLP:
7+
8+
```plain
9+
x -> gate_proj
10+
x -> up_proj
11+
SiLU(gate_proj(x)) * up_proj(x) -> down_proj
12+
```
13+
14+
That is a SwiGLU MLP. Every token visits the same weights.
15+
16+
MoE changes only the feed-forward half of the transformer block. Instead of one
17+
dense MLP, the model owns many expert MLPs. A small router chooses which experts
18+
each token should use:
19+
20+
```plain
21+
token hidden state -> router -> top-k experts -> weighted expert outputs
22+
```
23+
24+
The attention path does not change. KV cache does not change. The sparse work is
25+
inside the MLP half of the block.
26+
27+
**Readings**
28+
29+
- [Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer](https://arxiv.org/abs/1701.06538)
30+
- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668)
31+
- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961)
32+
33+
## Dense MLP vs MoE MLP
34+
35+
The dense Qwen3 MLP from Week 1 has one set of weights:
36+
37+
```plain
38+
w_gate: hidden_dim, dim
39+
w_up: hidden_dim, dim
40+
w_down: dim, hidden_dim
41+
```
42+
43+
A Qwen3-MoE sparse block has a bank of those weights:
44+
45+
```plain
46+
expert_gate: num_experts, moe_hidden_dim, dim
47+
expert_up: num_experts, moe_hidden_dim, dim
48+
expert_down: num_experts, dim, moe_hidden_dim
49+
```
50+
51+
The router produces one score per expert:
52+
53+
```plain
54+
router_logits: B, L, num_experts
55+
router_probs: softmax(router_logits)
56+
```
57+
58+
Then the model picks `num_experts_per_tok` experts for each token:
59+
60+
```plain
61+
expert_ids: B, L, num_experts_per_tok
62+
expert_scores: B, L, num_experts_per_tok
63+
```
64+
65+
For each token, only those selected experts run. Their outputs are weighted and
66+
summed:
67+
68+
```plain
69+
output[token] = sum(score_i * expert_i(token))
70+
```
71+
72+
That is the central MoE idea: the model can contain many parameters, but each
73+
token activates only a small subset of them.
74+
75+
## Qwen3-MoE Shape
76+
77+
Qwen3-MoE keeps the same attention structure as Qwen3, including QK norm, GQA,
78+
RoPE, and the same KV cache interface. It replaces some dense MLP layers with a
79+
sparse MoE block.
80+
81+
The useful pieces are:
82+
83+
- `gate`: a router linear layer from hidden size to `num_experts`
84+
- `switch_mlp`: many SwiGLU experts with `moe_intermediate_size`
85+
- `num_experts_per_tok`: how many experts a token uses
86+
- `norm_topk_prob`: whether selected expert scores are renormalized
87+
- `decoder_sparse_step` and `mlp_only_layers`: which layers are sparse vs dense
88+
89+
There is no shared expert in the Qwen3-MoE block we are following. The sparse
90+
feed-forward output is just the weighted top-k expert mixture.
91+
92+
## The MLX Primitive
93+
94+
MLX does not give us a single high-level MoE block in `mlx.nn`. The relevant
95+
primitive for this chapter is `mx.gather_qmm`: it performs quantized matrix
96+
multiplication while selecting a different matrix for each row.
97+
98+
For MoE, that means:
99+
100+
```plain
101+
token rows: N, D
102+
expert ids: N
103+
weights: E, O, D packed as 4-bit QuantizedWeights
104+
output: N, O
105+
```
106+
107+
The row with `expert_ids[i] = e` should multiply by `weights[e]`.
108+
109+
When the expert ids are sorted, pass `sorted_indices=True`. Keep the inverse
110+
order from the sort so the result can be restored to the original token order.
111+
112+
## Router Step
113+
114+
The router is just a quantized linear layer:
115+
116+
```python
117+
router_logits = quantized_linear(x, w_router)
118+
router_probs = softmax(router_logits, axis=-1)
119+
```
120+
121+
For a batch of tokens:
122+
123+
```plain
124+
x: B, L, D
125+
router_logits: B, L, E
126+
router_probs: B, L, E
127+
```
128+
129+
where `E = num_experts`.
130+
131+
Qwen3-MoE then uses top-k selection:
132+
133+
```python
134+
expert_ids = argpartition(-router_probs, k)[:k]
135+
expert_scores = take_along_axis(router_probs, expert_ids)
136+
```
137+
138+
If `norm_topk_prob` is true, renormalize `expert_scores` so the selected scores
139+
sum to 1 for each token.
140+
141+
## Expert Step
142+
143+
Each expert is the same kind of SwiGLU MLP we already know:
144+
145+
```plain
146+
expert(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x))
147+
```
148+
149+
The implementation should build token-expert jobs, group them by expert, and run
150+
the expert projections with `mx.gather_qmm`:
151+
152+
```plain
153+
selected expert ids -> expanded token-expert rows
154+
expanded rows -> sort/group by expert id
155+
grouped expert rows -> grouped gate/up projection
156+
SiLU(gate) * up -> grouped down projection
157+
restore original token/top-k order -> weighted sum
158+
```
159+
160+
The reorder is part of the model implementation. It keeps all token rows for the
161+
same expert contiguous so the expert bank can be applied with grouped matrix
162+
multiplication.
163+
164+
## Task 1: Grouped Expert Linear
165+
166+
```
167+
src/tiny_llm/moe.py
168+
```
169+
170+
Implement `grouped_expert_linear`. This is the MLX-shaped core of MoE.
171+
172+
The function accepts:
173+
174+
```plain
175+
x: ..., D
176+
w_experts: QuantizedWeights for num_experts, output_dim, D
177+
expert_ids: ...
178+
```
179+
180+
It returns:
181+
182+
```plain
183+
out: ..., output_dim
184+
```
185+
186+
The implementation should:
187+
188+
```plain
189+
1. flatten token rows and expert ids,
190+
2. sort rows by expert id,
191+
3. call mx.gather_qmm with sorted_indices=True,
192+
4. restore the original order.
193+
```
194+
195+
For the grouped matmul, the shape should look like:
196+
197+
```python
198+
out = mx.gather_qmm(
199+
mx.expand_dims(grouped_rows, -2),
200+
w_experts.weight,
201+
w_experts.scales,
202+
w_experts.biases,
203+
lhs_indices=mx.arange(grouped_rows.shape[0]),
204+
rhs_indices=grouped_expert_ids,
205+
transpose=True,
206+
group_size=w_experts.group_size,
207+
bits=w_experts.bits,
208+
mode=w_experts.mode,
209+
sorted_indices=True,
210+
).squeeze(-2)
211+
```
212+
213+
This task maps to the same idea as `QuantizedSwitchLinear` in `mlx-lm`: each
214+
token row uses a different packed expert matrix, and the expert ids choose the
215+
right matrix.
216+
217+
## Task 2: Router Top-k
218+
219+
```
220+
src/tiny_llm/moe.py
221+
```
222+
223+
Implement `route_topk`. It accepts hidden states and router weights, then
224+
returns:
225+
226+
- router probabilities
227+
- selected expert ids
228+
- selected expert scores
229+
230+
Use `quantized_linear` and `softmax`. Use `mx.argpartition` to select the top
231+
`num_experts_per_tok` experts, then `mx.take_along_axis` to gather their scores.
232+
233+
Keep `norm_topk_prob` as an argument because Qwen3-MoE stores this behavior in
234+
the model config.
235+
236+
## Task 3: Qwen3 Sparse MoE Block
237+
238+
```
239+
src/tiny_llm/moe.py
240+
```
241+
242+
Implement `Moe` by composing Task 1 and Task 2:
243+
244+
```plain
245+
hidden states -> route_topk
246+
hidden states + expert ids -> grouped gate projection
247+
hidden states + expert ids -> grouped up projection
248+
SiLU(gate) * up -> grouped down projection
249+
weighted sum over num_experts_per_tok
250+
```
251+
252+
This completes the Qwen3-MoE sparse feed-forward block. There is no shared expert
253+
branch in this block.
254+
255+
## Task 4: Integrate Qwen3-MoE Layers
256+
257+
```
258+
src/tiny_llm/qwen3_week3.py
259+
src/tiny_llm/models.py
260+
```
261+
262+
Add a Qwen3-MoE loader path that reuses the Week 3 Qwen3 attention and paged KV
263+
cache behavior, but swaps selected block MLPs for `Moe`.
264+
265+
The model wrapper should:
266+
267+
- keep Qwen3 attention unchanged,
268+
- use regular `Qwen3MLP` for `mlp_only_layers`,
269+
- use `Moe` for sparse layers selected by
270+
`decoder_sparse_step`,
271+
- load router and expert weights as `QuantizedWeights` from the Qwen3-MoE MLX
272+
model,
273+
- preserve the same decode call shape:
274+
275+
```python
276+
logits = model(tokens, offset, cache)
277+
```
278+
279+
No scheduler API change in `src/tiny_llm/batch.py` is required for correctness.
280+
281+
Run this task through the normal generation entrypoints instead of adding a
282+
separate unit test. For example:
283+
284+
```bash
285+
hf download Qwen/Qwen3-30B-A3B-MLX-4bit
286+
287+
pdm run main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \
288+
--prompt "Give me a short introduction to mixture of experts."
289+
290+
pdm run batch-main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \
291+
--batch-size 2 --prefill-step 16
292+
```
293+
294+
{{#include copyright.md}}

0 commit comments

Comments
 (0)