|
| 1 | +# MiMo-V2.5-Pro on SGL-JAX |
| 2 | + |
| 3 | +MiMo-V2.5-Pro is Xiaomi's large-scale MoE model with hybrid attention (full attention + sliding window attention) and FP8-quantized weights, optimized for long-context reasoning. SGL-JAX supports it on TPU v6e and v7x with FP8 dequantization, tensor parallelism, and expert parallelism. |
| 4 | + |
| 5 | +This cookbook walks through setting up a multi-node TPU slice and serving MiMo-V2.5-Pro end-to-end. The primary reference configuration is **TPU v7x-16** (`2x2x4`, 4 nodes); a TPU v6e-64 (`4x4x4`, 16 nodes) launch command is also provided. |
| 6 | + |
| 7 | +## Hardware Requirements |
| 8 | + |
| 9 | +| TPU Type | Topology | Chips per node | Nodes | Total chips | |
| 10 | +|----------|----------|----------------|-------|-------------| |
| 11 | +| v7x | 2x2x4 | 4 | 4 | 16 | |
| 12 | +| v6e | 4x4x4 | 4 | 16 | 64 | |
| 13 | + |
| 14 | +All nodes must be in the same TPU slice and able to reach each other on the JAX init port (`5000` by default below) and the TPU process port (`8471`). |
| 15 | + |
| 16 | +## Environment Setup |
| 17 | + |
| 18 | +### 1. Start a JAX TPU container on each node |
| 19 | + |
| 20 | +Use the official JAX 0.8.1 TPU image: |
| 21 | + |
| 22 | +```bash |
| 23 | +docker run -it --privileged \ |
| 24 | + --shm-size=32g \ |
| 25 | + --ipc=host \ |
| 26 | + --network=host \ |
| 27 | + -v /dev:/dev \ |
| 28 | + us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.8.1-rev1 bash |
| 29 | +``` |
| 30 | + |
| 31 | +> The same image is used by SGL-JAX's GKE / SkyPilot launchers; pinning to `jax0.8.1-rev1` keeps the JAX runtime in lockstep with the SGL-JAX `[tpu]` extras. |
| 32 | +
|
| 33 | +### 2. Clone and install SGL-JAX |
| 34 | + |
| 35 | +```bash |
| 36 | +git clone https://github.com/sgl-project/sglang-jax.git |
| 37 | +cd sglang-jax |
| 38 | +git fetch origin |
| 39 | +pip install -e "python[tpu]" |
| 40 | +``` |
| 41 | + |
| 42 | +This installs `sgl-jax` together with its TPU-specific dependencies (matching `jax==0.8.1`). |
| 43 | + |
| 44 | +### 3. (Optional) Install evalscope for accuracy evaluation |
| 45 | + |
| 46 | +```bash |
| 47 | +pip install evalscope==0.17.1 |
| 48 | +``` |
| 49 | + |
| 50 | +## Launching the Server |
| 51 | + |
| 52 | +Run the **same** command on every node, only varying `${NODE_RANK}` and pointing all nodes at the rank-0 host via `${MASTER_ADDR}` (e.g. `node0.cluster.local:5000`). |
| 53 | + |
| 54 | +### TPU v7x (16 chips, 4 nodes, `2x2x4`) |
| 55 | + |
| 56 | +```bash |
| 57 | +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python -m sgl_jax.launch_server \ |
| 58 | + --model-path XiaomiMiMo/MiMo-V2.5-Pro \ |
| 59 | + --trust-remote-code \ |
| 60 | + --tp-size 32 --ep-size 32 \ |
| 61 | + --moe-backend fused \ |
| 62 | + --host 0.0.0.0 --port 30271 \ |
| 63 | + --disable-radix-cache \ |
| 64 | + --page-size 256 --context-length 262144 \ |
| 65 | + --chunked-prefill-size 4096 \ |
| 66 | + --dtype bfloat16 --mem-fraction-static 0.95 \ |
| 67 | + --swa-full-tokens-ratio 0.25 \ |
| 68 | + --log-level info --max-running-requests 512 \ |
| 69 | + --nnodes 4 --node-rank ${NODE_RANK} \ |
| 70 | + --dist-init-addr ${MASTER_ADDR} |
| 71 | +``` |
| 72 | + |
| 73 | +`${NODE_RANK}` ranges from `0` to `3`. |
| 74 | + |
| 75 | +### TPU v6e (64 chips, 16 nodes, `4x4x4`) |
| 76 | + |
| 77 | +```bash |
| 78 | +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python -m sgl_jax.launch_server \ |
| 79 | + --model-path XiaomiMiMo/MiMo-V2.5-Pro \ |
| 80 | + --trust-remote-code \ |
| 81 | + --port 30271 \ |
| 82 | + --tp-size 64 --ep-size 64 \ |
| 83 | + --context-length 262144 --max-seq-len 4096 \ |
| 84 | + --chunked-prefill-size 4096 --max-prefill-tokens 16384 \ |
| 85 | + --page-size 256 \ |
| 86 | + --mem-fraction-static 0.92 \ |
| 87 | + --max-running-requests 512 \ |
| 88 | + --swa-full-tokens-ratio 0.15 \ |
| 89 | + --attention-backend fa --moe-backend fused \ |
| 90 | + --disable-radix-cache \ |
| 91 | + --nnodes 16 --node-rank ${NODE_RANK} \ |
| 92 | + --dist-init-addr ${MASTER_ADDR} |
| 93 | +``` |
| 94 | + |
| 95 | +`${NODE_RANK}` ranges from `0` to `15`. Compared to the v7x recipe, the v6e command lowers `mem-fraction-static` to `0.92` and tightens `swa-full-tokens-ratio` to `0.15` because v6e has less HBM per chip. |
| 96 | + |
| 97 | +### Key flags |
| 98 | + |
| 99 | +- `--tp-size / --ep-size`: Match the total JAX device count across all nodes. v7x exposes 2 logical devices per chip (32 = 16 chips × 2); v6e exposes 1 (64 = 64 chips × 1). |
| 100 | +- `--moe-backend fused`: Uses the fused Pallas MoE kernel (recommended). `epmoe` is also supported but slower at this scale. |
| 101 | +- `--page-size 256`: Required page size for the SWA pool's eviction logic. Smaller page sizes are not supported with MiMo-V2.5-Pro. |
| 102 | +- `--context-length 262144`: 256K context. Match this to your workload's max prompt + output length. |
| 103 | +- `--chunked-prefill-size 4096`: Splits long prefills into 4K-token chunks to bound peak HBM during prefill. |
| 104 | +- `--swa-full-tokens-ratio`: Fraction of the KV cache pool reserved for full-attention layers. `0.25` for v7x, `0.15` for v6e. |
| 105 | +- `--mem-fraction-static`: Fraction of HBM reserved for weights + KV cache. Lower if the host shares the TPU. |
| 106 | +- `--max-running-requests 512`: Upper bound on concurrent decoding requests. |
| 107 | +- `--disable-radix-cache`: Required for now — radix cache is not yet supported. |
| 108 | +- `JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache`: Persists XLA / Pallas compilation cache so subsequent restarts skip the ~4-minute precompile step. |
| 109 | + |
| 110 | +## Running on GKE (Indexed Job) |
| 111 | + |
| 112 | +For Kubernetes / GKE deployments, the four nodes are launched as an Indexed Job + headless Service so that pod `${index}` resolves to a stable DNS name. Below is a minimal manifest for a TPU v7x `2x2x4` slice. Adjust `nodeSelector`, `claimName`, and `image` for your cluster. |
| 113 | + |
| 114 | +```yaml |
| 115 | +--- |
| 116 | +apiVersion: v1 |
| 117 | +kind: Service |
| 118 | +metadata: |
| 119 | + name: mimo-v2-5-pro-headless-svc |
| 120 | +spec: |
| 121 | + clusterIP: None |
| 122 | + selector: |
| 123 | + job-name: mimo-v2-5-pro |
| 124 | + ports: |
| 125 | + - name: dist-init |
| 126 | + port: 5000 |
| 127 | + - name: tpu-process |
| 128 | + port: 8471 |
| 129 | +--- |
| 130 | +apiVersion: batch/v1 |
| 131 | +kind: Job |
| 132 | +metadata: |
| 133 | + name: mimo-v2-5-pro |
| 134 | +spec: |
| 135 | + backoffLimit: 0 |
| 136 | + completionMode: Indexed |
| 137 | + parallelism: 4 |
| 138 | + completions: 4 |
| 139 | + template: |
| 140 | + metadata: |
| 141 | + annotations: |
| 142 | + gke-gcsfuse/volumes: "true" |
| 143 | + spec: |
| 144 | + subdomain: mimo-v2-5-pro-headless-svc |
| 145 | + restartPolicy: Never |
| 146 | + serviceAccountName: gcs-account |
| 147 | + nodeSelector: |
| 148 | + cloud.google.com/gke-tpu-accelerator: tpu7x |
| 149 | + cloud.google.com/gke-tpu-topology: 2x2x4 |
| 150 | + containers: |
| 151 | + - name: mimo-v2-5-pro |
| 152 | + image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.8.1-rev1 |
| 153 | + command: ["bash", "-lc"] |
| 154 | + args: |
| 155 | + - | |
| 156 | + set -euxo pipefail |
| 157 | +
|
| 158 | + # --- 1. Clone & install --- |
| 159 | + REPO_DIR=/tmp/sglang-jax |
| 160 | + if [ ! -d "$REPO_DIR/.git" ]; then |
| 161 | + git clone https://github.com/sgl-project/sglang-jax.git "$REPO_DIR" |
| 162 | + fi |
| 163 | + cd "$REPO_DIR" && git fetch origin && pip install -e "python[tpu]" |
| 164 | +
|
| 165 | + # --- 2. Launch the server --- |
| 166 | + export NODE_RANK=${JOB_COMPLETION_INDEX} |
| 167 | + export MASTER_ADDR=mimo-v2-5-pro-0.mimo-v2-5-pro-headless-svc:5000 |
| 168 | + JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python -m sgl_jax.launch_server \ |
| 169 | + --model-path /models/MiMo-V2.5-Pro \ |
| 170 | + --trust-remote-code \ |
| 171 | + --tp-size 32 --ep-size 32 \ |
| 172 | + --moe-backend fused \ |
| 173 | + --host 0.0.0.0 --port 30271 \ |
| 174 | + --disable-radix-cache \ |
| 175 | + --page-size 256 --context-length 262144 \ |
| 176 | + --chunked-prefill-size 4096 \ |
| 177 | + --dtype bfloat16 --mem-fraction-static 0.95 \ |
| 178 | + --swa-full-tokens-ratio 0.25 \ |
| 179 | + --log-level info --max-running-requests 512 \ |
| 180 | + --nnodes 4 --node-rank ${NODE_RANK} \ |
| 181 | + --dist-init-addr ${MASTER_ADDR} |
| 182 | + env: |
| 183 | + - name: TPU_PROCESS_ADDRESSES |
| 184 | + value: mimo-v2-5-pro-0.mimo-v2-5-pro-headless-svc:8471,mimo-v2-5-pro-1.mimo-v2-5-pro-headless-svc:8471,mimo-v2-5-pro-2.mimo-v2-5-pro-headless-svc:8471,mimo-v2-5-pro-3.mimo-v2-5-pro-headless-svc:8471 |
| 185 | + - name: TPU_WORKER_HOSTNAMES |
| 186 | + value: mimo-v2-5-pro-0.mimo-v2-5-pro-headless-svc,mimo-v2-5-pro-1.mimo-v2-5-pro-headless-svc,mimo-v2-5-pro-2.mimo-v2-5-pro-headless-svc,mimo-v2-5-pro-3.mimo-v2-5-pro-headless-svc |
| 187 | + - name: TPU_PROCESS_PORT |
| 188 | + value: "8471" |
| 189 | + - name: JOB_COMPLETION_INDEX |
| 190 | + valueFrom: |
| 191 | + fieldRef: |
| 192 | + fieldPath: metadata.labels['batch.kubernetes.io/job-completion-index'] |
| 193 | + - name: TPU_WORKER_ID |
| 194 | + valueFrom: |
| 195 | + fieldRef: |
| 196 | + fieldPath: metadata.labels['batch.kubernetes.io/job-completion-index'] |
| 197 | + ports: |
| 198 | + - containerPort: 30271 |
| 199 | + name: http |
| 200 | + - containerPort: 5000 |
| 201 | + name: dist-init |
| 202 | + resources: |
| 203 | + requests: |
| 204 | + google.com/tpu: "4" |
| 205 | + limits: |
| 206 | + google.com/tpu: "4" |
| 207 | + volumeMounts: |
| 208 | + - mountPath: /models |
| 209 | + name: model-storage |
| 210 | + - mountPath: /dev/shm |
| 211 | + name: dev-shm |
| 212 | + volumes: |
| 213 | + - name: dev-shm |
| 214 | + emptyDir: |
| 215 | + medium: Memory |
| 216 | + - name: gke-gcsfuse-cache |
| 217 | + emptyDir: |
| 218 | + medium: Memory |
| 219 | + - name: model-storage |
| 220 | + persistentVolumeClaim: |
| 221 | + claimName: <your-model-pvc> |
| 222 | +``` |
| 223 | +
|
| 224 | +Apply with: |
| 225 | +
|
| 226 | +```bash |
| 227 | +kubectl apply -f mimo-v2-5-pro.yaml |
| 228 | +kubectl wait --for=condition=Ready pod -l job-name=mimo-v2-5-pro --timeout=600s |
| 229 | +``` |
| 230 | + |
| 231 | +The server is ready once `mimo-v2-5-pro-0` logs `Uvicorn running on http://0.0.0.0:30271`. |
| 232 | + |
| 233 | +## Sending a Request |
| 234 | + |
| 235 | +```bash |
| 236 | +curl -X POST http://<rank0-ip>:30271/v1/chat/completions \ |
| 237 | + -H "Content-Type: application/json" \ |
| 238 | + -d '{ |
| 239 | + "model": "XiaomiMiMo/MiMo-V2.5-Pro", |
| 240 | + "messages": [{"role": "user", "content": "Prove that sqrt(2) is irrational."}], |
| 241 | + "temperature": 1, "top_p": 0.95, "max_tokens": 4096 |
| 242 | + }' |
| 243 | +``` |
| 244 | + |
| 245 | +## Accuracy Evaluation |
| 246 | + |
| 247 | +### AIME 2025 (with thinking enabled) |
| 248 | + |
| 249 | +```bash |
| 250 | +evalscope eval \ |
| 251 | + --model XiaomiMiMo/MiMo-V2.5-Pro \ |
| 252 | + --api-url http://127.0.0.1:30271/v1/chat/completions \ |
| 253 | + --api-key EMPTY \ |
| 254 | + --eval-type service \ |
| 255 | + --datasets aime25 \ |
| 256 | + --eval-batch-size 16 \ |
| 257 | + --timeout 6000000 \ |
| 258 | + --generation-config '{"temperature":1,"top_p":0.95,"max_tokens":131072,"chat_template_kwargs":{"enable_thinking":true}}' |
| 259 | +``` |
| 260 | + |
| 261 | +Reference numbers measured on TPU v7x-16 (`2x2x4`, `tp=32 ep=32`): |
| 262 | + |
| 263 | +| Model | Dataset | Metric | Subset | Num | Score | |
| 264 | +|:--------------|:--------|:--------------|:------------|:----|:--------| |
| 265 | +| MiMo-V2.5-Pro | aime25 | AveragePass@1 | AIME2025-I | 15 | 0.8667 | |
| 266 | +| MiMo-V2.5-Pro | aime25 | AveragePass@1 | AIME2025-II | 15 | 1.0000 | |
| 267 | +| MiMo-V2.5-Pro | aime25 | AveragePass@1 | OVERALL | 30 | 0.9334 | |
| 268 | + |
| 269 | +## Known Issues |
| 270 | + |
| 271 | +- Radix cache cannot be enabled yet — always pass `--disable-radix-cache`. |
| 272 | + |
| 273 | +## Additional Resources |
| 274 | + |
| 275 | +- [MiMo-V2.5-Pro Model Card](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) |
0 commit comments