|
| 1 | +# Week 3 Day 3: Mixture of Experts |
| 2 | + |
| 3 | +In this chapter, we will implement the feed-forward shape of **Mixture of |
| 4 | +Experts**, or **MoE**, for the Qwen3 family. |
| 5 | + |
| 6 | +So far, every transformer block in tiny-llm has used the same dense Qwen3 MLP: |
| 7 | + |
| 8 | +```plain |
| 9 | +x -> gate_proj |
| 10 | +x -> up_proj |
| 11 | +SiLU(gate_proj(x)) * up_proj(x) -> down_proj |
| 12 | +``` |
| 13 | + |
| 14 | +That is a SwiGLU MLP. Every token visits the same weights. |
| 15 | + |
| 16 | +MoE changes only the feed-forward half of the transformer block. Instead of one |
| 17 | +dense MLP, the model owns many expert MLPs. A small router chooses which experts |
| 18 | +each token should use: |
| 19 | + |
| 20 | +```plain |
| 21 | +token hidden state -> router -> top-k experts -> weighted expert outputs |
| 22 | +``` |
| 23 | + |
| 24 | +The attention path does not change. KV cache does not change. The sparse work is |
| 25 | +inside the MLP half of the block. |
| 26 | + |
| 27 | +**Readings** |
| 28 | + |
| 29 | +- [Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer](https://arxiv.org/abs/1701.06538) |
| 30 | +- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668) |
| 31 | +- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) |
| 32 | + |
| 33 | +## Dense MLP vs MoE MLP |
| 34 | + |
| 35 | +The dense Qwen3 MLP from Week 1 has one set of weights: |
| 36 | + |
| 37 | +```plain |
| 38 | +w_gate: hidden_dim, dim |
| 39 | +w_up: hidden_dim, dim |
| 40 | +w_down: dim, hidden_dim |
| 41 | +``` |
| 42 | + |
| 43 | +A Qwen3-MoE sparse block has a bank of those weights: |
| 44 | + |
| 45 | +```plain |
| 46 | +expert_gate: num_experts, moe_hidden_dim, dim |
| 47 | +expert_up: num_experts, moe_hidden_dim, dim |
| 48 | +expert_down: num_experts, dim, moe_hidden_dim |
| 49 | +``` |
| 50 | + |
| 51 | +The router produces one score per expert: |
| 52 | + |
| 53 | +```plain |
| 54 | +router_logits: B, L, num_experts |
| 55 | +router_probs: softmax(router_logits) |
| 56 | +``` |
| 57 | + |
| 58 | +Then the model picks `num_experts_per_tok` experts for each token: |
| 59 | + |
| 60 | +```plain |
| 61 | +expert_ids: B, L, num_experts_per_tok |
| 62 | +expert_scores: B, L, num_experts_per_tok |
| 63 | +``` |
| 64 | + |
| 65 | +For each token, only those selected experts run. Their outputs are weighted and |
| 66 | +summed: |
| 67 | + |
| 68 | +```plain |
| 69 | +output[token] = sum(score_i * expert_i(token)) |
| 70 | +``` |
| 71 | + |
| 72 | +That is the central MoE idea: the model can contain many parameters, but each |
| 73 | +token activates only a small subset of them. |
| 74 | + |
| 75 | +## Qwen3-MoE Shape |
| 76 | + |
| 77 | +Qwen3-MoE keeps the same attention structure as Qwen3, including QK norm, GQA, |
| 78 | +RoPE, and the same KV cache interface. It replaces some dense MLP layers with a |
| 79 | +sparse MoE block. |
| 80 | + |
| 81 | +The useful pieces are: |
| 82 | + |
| 83 | +- `gate`: a router linear layer from hidden size to `num_experts` |
| 84 | +- `switch_mlp`: many SwiGLU experts with `moe_intermediate_size` |
| 85 | +- `num_experts_per_tok`: how many experts a token uses |
| 86 | +- `norm_topk_prob`: whether selected expert scores are renormalized |
| 87 | +- `decoder_sparse_step` and `mlp_only_layers`: which layers are sparse vs dense |
| 88 | + |
| 89 | +There is no shared expert in the Qwen3-MoE block we are following. The sparse |
| 90 | +feed-forward output is just the weighted top-k expert mixture. |
| 91 | + |
| 92 | +## The MLX Primitive |
| 93 | + |
| 94 | +MLX does not give us a single high-level MoE block in `mlx.nn`. The relevant |
| 95 | +primitive for this chapter is `mx.gather_qmm`: it performs quantized matrix |
| 96 | +multiplication while selecting a different matrix for each row. |
| 97 | + |
| 98 | +For MoE, that means: |
| 99 | + |
| 100 | +```plain |
| 101 | +token rows: N, D |
| 102 | +expert ids: N |
| 103 | +weights: E, O, D packed as 4-bit QuantizedWeights |
| 104 | +output: N, O |
| 105 | +``` |
| 106 | + |
| 107 | +The row with `expert_ids[i] = e` should multiply by `weights[e]`. |
| 108 | + |
| 109 | +When the expert ids are sorted, pass `sorted_indices=True`. Keep the inverse |
| 110 | +order from the sort so the result can be restored to the original token order. |
| 111 | + |
| 112 | +## Router Step |
| 113 | + |
| 114 | +The router is just a quantized linear layer: |
| 115 | + |
| 116 | +```python |
| 117 | +router_logits = quantized_linear(x, w_router) |
| 118 | +router_probs = softmax(router_logits, axis=-1) |
| 119 | +``` |
| 120 | + |
| 121 | +For a batch of tokens: |
| 122 | + |
| 123 | +```plain |
| 124 | +x: B, L, D |
| 125 | +router_logits: B, L, E |
| 126 | +router_probs: B, L, E |
| 127 | +``` |
| 128 | + |
| 129 | +where `E = num_experts`. |
| 130 | + |
| 131 | +Qwen3-MoE then uses top-k selection: |
| 132 | + |
| 133 | +```python |
| 134 | +expert_ids = argpartition(-router_probs, k)[:k] |
| 135 | +expert_scores = take_along_axis(router_probs, expert_ids) |
| 136 | +``` |
| 137 | + |
| 138 | +If `norm_topk_prob` is true, renormalize `expert_scores` so the selected scores |
| 139 | +sum to 1 for each token. |
| 140 | + |
| 141 | +## Expert Step |
| 142 | + |
| 143 | +Each expert is the same kind of SwiGLU MLP we already know: |
| 144 | + |
| 145 | +```plain |
| 146 | +expert(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x)) |
| 147 | +``` |
| 148 | + |
| 149 | +The implementation should build token-expert jobs, group them by expert, and run |
| 150 | +the expert projections with `mx.gather_qmm`: |
| 151 | + |
| 152 | +```plain |
| 153 | +selected expert ids -> expanded token-expert rows |
| 154 | +expanded rows -> sort/group by expert id |
| 155 | +grouped expert rows -> grouped gate/up projection |
| 156 | +SiLU(gate) * up -> grouped down projection |
| 157 | +restore original token/top-k order -> weighted sum |
| 158 | +``` |
| 159 | + |
| 160 | +The reorder is part of the model implementation. It keeps all token rows for the |
| 161 | +same expert contiguous so the expert bank can be applied with grouped matrix |
| 162 | +multiplication. |
| 163 | + |
| 164 | +## Task 1: Grouped Expert Linear |
| 165 | + |
| 166 | +``` |
| 167 | +src/tiny_llm/moe.py |
| 168 | +``` |
| 169 | + |
| 170 | +Implement `grouped_expert_linear`. This is the MLX-shaped core of MoE. |
| 171 | + |
| 172 | +The function accepts: |
| 173 | + |
| 174 | +```plain |
| 175 | +x: ..., D |
| 176 | +w_experts: QuantizedWeights for num_experts, output_dim, D |
| 177 | +expert_ids: ... |
| 178 | +``` |
| 179 | + |
| 180 | +It returns: |
| 181 | + |
| 182 | +```plain |
| 183 | +out: ..., output_dim |
| 184 | +``` |
| 185 | + |
| 186 | +The implementation should: |
| 187 | + |
| 188 | +```plain |
| 189 | +1. flatten token rows and expert ids, |
| 190 | +2. sort rows by expert id, |
| 191 | +3. call mx.gather_qmm with sorted_indices=True, |
| 192 | +4. restore the original order. |
| 193 | +``` |
| 194 | + |
| 195 | +For the grouped matmul, the shape should look like: |
| 196 | + |
| 197 | +```python |
| 198 | +out = mx.gather_qmm( |
| 199 | + mx.expand_dims(grouped_rows, -2), |
| 200 | + w_experts.weight, |
| 201 | + w_experts.scales, |
| 202 | + w_experts.biases, |
| 203 | + lhs_indices=mx.arange(grouped_rows.shape[0]), |
| 204 | + rhs_indices=grouped_expert_ids, |
| 205 | + transpose=True, |
| 206 | + group_size=w_experts.group_size, |
| 207 | + bits=w_experts.bits, |
| 208 | + mode=w_experts.mode, |
| 209 | + sorted_indices=True, |
| 210 | +).squeeze(-2) |
| 211 | +``` |
| 212 | + |
| 213 | +This task maps to the same idea as `QuantizedSwitchLinear` in `mlx-lm`: each |
| 214 | +token row uses a different packed expert matrix, and the expert ids choose the |
| 215 | +right matrix. |
| 216 | + |
| 217 | +## Task 2: Router Top-k |
| 218 | + |
| 219 | +``` |
| 220 | +src/tiny_llm/moe.py |
| 221 | +``` |
| 222 | + |
| 223 | +Implement `route_topk`. It accepts hidden states and router weights, then |
| 224 | +returns: |
| 225 | + |
| 226 | +- router probabilities |
| 227 | +- selected expert ids |
| 228 | +- selected expert scores |
| 229 | + |
| 230 | +Use `quantized_linear` and `softmax`. Use `mx.argpartition` to select the top |
| 231 | +`num_experts_per_tok` experts, then `mx.take_along_axis` to gather their scores. |
| 232 | + |
| 233 | +Keep `norm_topk_prob` as an argument because Qwen3-MoE stores this behavior in |
| 234 | +the model config. |
| 235 | + |
| 236 | +## Task 3: Qwen3 Sparse MoE Block |
| 237 | + |
| 238 | +``` |
| 239 | +src/tiny_llm/moe.py |
| 240 | +``` |
| 241 | + |
| 242 | +Implement `Moe` by composing Task 1 and Task 2: |
| 243 | + |
| 244 | +```plain |
| 245 | +hidden states -> route_topk |
| 246 | +hidden states + expert ids -> grouped gate projection |
| 247 | +hidden states + expert ids -> grouped up projection |
| 248 | +SiLU(gate) * up -> grouped down projection |
| 249 | +weighted sum over num_experts_per_tok |
| 250 | +``` |
| 251 | + |
| 252 | +This completes the Qwen3-MoE sparse feed-forward block. There is no shared expert |
| 253 | +branch in this block. |
| 254 | + |
| 255 | +## Task 4: Integrate Qwen3-MoE Layers |
| 256 | + |
| 257 | +``` |
| 258 | +src/tiny_llm/qwen3_week3.py |
| 259 | +src/tiny_llm/models.py |
| 260 | +``` |
| 261 | + |
| 262 | +Add a Qwen3-MoE loader path that reuses the Week 3 Qwen3 attention and paged KV |
| 263 | +cache behavior, but swaps selected block MLPs for `Moe`. |
| 264 | + |
| 265 | +The model wrapper should: |
| 266 | + |
| 267 | +- keep Qwen3 attention unchanged, |
| 268 | +- use regular `Qwen3MLP` for `mlp_only_layers`, |
| 269 | +- use `Moe` for sparse layers selected by |
| 270 | + `decoder_sparse_step`, |
| 271 | +- load router and expert weights as `QuantizedWeights` from the Qwen3-MoE MLX |
| 272 | + model, |
| 273 | +- preserve the same decode call shape: |
| 274 | + |
| 275 | +```python |
| 276 | +logits = model(tokens, offset, cache) |
| 277 | +``` |
| 278 | + |
| 279 | +No scheduler API change in `src/tiny_llm/batch.py` is required for correctness. |
| 280 | + |
| 281 | +Run this task through the normal generation entrypoints instead of adding a |
| 282 | +separate unit test. For example: |
| 283 | + |
| 284 | +```bash |
| 285 | +hf download Qwen/Qwen3-30B-A3B-MLX-4bit |
| 286 | + |
| 287 | +pdm run main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \ |
| 288 | + --prompt "Give me a short introduction to mixture of experts." |
| 289 | + |
| 290 | +pdm run batch-main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \ |
| 291 | + --batch-size 2 --prefill-step 16 |
| 292 | +``` |
| 293 | + |
| 294 | +{{#include copyright.md}} |
0 commit comments