Skip to content

Move FP8 MoE weight requantization from CPU to TPU#1842

Open
rohan-reddy wants to merge 3 commits intovllm-project:mainfrom
rohan-reddy:model-loading
Open

Move FP8 MoE weight requantization from CPU to TPU#1842
rohan-reddy wants to merge 3 commits intovllm-project:mainfrom
rohan-reddy:model-loading

Conversation

@rohan-reddy
Copy link
Contributor

@rohan-reddy rohan-reddy commented Mar 3, 2026

Summary

  • Shard FP8 MoE weights onto TPU before requantization so process_fp8_moe_weights (which is @jax.jit) runs on TPU instead of CPU
  • Use jax.experimental.shard_map + jax.lax.scan to process experts with per-device local shapes, minimizing XLA program reservation
  • Apply sharding constraints inside JIT so shard_moe_weights becomes a no-op post-requant
  • Applies to both vLLM and JAX code paths

Approach

Current: CPU load FP8, CPU dequant FP32, CPU requant, transfer to TPU

New: CPU load FP8, shard FP8 across TPUs, shard_map + lax.scan batched dequant/requant on TPU (1 expert/step), sharding constraints applied inside JIT

shard_map wraps the requant + process_moe_weights so XLA compiles with per-device local shapes (e.g., [16, ...] instead of [128, ...]), significantly reducing XLA program reservation (bytes_reserved). Inside shard_map, lax.scan processes one expert per step to minimize per-step FP32 intermediates. Sharding constraints inside the JIT make shard_moe_weights a no-op.

Benchmark (v6e-8, Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, TP=8)

python examples/offline_inference.py \
    --model Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 \
    --tensor-parallel-size 8 \
    --max-model-len 1024 --max-num-batched-tokens 128

Wall clock (model load to engine ready)

Main (CPU requant) This PR (TP) This PR (EP)
Total wall clock ~10 min ~6 min ~5 min

Per-layer requant time

Main (CPU requant) This PR Speedup
Per layer (requant only) ~2.0s ~0.05s 40x
All 64 MoE layers (requant only) ~128s ~3.2s 40x

XLA program reservation (bytes_reserved)

Approach XLA reservation/chip
Main (CPU requant) 0 MiB
This PR: TPU, shard_map + lax.scan 384 MiB
TPU, lax.scan, batch experts 739 MiB
TPU, no lax.scan, all experts at once 1921 MiB

Files changed

  1. tpu_inference/layers/common/process_weights/moe_weights.py

    • Added shard_fp8_moe_weights_to_tpu(): shards FP8 weights onto TPU with expert-dimension sharding before requant
    • Refactored process_fp8_moe_weights(): uses shard_map + lax.scan to process experts with per-device local shapes, applies with_sharding_constraint so output matches target sharding
    • Extracted _get_moe_weight_shardings(): shared by both shard_moe_weights and process_fp8_moe_weights to avoid duplicating sharding specs
  2. tpu_inference/layers/vllm/quantization/fp8.py: call shard_fp8_moe_weights_to_tpu before process_fp8_moe_weights in VllmFp8MoEMethod.process_weights_after_loading

  3. tpu_inference/layers/jax/quantization/fp8.py: same change for JAX code path. cpu_mesh_context() wraps only concatenation; shard + requant runs on TPU

Test plan

  • pytest tests/layers/vllm/test_fp8.py::test_fused_moe: 48 passed, 16 skipped
  • pytest tests/layers/jax/quantization/test_fp8.py::TestFp8FusedMoE: 24 passed, 8 skipped
  • Qwen3-235B-A22B-FP8, TP=8 on v6e-8
  • Qwen3-235B-A22B-FP8, TP=8, EP on v6e-8
  • Large model validation (DeepSeek-R1)

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.
Previous PR description (lax.scan approach, before shard_map)

