Skip to content

[Bugfix] Pad Marlin FP8 MoE weight dims to tile alignment under TP > 1#36807

Open
ssubhanjali wants to merge 1 commit intovllm-project:mainfrom
ssubhanjali:fix/marlin-fp8-moe-tp-tile-alignment
Open

[Bugfix] Pad Marlin FP8 MoE weight dims to tile alignment under TP > 1#36807
ssubhanjali wants to merge 1 commit intovllm-project:mainfrom
ssubhanjali:fix/marlin-fp8-moe-tp-tile-alignment

Conversation

@ssubhanjali
Copy link

@ssubhanjali ssubhanjali commented Mar 11, 2026

The Marlin kernel requires size_n % 64 == 0 (tile_n_size) and size_k % 16 == 0 (tile_k_size). When tensor-parallel sharding splits MoE weights across GPUs, per-rank dimensions can violate these constraints and cause a crash at model load time on any GPU that falls back to the Marlin FP8 MoE path (CC < 9.0: L40S, A100, A10G).

Example — Nemotron Nano 3 at TP=4 (intermediate_size=1856):
w13 gate+up: size_n = 1856 / 4 * 2 = 928... actually 464 per-shard → 464 % 64 = 16 ✗
w2 down: size_k = 1856 / 4 = 464... actually 232 per-shard → 232 % 16 = 8 ✗

This error is never triggered on Hopper+ (CC >= 9.0) because vLLM selects native FP8 MoE kernels (CUTLASS/Triton) on those GPUs and never enters the Marlin path.

Fix:

  • marlin_utils_fp8.py repack_weight(): pad size_n to the next multiple of 64 and size_k to the next multiple of 16 before calling gptq_marlin_repack.
  • marlin_utils_fp8.py permute_scales(): pad scales to match padded size_n.
  • fused_marlin_moe.py _fused_marlin_moe(): compute padded sizes, use them for GEMM calls, trim w13 padding before activation, pad intermediate activation output before w2 GEMM.

Padding with zeros is mathematically a no-op: zero weights and zero inputs contribute nothing to GEMM outputs. For already-aligned dimensions all padding amounts are zero and no tensor operations are performed.

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the bug Something isn't working label Mar 11, 2026
@mergify
Copy link

mergify bot commented Mar 11, 2026

Hi @ssubhanjali, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash in Marlin FP8 MoE with tensor parallelism by padding weight dimensions to meet kernel alignment requirements. The changes in both the weight loading and the forward pass appear to correctly implement this fix. My review includes suggestions to improve maintainability by centralizing the new padding logic and constants to avoid code duplication across different files and functions.

Comment on lines +92 to +95
# Compute the same tile-aligned padded sizes used at weight-load time.
_TILE_N, _TILE_K = 64, 16
_w13_n_padded = _w13_n + ((-_w13_n) % _TILE_N) # for w13 GEMM size_n
_N_padded = N + ((-N) % _TILE_K) # for w2 GEMM size_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Marlin tile sizes (_TILE_N, _TILE_K) and the padding calculation logic are duplicated from vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py. These constants and logic are specific to the Marlin kernel and should be centralized to ensure consistency and improve maintainability. Consider defining these constants in marlin_utils_fp8.py (or a shared constants file) and exposing a utility function to compute the padded dimensions, which can then be imported and used here.

Comment on lines +254 to +259
import torch.nn.functional as _F
_TILE_N, _TILE_K = 64, 16
_pad_n = (-size_n) % _TILE_N
_padded_size_n = size_n + _pad_n
_pad_k = (-size_k) % _TILE_K
_padded_size_k = size_k + _pad_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block for padding calculation is duplicated in the permute_scales function. To improve maintainability and avoid potential inconsistencies, this logic should be extracted into a single helper function. Additionally, the import torch.nn.functional as _F is a local import; it's better practice to have it at the top of the file.

ssubhanjali added a commit to ssubhanjali/vllm that referenced this pull request Mar 11, 2026
Address reviewer feedback on PR vllm-project#36807:
- Define MARLIN_TILE_N=64 and MARLIN_TILE_K=16 as module-level constants
  in marlin_utils_fp8.py instead of duplicating magic numbers
