Skip to content

Commit c0badba

Browse files
authored
mimo v2.5 pro cookbook (#980)
1 parent 0c97a91 commit c0badba

1 file changed

Lines changed: 275 additions & 0 deletions

File tree

docs/basic_usage/mimo_v2.5_pro.md

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# MiMo-V2.5-Pro on SGL-JAX
2+
3+
MiMo-V2.5-Pro is Xiaomi's large-scale MoE model with hybrid attention (full attention + sliding window attention) and FP8-quantized weights, optimized for long-context reasoning. SGL-JAX supports it on TPU v6e and v7x with FP8 dequantization, tensor parallelism, and expert parallelism.
4+
5+
This cookbook walks through setting up a multi-node TPU slice and serving MiMo-V2.5-Pro end-to-end. The primary reference configuration is **TPU v7x-16** (`2x2x4`, 4 nodes); a TPU v6e-64 (`4x4x4`, 16 nodes) launch command is also provided.
6+
7+
## Hardware Requirements
8+
9+
| TPU Type | Topology | Chips per node | Nodes | Total chips |
10+
|----------|----------|----------------|-------|-------------|
11+
| v7x | 2x2x4 | 4 | 4 | 16 |
12+
| v6e | 4x4x4 | 4 | 16 | 64 |
13+
14+
All nodes must be in the same TPU slice and able to reach each other on the JAX init port (`5000` by default below) and the TPU process port (`8471`).
15+
16+
## Environment Setup
17+
18+
### 1. Start a JAX TPU container on each node
19+
20+
Use the official JAX 0.8.1 TPU image:
21+
22+
```bash
23+
docker run -it --privileged \
24+
--shm-size=32g \
25+
--ipc=host \
26+
--network=host \
27+
-v /dev:/dev \
28+
us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.8.1-rev1 bash
29+
```
30+
31+
> The same image is used by SGL-JAX's GKE / SkyPilot launchers; pinning to `jax0.8.1-rev1` keeps the JAX runtime in lockstep with the SGL-JAX `[tpu]` extras.
32+
33+
### 2. Clone and install SGL-JAX
34+
35+
```bash
36+
git clone https://github.com/sgl-project/sglang-jax.git
37+
cd sglang-jax
38+
git fetch origin
39+
pip install -e "python[tpu]"
40+
```
41+
42+
This installs `sgl-jax` together with its TPU-specific dependencies (matching `jax==0.8.1`).
43+
44+
### 3. (Optional) Install evalscope for accuracy evaluation
45+
46+
```bash
47+
pip install evalscope==0.17.1
48+
```
49+
50+
## Launching the Server
51+
52+
Run the **same** command on every node, only varying `${NODE_RANK}` and pointing all nodes at the rank-0 host via `${MASTER_ADDR}` (e.g. `node0.cluster.local:5000`).
53+
54+
### TPU v7x (16 chips, 4 nodes, `2x2x4`)
55+
56+
```bash
57+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python -m sgl_jax.launch_server \
58+
--model-path XiaomiMiMo/MiMo-V2.5-Pro \
59+
--trust-remote-code \
60+
--tp-size 32 --ep-size 32 \
61+
--moe-backend fused \
62+
--host 0.0.0.0 --port 30271 \
63+
--disable-radix-cache \
64+
--page-size 256 --context-length 262144 \
65+
--chunked-prefill-size 4096 \
66+
--dtype bfloat16 --mem-fraction-static 0.95 \
67+
--swa-full-tokens-ratio 0.25 \
68+
--log-level info --max-running-requests 512 \
69+
--nnodes 4 --node-rank ${NODE_RANK} \
70+
--dist-init-addr ${MASTER_ADDR}
71+
```
72+
73+
`${NODE_RANK}` ranges from `0` to `3`.
74+
75+
### TPU v6e (64 chips, 16 nodes, `4x4x4`)
76+
77+
```bash
78+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python -m sgl_jax.launch_server \
79+
--model-path XiaomiMiMo/MiMo-V2.5-Pro \
80+
--trust-remote-code \
81+
--port 30271 \
82+
--tp-size 64 --ep-size 64 \
83+
--context-length 262144 --max-seq-len 4096 \
84+
--chunked-prefill-size 4096 --max-prefill-tokens 16384 \
85+
--page-size 256 \
86+
--mem-fraction-static 0.92 \
87+
--max-running-requests 512 \
88+
--swa-full-tokens-ratio 0.15 \
89+
--attention-backend fa --moe-backend fused \
90+
--disable-radix-cache \
91+
--nnodes 16 --node-rank ${NODE_RANK} \
92+
--dist-init-addr ${MASTER_ADDR}
93+
```
94+
95+
`${NODE_RANK}` ranges from `0` to `15`. Compared to the v7x recipe, the v6e command lowers `mem-fraction-static` to `0.92` and tightens `swa-full-tokens-ratio` to `0.15` because v6e has less HBM per chip.
96+
97+
### Key flags
98+
99+
- `--tp-size / --ep-size`: Match the total JAX device count across all nodes. v7x exposes 2 logical devices per chip (32 = 16 chips × 2); v6e exposes 1 (64 = 64 chips × 1).
100+
- `--moe-backend fused`: Uses the fused Pallas MoE kernel (recommended). `epmoe` is also supported but slower at this scale.
101+
- `--page-size 256`: Required page size for the SWA pool's eviction logic. Smaller page sizes are not supported with MiMo-V2.5-Pro.
102+
- `--context-length 262144`: 256K context. Match this to your workload's max prompt + output length.
103+
- `--chunked-prefill-size 4096`: Splits long prefills into 4K-token chunks to bound peak HBM during prefill.
104+
- `--swa-full-tokens-ratio`: Fraction of the KV cache pool reserved for full-attention layers. `0.25` for v7x, `0.15` for v6e.
105+
- `--mem-fraction-static`: Fraction of HBM reserved for weights + KV cache. Lower if the host shares the TPU.
106+
- `--max-running-requests 512`: Upper bound on concurrent decoding requests.
107+
- `--disable-radix-cache`: Required for now — radix cache is not yet supported.
108+
- `JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache`: Persists XLA / Pallas compilation cache so subsequent restarts skip the ~4-minute precompile step.
109+
110+
## Running on GKE (Indexed Job)
111+
112+
For Kubernetes / GKE deployments, the four nodes are launched as an Indexed Job + headless Service so that pod `${index}` resolves to a stable DNS name. Below is a minimal manifest for a TPU v7x `2x2x4` slice. Adjust `nodeSelector`, `claimName`, and `image` for your cluster.
113+
114+
```yaml
115+
---
116+
apiVersion: v1
117+
kind: Service
118+
metadata:
119+
name: mimo-v2-5-pro-headless-svc
120+
spec:
121+
clusterIP: None
122+
selector:
123+
job-name: mimo-v2-5-pro
124+
ports:
125+
- name: dist-init
126+
port: 5000
127+
- name: tpu-process
128+
port: 8471
129+
---
130+
apiVersion: batch/v1
131+
kind: Job
132+
metadata:
133+
name: mimo-v2-5-pro
134+
spec:
135+
backoffLimit: 0
136+
completionMode: Indexed
137+
parallelism: 4
138+
completions: 4
139+
template:
140+
metadata:
141+
annotations:
142+
gke-gcsfuse/volumes: "true"
143+
spec:
144+
subdomain: mimo-v2-5-pro-headless-svc
145+
restartPolicy: Never
146+
serviceAccountName: gcs-account
147+
nodeSelector:
148+
cloud.google.com/gke-tpu-accelerator: tpu7x
149+
cloud.google.com/gke-tpu-topology: 2x2x4
150+
containers:
151+
- name: mimo-v2-5-pro
152+
image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.8.1-rev1
153+
command: ["bash", "-lc"]
154+
args:
155+
- |
156+
set -euxo pipefail
157+
158+
# --- 1. Clone & install ---
159+
REPO_DIR=/tmp/sglang-jax
160+
if [ ! -d "$REPO_DIR/.git" ]; then
161+
git clone https://github.com/sgl-project/sglang-jax.git "$REPO_DIR"
162+
fi
163+
cd "$REPO_DIR" && git fetch origin && pip install -e "python[tpu]"
164+
165+
# --- 2. Launch the server ---
166+
export NODE_RANK=${JOB_COMPLETION_INDEX}
167+
export MASTER_ADDR=mimo-v2-5-pro-0.mimo-v2-5-pro-headless-svc:5000
168+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python -m sgl_jax.launch_server \
169+
--model-path /models/MiMo-V2.5-Pro \
170+
--trust-remote-code \
171+
--tp-size 32 --ep-size 32 \
172+
--moe-backend fused \
173+
--host 0.0.0.0 --port 30271 \
174+
--disable-radix-cache \
175+
--page-size 256 --context-length 262144 \
176+
--chunked-prefill-size 4096 \
177+
--dtype bfloat16 --mem-fraction-static 0.95 \
178+
--swa-full-tokens-ratio 0.25 \
179+
--log-level info --max-running-requests 512 \
180+
--nnodes 4 --node-rank ${NODE_RANK} \
181+
--dist-init-addr ${MASTER_ADDR}
182+
env:
183+
- name: TPU_PROCESS_ADDRESSES
184+
value: mimo-v2-5-pro-0.mimo-v2-5-pro-headless-svc:8471,mimo-v2-5-pro-1.mimo-v2-5-pro-headless-svc:8471,mimo-v2-5-pro-2.mimo-v2-5-pro-headless-svc:8471,mimo-v2-5-pro-3.mimo-v2-5-pro-headless-svc:8471
185+
- name: TPU_WORKER_HOSTNAMES
186+
value: mimo-v2-5-pro-0.mimo-v2-5-pro-headless-svc,mimo-v2-5-pro-1.mimo-v2-5-pro-headless-svc,mimo-v2-5-pro-2.mimo-v2-5-pro-headless-svc,mimo-v2-5-pro-3.mimo-v2-5-pro-headless-svc
187+
- name: TPU_PROCESS_PORT
188+
value: "8471"
189+
- name: JOB_COMPLETION_INDEX
190+
valueFrom:
191+
fieldRef:
192+
fieldPath: metadata.labels['batch.kubernetes.io/job-completion-index']
193+
- name: TPU_WORKER_ID
194+
valueFrom:
195+
fieldRef:
196+
fieldPath: metadata.labels['batch.kubernetes.io/job-completion-index']
197+
ports:
198+
- containerPort: 30271
199+
name: http
200+
- containerPort: 5000
201+
name: dist-init
202+
resources:
203+
requests:
204+
google.com/tpu: "4"
205+
limits:
206+
google.com/tpu: "4"
207+
volumeMounts:
208+
- mountPath: /models
209+
name: model-storage
210+
- mountPath: /dev/shm
211+
name: dev-shm
212+
volumes:
213+
- name: dev-shm
214+
emptyDir:
215+
medium: Memory
216+
- name: gke-gcsfuse-cache
217+
emptyDir:
218+
medium: Memory
219+
- name: model-storage
220+
persistentVolumeClaim:
221+
claimName: <your-model-pvc>
222+
```
223+
224+
Apply with:
225+
226+
```bash
227+
kubectl apply -f mimo-v2-5-pro.yaml
228+
kubectl wait --for=condition=Ready pod -l job-name=mimo-v2-5-pro --timeout=600s
229+
```
230+
231+
The server is ready once `mimo-v2-5-pro-0` logs `Uvicorn running on http://0.0.0.0:30271`.
232+
233+
## Sending a Request
234+
235+
```bash
236+
curl -X POST http://<rank0-ip>:30271/v1/chat/completions \
237+
-H "Content-Type: application/json" \
238+
-d '{
239+
"model": "XiaomiMiMo/MiMo-V2.5-Pro",
240+
"messages": [{"role": "user", "content": "Prove that sqrt(2) is irrational."}],
241+
"temperature": 1, "top_p": 0.95, "max_tokens": 4096
242+
}'
243+
```
244+
245+
## Accuracy Evaluation
246+
247+
### AIME 2025 (with thinking enabled)
248+
249+
```bash
250+
evalscope eval \
251+
--model XiaomiMiMo/MiMo-V2.5-Pro \
252+
--api-url http://127.0.0.1:30271/v1/chat/completions \
253+
--api-key EMPTY \
254+
--eval-type service \
255+
--datasets aime25 \
256+
--eval-batch-size 16 \
257+
--timeout 6000000 \
258+
--generation-config '{"temperature":1,"top_p":0.95,"max_tokens":131072,"chat_template_kwargs":{"enable_thinking":true}}'
259+
```
260+
261+
Reference numbers measured on TPU v7x-16 (`2x2x4`, `tp=32 ep=32`):
262+
263+
| Model | Dataset | Metric | Subset | Num | Score |
264+
|:--------------|:--------|:--------------|:------------|:----|:--------|
265+
| MiMo-V2.5-Pro | aime25 | AveragePass@1 | AIME2025-I | 15 | 0.8667 |
266+
| MiMo-V2.5-Pro | aime25 | AveragePass@1 | AIME2025-II | 15 | 1.0000 |
267+
| MiMo-V2.5-Pro | aime25 | AveragePass@1 | OVERALL | 30 | 0.9334 |
268+
269+
## Known Issues
270+
271+
- Radix cache cannot be enabled yet — always pass `--disable-radix-cache`.
272+
273+
## Additional Resources
274+
275+
- [MiMo-V2.5-Pro Model Card](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro)

0 commit comments

Comments
 (0)