Summary

  • Shard FP8 MoE weights onto TPU before requantization so process_fp8_moe_weights (which is @jax.jit) runs on TPU instead of CPU
  • Use jax.lax.scan with memory-budget-based batch sizing to process experts in small batches, reducing peak HBM and XLA compilation overhead
  • Apply sharding constraints inside JIT so shard_moe_weights becomes a no-op post-requant
  • Applies to both vLLM and JAX code paths

Approach

Current: CPU load FP8, CPU dequant FP32, CPU requant, transfer to TPU

New: CPU load FP8, shard FP8 across TPUs, lax.scan batched dequant/requant on TPU (4 experts/step), sharding constraints applied inside JIT

lax.scan processes a small batch of experts per step, so XLA only needs to reserve FP32 intermediates for that batch rather than all experts at once. The batch size is computed from a memory budget (0.5 * tp_size full-expert-equivalents), ensuring consistent memory usage across TP/EP configs. XLA compiles the scan body once and reuses it across iterations. Sharding constraints inside the JIT make shard_moe_weights a no-op.

Benchmark (v6e-8, Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, TP=8)

94 MoE layers, 128 experts each.

Metric main (CPU requant) All-at-once TPU lax.scan TPU
Requant method CPU-side, then transfer All 128 experts at once on TPU lax.scan, 4 experts/step on TPU
Weight loading time 7m24s 5m56s 3m39s
Final HBM/chip 27.87 GiB 27.87 GiB 27.88 GiB
Requant peak blowup (layer 1) N/A 484.25 MiB/chip 298.85 MiB/chip
Requant peak blowup (steady state) N/A 299.71 MiB/chip 290.19 MiB/chip
FP32 overhead per layer N/A ~9.5 MiB/chip ~0 MiB/chip
shard_fp8_to_tpu (per layer) N/A ~1.2s ~1.1s
process_fp8_moe_weights (per layer) N/A ~0.08s ~0.08s
shard_moe_weights (per layer) N/A ~1.7s 0.00s (no-op)

Memory budget approach

  • Formula: target = max(1, int(requant_memory_budget * tp_size)), then find largest divisor of num_experts <= target
  • For TP=8: target=4, scan_batch_size=4 (128 experts, 4 divides evenly)
  • Gives consistent FP32 memory usage regardless of TP/EP config
Previous PR description (all-at-once approach, before lax.scan)

Summary

  • Shard FP8 MoE weights onto TPU before requantization so process_fp8_moe_weights (which is @jax.jit) runs on TPU instead of CPU
  • Each TPU holds only its expert shard during dequant/requant, reducing risk of OOM on any single device
  • Applies to both vLLM and JAX code paths

Approach

Current: CPU load FP8, CPU dequant FP32, CPU requant FP8, transfer to TPU

New: CPU load FP8, shard FP8 across TPUs, TPU dequant FP32, TPU requant FP8

Note that process_fp8_moe_weights is already @jax.jit. By placing FP8 inputs on TPU with expert-dimension sharding before calling it, the dequant/requant dispatches to TPU automatically with SPMD parallelism. No changes needed to the requantization logic itself.

By sharding FP8 weights across TPUs before requantization, no single device holds the full unsharded weight during the FP8-to-FP32-to-FP8 dequant/requant step. This is critical for large MoE models like DeepSeek V3 where the transient FP32 intermediate for a single layer's experts can exceed a single TPU's HBM.

Files changed

  1. tpu_inference/layers/common/process_weights/moe_weights.py: added shard_fp8_moe_weights_to_tpu() which shards all FusedMoEWeights fields onto TPU using general_device_put with expert-dimension NamedSharding. Falls back to first mesh axis for meshes without an EXPERT axis.

  2. tpu_inference/layers/vllm/quantization/fp8.py: call shard_fp8_moe_weights_to_tpu before process_fp8_moe_weights in VllmFp8MoEMethod.process_weights_after_loading.

  3. tpu_inference/layers/jax/quantization/fp8.py: restructured Fp8FusedMoEMethod.process_weights_after_loading so cpu_mesh_context() wraps only the concatenation step. shard_fp8_moe_weights_to_tpu + process_fp8_moe_weights run outside it on TPU.