- Extract _pad_to_marlin_tile() helper to avoid repeating padding logic
  across repack_weight() and permute_scales()
- Move torch.nn.functional import to top of marlin_utils_fp8.py
- Import MARLIN_TILE_N/K into fused_marlin_moe.py from marlin_utils_fp8
  instead of re-defining inline constants

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@mergify
Copy link

mergify bot commented Mar 11, 2026

Hi @ssubhanjali, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

ssubhanjali added a commit to ssubhanjali/vllm that referenced this pull request Mar 11, 2026
Address reviewer feedback on PR vllm-project#36807:
- Define MARLIN_TILE_N=64 and MARLIN_TILE_K=16 as module-level constants
  in marlin_utils_fp8.py instead of duplicating magic numbers
- Extract _pad_to_marlin_tile() helper to avoid repeating padding logic
  across repack_weight() and permute_scales()
- Move torch.nn.functional import to top of marlin_utils_fp8.py
- Import MARLIN_TILE_N/K into fused_marlin_moe.py from marlin_utils_fp8
  instead of re-defining inline constants

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ssubhanjali ssubhanjali force-pushed the fix/marlin-fp8-moe-tp-tile-alignment branch 2 times, most recently from 46ed086 to a0b8102 Compare March 11, 2026 17:44
The Marlin kernel requires size_n % 64 == 0 (tile_n_size) and
size_k % 16 == 0 (tile_k_size). When tensor-parallel sharding splits
MoE weights across GPUs, per-rank dimensions can violate these constraints
and cause a crash at model load time on any GPU that falls back to the
Marlin FP8 MoE path (CC < 9.0: L40S, A100, A10G).

Example — Nemotron Nano 3 at TP=4 (intermediate_size=1856):
  w13 gate+up: size_n = 464 per shard → 464 % 64 = 16 ✗
  w2  down:    size_k = 232 per shard → 232 % 16 = 8  ✗

This error is never triggered on Hopper+ (CC >= 9.0) because vLLM selects
native FP8 MoE kernels (CUTLASS/Triton) on those GPUs and never enters the
Marlin path.

Fix:
- Define MARLIN_TILE_N=64, MARLIN_TILE_K=16 and _pad_to_marlin_tile()
  helper in marlin_utils_fp8.py
- repack_weight(): pad size_n/size_k to tile boundaries before
  calling gptq_marlin_repack
- permute_scales(): pad scales to match padded size_n
- fused_marlin_moe.py _fused_marlin_moe(): import tile constants,
  compute padded sizes, use them for GEMM calls, trim w13 padding
  before activation, pad intermediate output before w2 GEMM

Padding with zeros is mathematically a no-op: zero weights and zero
inputs contribute nothing to GEMM outputs. For already-aligned
dimensions all padding amounts are zero and no operations are performed.

Tested on B200 with VLLM_TEST_FORCE_FP8_MARLIN=1 using Nemotron Nano 3
weight shapes (E=2, K=1024, N=232, W13_N=464).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Subhanjali <subhanjalis@nvidia.com>
@ssubhanjali ssubhanjali force-pushed the fix/marlin-fp8-moe-tp-tile-alignment branch from a0b8102 to 52fe813 Compare March 11, 2026 18:01
Copy link
Contributor

@alvinttang alvinttang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The padding approach is correct in principle: aligning size_n to MARLIN_TILE_N=64 and size_k to MARLIN_TILE_K=16 at weight-load time, then mirroring the same padding at inference time to keep the GEMM shape contract intact. A few things worth verifying: (1) The trim after the w13 GEMM (intermediate_cache1[:, :_w13_n].contiguous()) introduces an extra allocation — if this is on the hot path under high load, it may be worth only trimming when _w13_n_padded \!= _w13_n (which the code already guards, good). (2) The zero-padding of intermediate_cache2 before the w2 GEMM assumes that the repacked w2 weight's zero-padded columns produce zero output — this is only safe if the padded columns of the repacked weight are also zero; please confirm pack_fp8_to_int32 + gptq_marlin_repack preserve zero-padding in the repacked layout. (3) There are no tests added for the TP>1 path — a unit test parameterized over non-tile-aligned intermediate sizes (e.g., N=232) would make this much safer to maintain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants