Move FP8 MoE weight requantization from CPU to TPU#1842
Move FP8 MoE weight requantization from CPU to TPU#1842rohan-reddy wants to merge 3 commits intovllm-project:mainfrom
Conversation
e284ace to
b8e7769
Compare
|
awesome! thanks for working on this. for context, one of the reason we've switched to cpu was that we sometimes saw a sharding / layout issue that causes a crash or oom in certain edge cases. can you try testing this code in various different configs - different models, different sharding configs (tp, ep, dp) - and verify it's working correctly for all use cases? @lk-chen, can you also review this pr? |
|
Thanks Rohan, the weight loading was intentionally put on CPU to avoid OOM. Could you check if #1848 can be finished before change weight loading logic here? Just FYI, if we do observe unnecessary device usage, https://docs.jax.dev/en/latest/notebooks/host-offloading.html might be helpful |
|
@kyuyeunk Got it, are there any particular models I should make sure to check? If not I will just try a few different ones with different sharding configs. |
using v6e-8 should be okay! i think any model config that makes the last dim of a sharded weight become non-divisible by 128 should be a good test. so something like |
|
@lk-chen I updated the description with data recorded by the weight memory monitor utility, it shows some spikes but max peak delta of still 1.2 GB. For a large model, I assume the user would be running on larger hardware config too, so the per-device transient shouldn't blow up too much, as it would be split up among more devices. In previous implementation of the TPU requantization, if I read correctly, we were pushing the whole weight to one device and then redistribute, which OOM'd? Does the CI include some tests that can inform us about the result of a large model? I don't think I have requisite hardware to try myself. @kyuyeunk Also updated the description with more sharding config tests including the one you suggested |
No, the tests in CI are intended to be small and efficient. Most OOM issue I encountered is DeepSeek-R1, which is 685B and barely fit on Ironwood (192*4=768GiB). And any unnecessary HBM usage will negatively impact concurrency, which impacts overall throughput. |
I'll do some testing of my own by patching this commit & see if it works correctly. at very least, we can introduce an option that's disabled by default so users can enabled it only when the model is small and they know it won't cause oom. |
|
@kyuyeunk I also have some prototype working now (not yet in this PR) of doing expert-batched requant on TPU. This way the XLA op won't reserve the whole FP32 space for the shard at once. Example: if there are 16 experts on a shard, this PR's implementation would reserve the FP32 space for all experts, but with batch size = 1, we would need only 1/16 the extra space at any one time. This would use much less HBM, but problem is it adds some complexity as the user should probably configure it based on model & hardware combo they're running on. Large batch size speeds up the quantization but could still bloat HBM for models with larger experts. Unless we hardcode batch size of 1. Then I think the runtime won't be much better for small models but could be better for large models, and safety is pretty good. Is that a fruitful direction to explore? |
@kyuyeunk not only OOM. If the peak HBM usage is not release after weight loading, less HBM can be allocated for KV cache, which impacts concurrency/throughput. |
|
@lk-chen I included the logs ("Memory monitor logs - this PR"), it also shows before and after, and the after is always the same or reduced compared to before. So the peak usage is released by the end of the weight loading |
|
Tested with qwen coder 480 fp8 (482GB) on singlehost v7x (768GB). Confirmed that there is no OOM and saw model loading time speed up. before after however, this is still relatively small model compared to v7x's hbm capacity. @lk-chen, maybe you can try it on deepseek and see how it goes?
fair point. we can easily check this by checking this log during server launch time: if |
|
also, does qwen3 235b fp8 fit in v6e-8? if so, let's use that model instead because qwen3 30b is too small to properly stress test the hbm usage. |
|
(will need to review in more detail, but) i think this pr is good but, only thing that i want an answer is whether this change is safe to be applied for all scenarios or do we need to make it an opt-in that is disabled by default. |
|
Deepseek-R1 failed to load on Ironwood CMD |
@rohan-reddy maybe this idea can help reduce the peak hbm usage and thus avoid oom? |
|
I revised the approach to do expert-batched requant using lax.scan. The batching is supposed to be 1/2 of a full expert equivalent memory per iteration. |
|
Verifying commit 1583e31, loading deepseek-r1 on Ironwood |
|
Thanks @lk-chen for running this! |
|
I tracked With The DeepSeek OOM ("Attempting to reserve 3.06G, 2.81G free") at layer 143/163 is consistent with this, DeepSeek has ~4x larger expert dimensions than Qwen 235B, and 3060 MiB ≈ 4 × 739 MiB. Reducing Qwen3-235B-A22B-FP8, TP=8, v6e-8 (32 GiB HBM/chip)
Qwen3-30B-A3B-FP8, TP=8, v6e-8
|
|
Tried a few other approaches. The winner right now appears to be wrapping the FP8 MoE requantization + weight processing in Benchmark Results (Qwen3-235B-A22B, TP=8, 128 experts):
Wall clock comparison of requant phase to main (disk I/O ~43s and engine init/warmup ~52s are identical):
Batch size reduction also had a better impact here on memory use. Possibly because intermediates are a larger fraction of the smaller program. Used
I don't know how much further reducible this allocation is. I tried some other approaches including breaking down into smaller JIT calls wrapped in a Python loop, so XLA materializes a smaller tensor. But those were ineffective (details below if interested). Other approachesFused scan (requant + process per expert in one scan step)Fusing requant and Chunked multi-JIT with per-device assemblyBreaking the computation into multiple smaller JIT calls (e.g., 4 chunks of 4 experts), then assembling results outside XLA using Python loop with per-expert JIT calls16 separate JIT calls per device, each processing one expert (
|
Signed-off-by: Rohan Reddy <rreddy.nyc@gmail.com>
Signed-off-by: Rohan Reddy <rreddy.nyc@gmail.com>
Signed-off-by: Rohan Reddy <rreddy.nyc@gmail.com>
6ccd197 to
41ab4c6
Compare
|
random idea: it may be possible to leverage |
|
Testing on 41ab4c6 |
Summary
process_fp8_moe_weights(which is@jax.jit) runs on TPU instead of CPUjax.experimental.shard_map+jax.lax.scanto process experts with per-device local shapes, minimizing XLA program reservationshard_moe_weightsbecomes a no-op post-requantApproach
Current: CPU load FP8, CPU dequant FP32, CPU requant, transfer to TPU
New: CPU load FP8, shard FP8 across TPUs,
shard_map+lax.scanbatched dequant/requant on TPU (1 expert/step), sharding constraints applied inside JITshard_mapwraps the requant +process_moe_weightsso XLA compiles with per-device local shapes (e.g.,[16, ...]instead of[128, ...]), significantly reducing XLA program reservation (bytes_reserved). Inside shard_map,lax.scanprocesses one expert per step to minimize per-step FP32 intermediates. Sharding constraints inside the JIT makeshard_moe_weightsa no-op.Benchmark (v6e-8, Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, TP=8)
Wall clock (model load to engine ready)
Per-layer requant time
XLA program reservation (
bytes_reserved)shard_map+lax.scanlax.scan, batch expertslax.scan, all experts at onceFiles changed
tpu_inference/layers/common/process_weights/moe_weights.pyshard_fp8_moe_weights_to_tpu(): shards FP8 weights onto TPU with expert-dimension sharding before requantprocess_fp8_moe_weights(): usesshard_map+lax.scanto process experts with per-device local shapes, applieswith_sharding_constraintso output matches target sharding_get_moe_weight_shardings(): shared by bothshard_moe_weightsandprocess_fp8_moe_weightsto avoid duplicating sharding specstpu_inference/layers/vllm/quantization/fp8.py: callshard_fp8_moe_weights_to_tpubeforeprocess_fp8_moe_weightsinVllmFp8MoEMethod.process_weights_after_loadingtpu_inference/layers/jax/quantization/fp8.py: same change for JAX code path.cpu_mesh_context()wraps only concatenation; shard + requant runs on TPUTest plan
pytest tests/layers/vllm/test_fp8.py::test_fused_moe: 48 passed, 16 skippedpytest tests/layers/jax/quantization/test_fp8.py::TestFp8FusedMoE: 24 passed, 8 skippedChecklist
Previous PR description (lax.scan approach, before shard_map)
Summary
process_fp8_moe_weights(which is@jax.jit) runs on TPU instead of CPUjax.lax.scanwith memory-budget-based batch sizing to process experts in small batches, reducing peak HBM and XLA compilation overheadshard_moe_weightsbecomes a no-op post-requantApproach
Current: CPU load FP8, CPU dequant FP32, CPU requant, transfer to TPU
New: CPU load FP8, shard FP8 across TPUs,
lax.scanbatched dequant/requant on TPU (4 experts/step), sharding constraints applied inside JITlax.scanprocesses a small batch of experts per step, so XLA only needs to reserve FP32 intermediates for that batch rather than all experts at once. The batch size is computed from a memory budget (0.5 * tp_sizefull-expert-equivalents), ensuring consistent memory usage across TP/EP configs. XLA compiles the scan body once and reuses it across iterations. Sharding constraints inside the JIT makeshard_moe_weightsa no-op.Benchmark (v6e-8, Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, TP=8)
94 MoE layers, 128 experts each.
main(CPU requant)Memory budget approach
target = max(1, int(requant_memory_budget * tp_size)), then find largest divisor ofnum_experts <= targetPrevious PR description (all-at-once approach, before lax.scan)
Summary
process_fp8_moe_weights(which is@jax.jit) runs on TPU instead of CPUApproach
Current: CPU load FP8, CPU dequant FP32, CPU requant FP8, transfer to TPU
New: CPU load FP8, shard FP8 across TPUs, TPU dequant FP32, TPU requant FP8
Note that
process_fp8_moe_weightsis already@jax.jit. By placing FP8 inputs on TPU with expert-dimension sharding before calling it, the dequant/requant dispatches to TPU automatically with SPMD parallelism. No changes needed to the requantization logic itself.By sharding FP8 weights across TPUs before requantization, no single device holds the full unsharded weight during the FP8-to-FP32-to-FP8 dequant/requant step. This is critical for large MoE models like DeepSeek V3 where the transient FP32 intermediate for a single layer's experts can exceed a single TPU's HBM.
Files changed
tpu_inference/layers/common/process_weights/moe_weights.py: addedshard_fp8_moe_weights_to_tpu()which shards allFusedMoEWeightsfields onto TPU usinggeneral_device_putwith expert-dimensionNamedSharding. Falls back to first mesh axis for meshes without an EXPERT axis.tpu_inference/layers/vllm/quantization/fp8.py: callshard_fp8_moe_weights_to_tpubeforeprocess_fp8_moe_weightsinVllmFp8MoEMethod.process_weights_after_loading.tpu_inference/layers/jax/quantization/fp8.py: restructuredFp8FusedMoEMethod.process_weights_after_loadingsocpu_mesh_context()wraps only the concatenation step.shard_fp8_moe_weights_to_tpu+process_fp8_moe_weightsrun outside it on TPU.Benchmark (v6e-8, Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, TP=8, EP)
process_fp8_moe_weightsper layershard_moe_weightsper layerMemory usage during weight loading (v6e-8, Qwen3-30B-A3B-FP8)
Peak HBM delta measured with
_MemoryPoller(from #1854) duringload_weights: