Skip to content

[RMSNorm] remove invvar from rms_norm Python API return value#78438

Merged
DanielSun11 merged 3 commits intoPaddlePaddle:developfrom
DanielSun11:rms_norm
Apr 3, 2026
Merged

[RMSNorm] remove invvar from rms_norm Python API return value#78438
DanielSun11 merged 3 commits intoPaddlePaddle:developfrom
DanielSun11:rms_norm

Conversation

@DanielSun11
Copy link
Copy Markdown
Contributor

@DanielSun11 DanielSun11 commented Mar 23, 2026

PR Category

Operator Mechanism

PR Types

Improvements

Description

paddle.nn.functional.rms_norm 原来返回 (out, invvar) 两个 Tensor,其中 invvar(每行的逆标准差)是内部中间量,对外部用户没有实际意义。为了和torch对齐,本 PR 在 Python API 层面将其隐藏,使函数只返回归一化结果 out

主要改动:

  • python/paddle/nn/functional/norm.py:返回类型由 tuple[Tensor, Tensor] 改为 Tensor,动态图/PIR 模式下取 _C_ops.rms_norm(...)[0],静态图模式下只返回 out;同时将 eps 默认值由 1e-5 改为 实际计算时的数据类型的最小精度eps
  • paddle/phi/ops/yaml/ops.yaml:同步将 rms_norm op 的 epsilon 默认值改为 fp32的最小精度
  • test/legacy_test/test_rms_norm_op.pyOpTestpython_api wrapper 改为直接调 _C_ops.rms_norm 以保留 op 层双输出供框架校验;TestRMSNormAPI 中移除对 invvar 的解包和数值断言

是否引起精度变化

…ate default eps

- rms_norm now returns only `out` (Tensor) instead of `(out, invvar)` tuple
- Updated return type annotation from `tuple[Tensor, Tensor]` to `Tensor`
- Changed default eps from 1e-5 to 0.0 in both norm.py and ops.yaml
- Updated unit tests: rms_norm_wrapper now calls _C_ops directly to preserve
  op-level invvar output for OpTest framework; removed invvar assertion in TestRMSNormAPI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 23, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 23, 2026
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@ae907b8). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #78438   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         1           
  Lines              ?         6           
  Branches           ?         0           
===========================================
  Hits               ?         6           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 changed the title [RMSNorm] remove invvar from rms_norm Python API return value and upd… [RMSNorm] remove invvar from rms_norm Python API return value Apr 3, 2026
Copy link
Copy Markdown
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DanielSun11 DanielSun11 merged commit 36eb239 into PaddlePaddle:develop Apr 3, 2026
283 of 297 checks passed
ShigureNyako pushed a commit to ShigureNyako/Paddle that referenced this pull request Apr 3, 2026
…Paddle#78438)

* [RMSNorm] remove invvar from rms_norm Python API return value and update default eps

- rms_norm now returns only `out` (Tensor) instead of `(out, invvar)` tuple
- Updated return type annotation from `tuple[Tensor, Tensor]` to `Tensor`
- Changed default eps from 1e-5 to 0.0 in both norm.py and ops.yaml
- Updated unit tests: rms_norm_wrapper now calls _C_ops directly to preserve
  op-level invvar output for OpTest framework; removed invvar assertion in TestRMSNormAPI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* default eps is none

* fix ut

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
@ShigureNyako
Copy link
Copy Markdown
Contributor

✅ Cherry-pick successful! Created PR: #78578

liuhao2638 pushed a commit to liuhao2638/Paddle that referenced this pull request Apr 7, 2026
…Paddle#78438)

* [RMSNorm] remove invvar from rms_norm Python API return value and update default eps

- rms_norm now returns only `out` (Tensor) instead of `(out, invvar)` tuple
- Updated return type annotation from `tuple[Tensor, Tensor]` to `Tensor`
- Changed default eps from 1e-5 to 0.0 in both norm.py and ops.yaml
- Updated unit tests: rms_norm_wrapper now calls _C_ops directly to preserve
  op-level invvar output for OpTest framework; removed invvar assertion in TestRMSNormAPI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* default eps is none

* fix ut

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
DanielSun11 added a commit to DanielSun11/Paddle that referenced this pull request Apr 8, 2026
…Paddle#78438)

* [RMSNorm] remove invvar from rms_norm Python API return value and update default eps

- rms_norm now returns only `out` (Tensor) instead of `(out, invvar)` tuple
- Updated return type annotation from `tuple[Tensor, Tensor]` to `Tensor`
- Changed default eps from 1e-5 to 0.0 in both norm.py and ops.yaml
- Updated unit tests: rms_norm_wrapper now calls _C_ops directly to preserve
  op-level invvar output for OpTest framework; removed invvar assertion in TestRMSNormAPI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* default eps is none

* fix ut

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
sneaxiy pushed a commit that referenced this pull request Apr 9, 2026
…value (#78608)

* [RMSNorm] remove invvar from rms_norm Python API return value (#78438)

* [RMSNorm] remove invvar from rms_norm Python API return value and update default eps

- rms_norm now returns only `out` (Tensor) instead of `(out, invvar)` tuple
- Updated return type annotation from `tuple[Tensor, Tensor]` to `Tensor`
- Changed default eps from 1e-5 to 0.0 in both norm.py and ops.yaml
- Updated unit tests: rms_norm_wrapper now calls _C_ops directly to preserve
  op-level invvar output for OpTest framework; removed invvar assertion in TestRMSNormAPI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* default eps is none

* fix ut

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix doc

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants