[Feature][OP] Add batch-invariant RMSNorm kernel and TP embedding Custom AR path#6749
Open
gongweibao wants to merge 4 commits intoPaddlePaddle:developfrom
Open
[Feature][OP] Add batch-invariant RMSNorm kernel and TP embedding Custom AR path#6749gongweibao wants to merge 4 commits intoPaddlePaddle:developfrom
gongweibao wants to merge 4 commits intoPaddlePaddle:developfrom
Conversation
…AR path - Add Triton-based rms_norm_batch_invariant kernel for M-invariant RMSNorm - Add linear/linear_v2 tracking wrappers in batch_invariant_mode - Route TP VocabParallelEmbedding through Custom AR instead of NCCL - Increase FD_CUSTOM_AR_MAX_SIZE_MB default from 8 to 64 - Add unit tests for RMSNorm and TP embedding invariance
|
Thanks for your contribution! |
|
gongweibao seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #6749 +/- ##
==========================================
Coverage ? 72.03%
==========================================
Files ? 392
Lines ? 53919
Branches ? 8475
==========================================
Hits ? 38839
Misses ? 12302
Partials ? 2778
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Relax bfloat16 atol from 1e-3 to 1e-2 for D=3584 in RMSNorm numerical correctness test (0.0078125 diff is expected at bfloat16 precision) - Update test_communication expected buffer size from 8MB to 64MB to match FD_CUSTOM_AR_MAX_SIZE_MB default change in envs.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Collaborator
There was a problem hiding this comment.
deterministic 的端到端单测尽量合并成一个?减少ci压力
Collaborator
Author
There was a problem hiding this comment.
这个比较小。CI 增加的有限。时间很短。
主要是跟其他的单测没有关系。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
In deterministic inference mode, batch composition changes (e.g., prefix cache hit vs miss) can cause different padding/concatenation patterns, leading to non-deterministic outputs from RMSNorm and TP all-reduce operations. This PR makes these operators M-invariant (batch-size independent) to ensure identical outputs regardless of batch composition.
Modifications
batch_invariant_ops.py): Addrms_norm_batch_invariant— a per-row Triton kernel that computes RMSNorm independently per row, eliminating cross-row reduction non-determinism.normalization.py): Whenbatch_invariant_modeis enabled, route through the Triton kernel instead of the default fused kernel.embeddings.py): In deterministic mode with TP>1, bypass Paddle's_mp_allreduce(NCCL) and use Custom All-Reduce for deterministic embedding all-reduce.envs.py): IncreaseFD_CUSTOM_AR_MAX_SIZE_MBdefault from 8 to 64 MB to accommodate larger tensors without NCCL fallback.multiquery_attention_c16_impl.cuh): Minor type adjustments in the append attention kernel.Usage or Command
Enable deterministic mode:
export FD_DETERMINISTIC_MODE=1The batch-invariant RMSNorm is automatically activated when
batch_invariant_modeis enabled during deterministic forward passes.Accuracy Tests
tests/batch_invariant/test_batch_invariance_op_rmsnorm.py: Verifies RMSNorm output is identical across different batch compositions (padding patterns).tests/e2e/4cards_cases/vocab_parallel_embedding_deterministic.py: Verifies TP embedding produces deterministic results with Custom AR path.Checklist
pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.