Benchmark (v6e-8, Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, TP=8, EP)

Metric Main (CPU) All-at-once TPU Speedup
process_fp8_moe_weights per layer ~0.22s ~0.02s 11x
First layer (includes JIT compile) ~0.48s 0.90s expected
shard_moe_weights per layer ~0.02s 0.00s no-op
Total model load time 43.02s 34.02s 21% faster

Memory usage during weight loading (v6e-8, Qwen3-30B-A3B-FP8)

Peak HBM delta measured with _MemoryPoller (from #1854) during load_weights:

Test case Main (CPU requant) All-at-once TPU
PP=0, PP_world=1 0.005 GB 1.168 GB
PP=0, PP_world=4 1.159 GB 1.159 GB
PP=1, PP_world=4 0.188 GB 0.939 GB
PP=3, PP_world=4 0.001 GB 0.001 GB

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 4, 2026

awesome! thanks for working on this. for context, one of the reason we've switched to cpu was that we sometimes saw a sharding / layout issue that causes a crash or oom in certain edge cases. can you try testing this code in various different configs - different models, different sharding configs (tp, ep, dp) - and verify it's working correctly for all use cases?

@lk-chen, can you also review this pr?

@kyuyeunk kyuyeunk added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 4, 2026
@lk-chen
Copy link
Collaborator

lk-chen commented Mar 4, 2026

Thanks Rohan, the weight loading was intentionally put on CPU to avoid OOM. Could you check if #1848 can be finished before change weight loading logic here?

Just FYI, if we do observe unnecessary device usage, https://docs.jax.dev/en/latest/notebooks/host-offloading.html might be helpful

@rohan-reddy
Copy link
Contributor Author

rohan-reddy commented Mar 4, 2026

@kyuyeunk Got it, are there any particular models I should make sure to check? If not I will just try a few different ones with different sharding configs.
Also, is it okay to test just models that can fit on v6e-8 or should I be testing with larger HW config? I don't have access to v7 in my account, and per my understanding, v6e-16 and larger have smaller individual workers than v6e-8, so I have to run with multi-host setup with Ray which may be less supported?

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 4, 2026

@kyuyeunk Got it, are there any particular models I should make sure to check? If not I will just try a few different ones with different sharding configs. Also, is it okay to test just models that can fit on v6e-8 or should I be testing with larger HW config? I don't have access to v7 in my account, and per my understanding, v6e-16 and larger have smaller individual workers than v6e-8, so I have to run with multi-host setup with Ray which may be less supported?

using v6e-8 should be okay! i think any model config that makes the last dim of a sharded weight become non-divisible by 128 should be a good test. so something like Qwen/Qwen3-30B-A3B-FP8 from your example but using tp (and not ep) will make one of moe weight last dim to be moe_intermediate_size/tp_size/128=768/8/128=0.75 so this might be a good candidate.

@rohan-reddy
Copy link
Contributor Author

@lk-chen I updated the description with data recorded by the weight memory monitor utility, it shows some spikes but max peak delta of still 1.2 GB.

For a large model, I assume the user would be running on larger hardware config too, so the per-device transient shouldn't blow up too much, as it would be split up among more devices. In previous implementation of the TPU requantization, if I read correctly, we were pushing the whole weight to one device and then redistribute, which OOM'd?

Does the CI include some tests that can inform us about the result of a large model? I don't think I have requisite hardware to try myself.

@kyuyeunk Also updated the description with more sharding config tests including the one you suggested

@lk-chen
Copy link
Collaborator

lk-chen commented Mar 5, 2026

Does the CI include some tests that can inform us about the result of a large model?

No, the tests in CI are intended to be small and efficient.

Most OOM issue I encountered is DeepSeek-R1, which is 685B and barely fit on Ironwood (192*4=768GiB). And any unnecessary HBM usage will negatively impact concurrency, which impacts overall throughput.

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 5, 2026

Does the CI include some tests that can inform us about the result of a large model?

No, the tests in CI are intended to be small and efficient.

Most OOM issue I encountered is DeepSeek-R1, which is 685B and barely fit on Ironwood (192*4=768GiB). And any unnecessary HBM usage will negatively impact concurrency, which impacts overall throughput.

I'll do some testing of my own by patching this commit & see if it works correctly. at very least, we can introduce an option that's disabled by default so users can enabled it only when the model is small and they know it won't cause oom.

@rohan-reddy
Copy link
Contributor Author

@kyuyeunk I also have some prototype working now (not yet in this PR) of doing expert-batched requant on TPU. This way the XLA op won't reserve the whole FP32 space for the shard at once. Example: if there are 16 experts on a shard, this PR's implementation would reserve the FP32 space for all experts, but with batch size = 1, we would need only 1/16 the extra space at any one time. This would use much less HBM, but problem is it adds some complexity as the user should probably configure it based on model & hardware combo they're running on. Large batch size speeds up the quantization but could still bloat HBM for models with larger experts. Unless we hardcode batch size of 1. Then I think the runtime won't be much better for small models but could be better for large models, and safety is pretty good. Is that a fruitful direction to explore?

@lk-chen
Copy link
Collaborator

lk-chen commented Mar 5, 2026

Does the CI include some tests that can inform us about the result of a large model?

No, the tests in CI are intended to be small and efficient.
Most OOM issue I encountered is DeepSeek-R1, which is 685B and barely fit on Ironwood (192*4=768GiB). And any unnecessary HBM usage will negatively impact concurrency, which impacts overall throughput.

I'll do some testing of my own by patching this commit & see if it works correctly. at very least, we can introduce an option that's disabled by default so users can enabled it only when the model is small and they know it won't cause oom.

@kyuyeunk not only OOM. If the peak HBM usage is not release after weight loading, less HBM can be allocated for KV cache, which impacts concurrency/throughput.

@rohan-reddy
Copy link
Contributor Author

@lk-chen I included the logs ("Memory monitor logs - this PR"), it also shows before and after, and the after is always the same or reduced compared to before. So the peak usage is released by the end of the weight loading

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 5, 2026

Tested with qwen coder 480 fp8 (482GB) on singlehost v7x (768GB). Confirmed that there is no OOM and saw model loading time speed up.

before

(EngineCore_DP0 pid=2306516) INFO 03-05 10:04:21 [vllm_model_wrapper.py:209] Total time to load model weights from storage to TPU: 506.42 seconds.

after

(EngineCore_DP0 pid=2524225) INFO 03-05 22:54:47 [vllm_model_wrapper.py:209] Total time to load model weights from storage to TPU: 314.49 seconds.

however, this is still relatively small model compared to v7x's hbm capacity. @lk-chen, maybe you can try it on deepseek and see how it goes?

@kyuyeunk not only OOM. If the peak HBM usage is not release after weight loading, less HBM can be allocated for KV cache, which impacts concurrency/throughput.

fair point. we can easily check this by checking this log during server launch time:

(EngineCore_DP0 pid=2436334) INFO 03-05 22:43:19 [tpu_runner.py:563] Init model | hbm=[(56.54, 94.75), (56.54, 94.75), (56.54, 94.75), (56.54, 94.75), (56.54, 94.75), (56.54, 94.75), (56.54, 94.75), (56.54, 94.75)]GiB
(EngineCore_DP0 pid=2436334) WARNING 03-05 22:43:19 [kv_cache_manager.py:189] Compilation num_layers = 62
(EngineCore_DP0 pid=2436334) INFO 03-05 22:43:19 [tpu_worker.py:313] Memory statistics | total_hbm_limit_gb=757.97GiB | total_hbm_limit_cap_gb=682.17GiB | total_hbm_used_gb=452.31GiB | total_hbm_avail_gb=229.86GiB
(EngineCore_DP0 pid=2436334) INFO 03-05 22:43:19 [kv_cache_utils.py:1314] GPU KV cache size: 971,840 tokens
(EngineCore_DP0 pid=2436334) INFO 03-05 22:43:19 [kv_cache_utils.py:1319] Maximum concurrency for 10,240 tokens per request: 94.91x
(EngineCore_DP0 pid=2436334) INFO 03-05 22:43:19 [compilation_manager.py:509] Compiling sampling with different input shapes.

if total_hbm_used_gb increases in anyway with this pr, it means that it will indeed impact concurrency & throughput. @rohan-reddy, can you check this log - compare before & after this pr and check that it doesn't impact that value?

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 5, 2026

also, does qwen3 235b fp8 fit in v6e-8? if so, let's use that model instead because qwen3 30b is too small to properly stress test the hbm usage.

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 5, 2026

(will need to review in more detail, but) i think this pr is good but, only thing that i want an answer is whether this change is safe to be applied for all scenarios or do we need to make it an opt-in that is disabled by default.

@lk-chen
Copy link
Collaborator

lk-chen commented Mar 6, 2026

Deepseek-R1 failed to load on Ironwood

...
Loading safetensors checkpoint shards:  24% Completed | 39/163 [15:29<33:19, 16.13s/it]
...
ValueError: RESOURCE_EXHAUSTED: Error loading program 'jit_process_fp8_moe_weights': Attempting to reser
ve 56.00G at the bottom of memory. That was not possible. There are 55.72G free, 0B reserved, and 55.72G reservable.: while running
replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

CMD

export CURRENT_VLLM_MODEL=deepseek-ai/DeepSeek-R1
VLLM_MLA_DISABLE=0 MOE_REQUANT_BLOCK_SIZE=256 MOE_REQUANT_WEIGHT_DTYPE=fp4 USE_UNFUSED_MEGABLOCKS=0 NEW_MODEL_DESIGN=1 TPU_BACKEND_TYPE=jax \
MODEL_IMPL_TYPE=flax_nnx vllm serve $CURRENT_VLLM_MODEL \
    --gpu-memory-utilization=0.95 --max-model-len=5120 --max-num-seqs=16 --max-num-batched-tokens 128 \
    --disable-log-requests \
    --no-enable-prefix-caching \
    --tensor-parallel-size 8 \
    --additional_config='{"sharding": {"sharding_strategy": {"enable_dp_attention": true, "expert_parallelism": 1, "tensor_parallelism": 8}}, "sparse_matmul": "True"}' \
    --no-async-scheduling

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 6, 2026

@kyuyeunk I also have some prototype working now (not yet in this PR) of doing expert-batched requant on TPU. This way the XLA op won't reserve the whole FP32 space for the shard at once. Example: if there are 16 experts on a shard, this PR's implementation would reserve the FP32 space for all experts, but with batch size = 1, we would need only 1/16 the extra space at any one time. This would use much less HBM, but problem is it adds some complexity as the user should probably configure it based on model & hardware combo they're running on. Large batch size speeds up the quantization but could still bloat HBM for models with larger experts. Unless we hardcode batch size of 1. Then I think the runtime won't be much better for small models but could be better for large models, and safety is pretty good. Is that a fruitful direction to explore?

@kyuyeunk I also have some prototype working now (not yet in this PR) of doing expert-batched requant on TPU. This way the XLA op won't reserve the whole FP32 space for the shard at once. Example: if there are 16 experts on a shard, this PR's implementation would reserve the FP32 space for all experts, but with batch size = 1, we would need only 1/16 the extra space at any one time. This would use much less HBM, but problem is it adds some complexity as the user should probably configure it based on model & hardware combo they're running on. Large batch size speeds up the quantization but could still bloat HBM for models with larger experts. Unless we hardcode batch size of 1. Then I think the runtime won't be much better for small models but could be better for large models, and safety is pretty good. Is that a fruitful direction to explore?

@rohan-reddy maybe this idea can help reduce the peak hbm usage and thus avoid oom?

@rohan-reddy
Copy link
Contributor Author

I revised the approach to do expert-batched requant using lax.scan. The batching is supposed to be 1/2 of a full expert equivalent memory per iteration.
Early tests with qwen 3 235B FP8 on v6e-8 look pretty good (It barely fits as is, I believe within 2 GB of the upper HBM limit of 225 GB, so a significant memory regression would cause OOM). I posted the table in the description. In theory I think this approach should be better for large models, but I think we would have to test on Ironwood again to see the impact.

@lk-chen
Copy link
Collaborator

lk-chen commented Mar 9, 2026

Verifying commit 1583e31, loading deepseek-r1 on Ironwood

Loading safetensors checkpoint shards:  88% Completed | 143/163 [1:06:32<09:00, 27.01s/it]
...
(EngineCore_DP0 pid=3610565)   File "/home/lkchen_google_com/worktree/qwn_pp/tpu_inference/layers/jax/quantization/fp8.py", line 538
, in process_weights_after_loading
(EngineCore_DP0 pid=3610565)     weights = process_fp8_moe_weights(
(EngineCore_DP0 pid=3610565)               ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=3610565) ValueError: RESOURCE_EXHAUSTED: Error loading program 'jit_process_fp8_moe_weights': Attempting to rese
rve 3.06G at the bottom of memory. That was not possible. There are 2.81G free, 0B reserved, and 2.81G reservable.: while running re
plica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

@rohan-reddy
Copy link
Contributor Author

Thanks @lk-chen for running this!
I should have realized this when seeing little difference between batched and all-at-once, I should be using the reserved functions and not just bytes_in_use from memory stats API since OOM is occurring in this case from XLA reserving memory for the op which is not tracked as bytes in use. I may also want to add this to the weight loading memory util to check for regression in reserved memory not just bytes in use.
The reserved memory fluctuation can OOM even if it does not hurt the final available HBM available for KV cache. I'll track this now in a few tests, I'll try scan batch size=1 to compare, and then maybe a python loop instead of xla scan.

@rohan-reddy
Copy link
Contributor Author

I tracked device.memory_stats()["bytes_reserved"]. Looks like lax.scan batch_size doesn't affect the reservation. The reservation is dominated by the full input/output tensor buffers at the JIT boundary, not the scan body's transient FP32 intermediates.

With lax.scan, the entire [num_experts, ...] tensor is passed into the @jax.jit boundary. XLA is reserving memory for both the full input arrays and full output arrays, regardless of how many experts the scan body processes per step. The scan body's FP32 intermediates (~14-55 MiB depending on batch size) are noise compared to the I/O buffers (~660+ MiB for Qwen 235B).

The DeepSeek OOM ("Attempting to reserve 3.06G, 2.81G free") at layer 143/163 is consistent with this, DeepSeek has ~4x larger expert dimensions than Qwen 235B, and 3060 MiB ≈ 4 × 739 MiB. Reducing scan_batch_size would not help since the reservation is driven by I/O tensor sizes, not batch size.

Qwen3-235B-A22B-FP8, TP=8, v6e-8 (32 GiB HBM/chip)

Approach XLA reservation/chip
All-at-once (no scan, batch=128) 1921 MiB
lax.scan (batch=4) 739 MiB
lax.scan (batch=1) 737 MiB

Qwen3-30B-A3B-FP8, TP=8, v6e-8

Approach XLA reservation/chip
lax.scan (batch=4) 169 MiB

@rohan-reddy
Copy link
Contributor Author

rohan-reddy commented Mar 10, 2026

Tried a few other approaches. The winner right now appears to be wrapping the FP8 MoE requantization + weight processing in jax.experimental.shard_map so XLA compiles with per-device local shapes instead of global shapes. Will update PR shortly.

Benchmark Results (Qwen3-235B-A22B, TP=8, 128 experts):

Approach XLA reservation/chip Speed/layer Total (64 MoE layers)
All-at-once (no scan) 1921 MiB - -
lax.scan SPMD (batch=4) 739 MiB ~0.06s ~3.8s
shard_map + lax.scan (batch=4) 480 MiB ~0.05s ~3.2s
shard_map + lax.scan (batch=1) 384 MiB ~0.05s ~3.2s

Wall clock comparison of requant phase to main (disk I/O ~43s and engine init/warmup ~52s are identical):

Main (CPU requant) shard_map (TPU requant)
MoE requant (64 layers) ~4 min 40s ~2 min 18s
  • Regular jax.jit SPMD: XLA compiles with global shapes (e.g., [128, 3072, 4096]), then the SPMD partitioner splits across devices
  • shard_map: XLA compiles with per-device local shapes (e.g., [16, 3072, 4096]), resulting in smaller I/O buffers and scratch allocations

Batch size reduction also had a better impact here on memory use. Possibly because intermediates are a larger fraction of the smaller program.

Used compiled.memory_analysis() to confirm temp_size = 384.19 MiB matches bytes_reserved exactly. Buffer breakdown:

  • Scan accumulator [16, ...] FP8: ~288 MiB. Think this is irreducible and has to hold all local expert results before returning from shard_map
  • Per-step FP32 intermediates [1, ...]: ~18 MiB

I don't know how much further reducible this allocation is. I tried some other approaches including breaking down into smaller JIT calls wrapped in a Python loop, so XLA materializes a smaller tensor. But those were ineffective (details below if interested).

Other approaches

Fused scan (requant + process per expert in one scan step)

Fusing requant and process_moe_weights into a single scan body so each step fully processes one expert. Result: 384.59 MiB — essentially identical to the separate approach (384 MiB). The scan accumulator [16, ...] dominates regardless of what happens per step.

Chunked multi-JIT with per-device assembly

Breaking the computation into multiple smaller JIT calls (e.g., 4 chunks of 4 experts), then assembling results outside XLA using jax.make_array_from_single_device_arrays. Each chunk has low reservation, but shard_moe_weights downstream sees the full [128, ...] global tensor and compiles a reshard program → 1536 MiB. Fundamental problem: any approach that returns data needing resharding triggers a large program.

Python loop with per-expert JIT calls

16 separate JIT calls per device, each processing one expert ([1, ...]). Achieves the theoretical minimum reservation of 193 MiB, but each w13_weight[i:i+1] slice from an expert-sharded array triggers replication across all 8 devices. Result: ~4-5s/layer (80x slower than shard_map).

donate_argnums (with both scan and shard_map)

Donating input buffers reduces bytes_in_use (layer_delta: 290→2 MiB) but increases bytes_reserved (739→802 MiB for scan, 480→672 MiB for shard_map). XLA needs extra reservation to manage buffer donation. Counterproductive for the OOM problem.

Signed-off-by: Rohan Reddy <rreddy.nyc@gmail.com>
Signed-off-by: Rohan Reddy <rreddy.nyc@gmail.com>
Signed-off-by: Rohan Reddy <rreddy.nyc@gmail.com>
@kyuyeunk
Copy link
Collaborator

random idea: it may be possible to leverage donate_argnames (or donate_argnum) in jit to reduce memory usage: https://docs.jax.dev/en/latest/_autosummary/jax.jit.html

@lk-chen
Copy link
Collaborator

lk-chen commented Mar 11, 2026

Testing on 41ab4c6

(EngineCore_DP0 pid=367073)   File "/home/lkchen_google_com/worktree/qwn_pp/tpu_inference/layers/jax/quantization/fp8.py", line 538, in process_we
ights_after_loading
(EngineCore_DP0 pid=367073)     weights = process_fp8_moe_weights(
(EngineCore_DP0 pid=367073)               ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=367073) ValueError: RESOURCE_EXHAUSTED: Error loading program 'jit_process_fp8_moe_weights': Attempting to reserve 21.00G at t
he bottom of memory. That was not possible. There are 19.90G free, 0B reserved, and 19.90G reservable.: while running replica 0 and partition 0 of
 a replicated computation (other replicas may have failed as well).
Loading safetensors checkpoint shards:  69% Completed | 113/163 [52:20<23:09, 27.79s/it]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants