[Bugfix] Pad Marlin FP8 MoE weight dims to tile alignment under TP > 1#36807
[Bugfix] Pad Marlin FP8 MoE weight dims to tile alignment under TP > 1#36807ssubhanjali wants to merge 1 commit intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
|
Hi @ssubhanjali, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
There was a problem hiding this comment.
Code Review
This pull request addresses a crash in Marlin FP8 MoE with tensor parallelism by padding weight dimensions to meet kernel alignment requirements. The changes in both the weight loading and the forward pass appear to correctly implement this fix. My review includes suggestions to improve maintainability by centralizing the new padding logic and constants to avoid code duplication across different files and functions.
| # Compute the same tile-aligned padded sizes used at weight-load time. | ||
| _TILE_N, _TILE_K = 64, 16 | ||
| _w13_n_padded = _w13_n + ((-_w13_n) % _TILE_N) # for w13 GEMM size_n | ||
| _N_padded = N + ((-N) % _TILE_K) # for w2 GEMM size_k |
There was a problem hiding this comment.
The Marlin tile sizes (_TILE_N, _TILE_K) and the padding calculation logic are duplicated from vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py. These constants and logic are specific to the Marlin kernel and should be centralized to ensure consistency and improve maintainability. Consider defining these constants in marlin_utils_fp8.py (or a shared constants file) and exposing a utility function to compute the padded dimensions, which can then be imported and used here.
| import torch.nn.functional as _F | ||
| _TILE_N, _TILE_K = 64, 16 | ||
| _pad_n = (-size_n) % _TILE_N | ||
| _padded_size_n = size_n + _pad_n | ||
| _pad_k = (-size_k) % _TILE_K | ||
| _padded_size_k = size_k + _pad_k |
There was a problem hiding this comment.
This block for padding calculation is duplicated in the permute_scales function. To improve maintainability and avoid potential inconsistencies, this logic should be extracted into a single helper function. Additionally, the import torch.nn.functional as _F is a local import; it's better practice to have it at the top of the file.
Address reviewer feedback on PR vllm-project#36807: - Define MARLIN_TILE_N=64 and MARLIN_TILE_K=16 as module-level constants in marlin_utils_fp8.py instead of duplicating magic numbers - Extract _pad_to_marlin_tile() helper to avoid repeating padding logic across repack_weight() and permute_scales() - Move torch.nn.functional import to top of marlin_utils_fp8.py - Import MARLIN_TILE_N/K into fused_marlin_moe.py from marlin_utils_fp8 instead of re-defining inline constants Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Hi @ssubhanjali, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Address reviewer feedback on PR vllm-project#36807: - Define MARLIN_TILE_N=64 and MARLIN_TILE_K=16 as module-level constants in marlin_utils_fp8.py instead of duplicating magic numbers - Extract _pad_to_marlin_tile() helper to avoid repeating padding logic across repack_weight() and permute_scales() - Move torch.nn.functional import to top of marlin_utils_fp8.py - Import MARLIN_TILE_N/K into fused_marlin_moe.py from marlin_utils_fp8 instead of re-defining inline constants Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
46ed086 to
a0b8102
Compare
The Marlin kernel requires size_n % 64 == 0 (tile_n_size) and size_k % 16 == 0 (tile_k_size). When tensor-parallel sharding splits MoE weights across GPUs, per-rank dimensions can violate these constraints and cause a crash at model load time on any GPU that falls back to the Marlin FP8 MoE path (CC < 9.0: L40S, A100, A10G). Example — Nemotron Nano 3 at TP=4 (intermediate_size=1856): w13 gate+up: size_n = 464 per shard → 464 % 64 = 16 ✗ w2 down: size_k = 232 per shard → 232 % 16 = 8 ✗ This error is never triggered on Hopper+ (CC >= 9.0) because vLLM selects native FP8 MoE kernels (CUTLASS/Triton) on those GPUs and never enters the Marlin path. Fix: - Define MARLIN_TILE_N=64, MARLIN_TILE_K=16 and _pad_to_marlin_tile() helper in marlin_utils_fp8.py - repack_weight(): pad size_n/size_k to tile boundaries before calling gptq_marlin_repack - permute_scales(): pad scales to match padded size_n - fused_marlin_moe.py _fused_marlin_moe(): import tile constants, compute padded sizes, use them for GEMM calls, trim w13 padding before activation, pad intermediate output before w2 GEMM Padding with zeros is mathematically a no-op: zero weights and zero inputs contribute nothing to GEMM outputs. For already-aligned dimensions all padding amounts are zero and no operations are performed. Tested on B200 with VLLM_TEST_FORCE_FP8_MARLIN=1 using Nemotron Nano 3 weight shapes (E=2, K=1024, N=232, W13_N=464). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Subhanjali <subhanjalis@nvidia.com>
a0b8102 to
52fe813
Compare
alvinttang
left a comment
There was a problem hiding this comment.
The padding approach is correct in principle: aligning size_n to MARLIN_TILE_N=64 and size_k to MARLIN_TILE_K=16 at weight-load time, then mirroring the same padding at inference time to keep the GEMM shape contract intact. A few things worth verifying: (1) The trim after the w13 GEMM (intermediate_cache1[:, :_w13_n].contiguous()) introduces an extra allocation — if this is on the hot path under high load, it may be worth only trimming when _w13_n_padded \!= _w13_n (which the code already guards, good). (2) The zero-padding of intermediate_cache2 before the w2 GEMM assumes that the repacked w2 weight's zero-padded columns produce zero output — this is only safe if the padded columns of the repacked weight are also zero; please confirm pack_fp8_to_int32 + gptq_marlin_repack preserve zero-padding in the repacked layout. (3) There are no tests added for the TP>1 path — a unit test parameterized over non-tile-aligned intermediate sizes (e.g., N=232) would make this much safer to maintain.
The Marlin kernel requires
size_n % 64 == 0(tile_n_size) andsize_k % 16 == 0(tile_k_size). When tensor-parallel sharding splits MoE weights across GPUs, per-rank dimensions can violate these constraints and cause a crash at model load time on any GPU that falls back to the Marlin FP8 MoE path (CC < 9.0: L40S, A100, A10G).Example — Nemotron Nano 3 at TP=4 (intermediate_size=1856):
w13 gate+up: size_n = 1856 / 4 * 2 = 928... actually 464 per-shard → 464 % 64 = 16 ✗
w2 down: size_k = 1856 / 4 = 464... actually 232 per-shard → 232 % 16 = 8 ✗
This error is never triggered on Hopper+ (CC >= 9.0) because vLLM selects native FP8 MoE kernels (CUTLASS/Triton) on those GPUs and never enters the Marlin path.
Fix:
Padding with zeros is mathematically a no-op: zero weights and zero inputs contribute nothing to GEMM outputs. For already-aligned dimensions all padding amounts are zero and no tensor operations are performed.
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.