Skip to content

Conversation

@mstojkovicTT
Copy link
Contributor

@mstojkovicTT mstojkovicTT commented Nov 17, 2025

Ticket

fixes: #1846
fixes: #2261

Problem description

explained in #1846

What's changed

Adding support for RMSNorm.
Also added a temporary option to turn off composite ops (mainly for debugging while this feature is not stable)

Checklist

Copy link
Contributor

@sgligorijevicTT sgligorijevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the Jax support?

@mstojkovicTT
Copy link
Contributor Author

Where is the Jax support?

This is targeting torch version of rms_norm

@mstojkovicTT mstojkovicTT force-pushed the mstojkovic/rms_norm_composite branch from 42f36c8 to cdacb10 Compare November 20, 2025 10:32
@mstojkovicTT mstojkovicTT marked this pull request as ready for review November 20, 2025 15:39
@mstojkovicTT mstojkovicTT force-pushed the mstojkovic/rms_norm_composite branch from a4f70da to e5bda4b Compare November 21, 2025 13:42
@mstojkovicTT mstojkovicTT force-pushed the mstojkovic/rms_norm_composite branch from e5bda4b to 538e6d0 Compare November 24, 2025 10:56
@tenstorrent tenstorrent deleted a comment from github-actions bot Nov 24, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Nov 24, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-XLA Tests9 ran5 passed4 skipped0 failed
TestResult
No test annotations available

Copy link

@nvukobratTT nvukobratTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments, approving

Copy link

@nvukobratTT nvukobratTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this PR should close this issue as well:

@mstojkovicTT
Copy link
Contributor Author

@hshahTT @AleksKnezevic, can you take a look at the PR so we can merge this



@pytest.mark.parametrize("use_weight", [True, False])
def test_composite_rms_norm(use_weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the names of the tests are inverted, no? As in this is the non_composite version.

Copy link
Contributor Author

@mstojkovicTT mstojkovicTT Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea is to test 2 variants, one where i call torch op that will turn composite, and the other where i call the wrapper directly. In this case, i call the composite wrapper directly, so that is why I called the test that way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

Tests example model that has a composite RMS norm operation.
"""

class RMSNormModel(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a multichip test? RMSNorm is typically replicated across devices so batch parallel is fine.

This won't be needed for each op, but our infra in handling composites is new and relatively untested, so we should add it for the first few.

@sshonTT FYI.

@mstojkovicTT mstojkovicTT changed the title [Composite] RMSNorm OP [Composite] RMSNorm OP and Enable composites by default Nov 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Lower torch gelu as composite Lower RMSNorm as composite op

5 participants