Skip to content

[Feature][OP] Add batch-invariant RMSNorm kernel and TP embedding Custom AR path#6749

Open
gongweibao wants to merge 4 commits intoPaddlePaddle:developfrom
gongweibao:pr/batch-invariant-ops
Open

[Feature][OP] Add batch-invariant RMSNorm kernel and TP embedding Custom AR path#6749
gongweibao wants to merge 4 commits intoPaddlePaddle:developfrom
gongweibao:pr/batch-invariant-ops

Conversation

@gongweibao
Copy link
Collaborator

@gongweibao gongweibao commented Mar 10, 2026

Motivation

In deterministic inference mode, batch composition changes (e.g., prefix cache hit vs miss) can cause different padding/concatenation patterns, leading to non-deterministic outputs from RMSNorm and TP all-reduce operations. This PR makes these operators M-invariant (batch-size independent) to ensure identical outputs regardless of batch composition.

Modifications

  • Triton RMSNorm kernel (batch_invariant_ops.py): Add rms_norm_batch_invariant — a per-row Triton kernel that computes RMSNorm independently per row, eliminating cross-row reduction non-determinism.
  • RMSNorm routing (normalization.py): When batch_invariant_mode is enabled, route through the Triton kernel instead of the default fused kernel.
  • TP Embedding Custom AR (embeddings.py): In deterministic mode with TP>1, bypass Paddle's _mp_allreduce (NCCL) and use Custom All-Reduce for deterministic embedding all-reduce.
  • Custom AR buffer size (envs.py): Increase FD_CUSTOM_AR_MAX_SIZE_MB default from 8 to 64 MB to accommodate larger tensors without NCCL fallback.
  • CUDA kernel fix (multiquery_attention_c16_impl.cuh): Minor type adjustments in the append attention kernel.
  • Unit tests: Add RMSNorm batch-invariance test and TP embedding deterministic test.

Usage or Command

Enable deterministic mode:

export FD_DETERMINISTIC_MODE=1

The batch-invariant RMSNorm is automatically activated when batch_invariant_mode is enabled during deterministic forward passes.

Accuracy Tests

  • tests/batch_invariant/test_batch_invariance_op_rmsnorm.py: Verifies RMSNorm output is identical across different batch compositions (padding patterns).
  • tests/e2e/4cards_cases/vocab_parallel_embedding_deterministic.py: Verifies TP embedding produces deterministic results with Custom AR path.

Checklist

  • Add at least a tag in the PR title.
  • Format your code, run pre-commit before commit.
  • Add unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

…AR path

- Add Triton-based rms_norm_batch_invariant kernel for M-invariant RMSNorm
- Add linear/linear_v2 tracking wrappers in batch_invariant_mode
- Route TP VocabParallelEmbedding through Custom AR instead of NCCL
- Increase FD_CUSTOM_AR_MAX_SIZE_MB default from 8 to 64
- Add unit tests for RMSNorm and TP embedding invariance
@paddle-bot
Copy link

paddle-bot bot commented Mar 10, 2026

Thanks for your contribution!

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


gongweibao seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@codecov-commenter
Copy link

codecov-commenter commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 36.58537% with 26 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@28f7727). Learn more about missing BASE report.

Files with missing lines Patch % Lines
.../layers/batch_invariant_ops/batch_invariant_ops.py 39.28% 17 Missing ⚠️
fastdeploy/model_executor/layers/embeddings.py 28.57% 4 Missing and 1 partial ⚠️
fastdeploy/model_executor/layers/normalization.py 33.33% 3 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #6749   +/-   ##
==========================================
  Coverage           ?   72.03%           
==========================================
  Files              ?      392           
  Lines              ?    53919           
  Branches           ?     8475           
==========================================
  Hits               ?    38839           
  Misses             ?    12302           
  Partials           ?     2778           
Flag Coverage Δ
GPU 72.03% <36.58%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Relax bfloat16 atol from 1e-3 to 1e-2 for D=3584 in RMSNorm numerical
  correctness test (0.0078125 diff is expected at bfloat16 precision)
- Update test_communication expected buffer size from 8MB to 64MB to match
  FD_CUSTOM_AR_MAX_SIZE_MB default change in envs.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
gongweibao and others added 2 commits March 11, 2026 09:27
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gongweibao gongweibao requested review from SigureMo and gongshaotian and removed request for gongshaotian March 11, 2026 01:34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deterministic 的端到端单测尽量合并成一个?减少ci压力

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个比较小。CI 增加的有限。时间很短。
主要是跟其他的单测没有关系。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants