-
Notifications
You must be signed in to change notification settings - Fork 19
[Composite] RMSNorm OP and Enable composites by default #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
sgligorijevicTT
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the Jax support?
This is targeting torch version of rms_norm |
42f36c8 to
cdacb10
Compare
a4f70da to
e5bda4b
Compare
e5bda4b to
538e6d0
Compare
|
||||||||||||||
nvukobratTT
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few minor comments, approving
nvukobratTT
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, this PR should close this issue as well:
|
@hshahTT @AleksKnezevic, can you take a look at the PR so we can merge this |
|
|
||
|
|
||
| @pytest.mark.parametrize("use_weight", [True, False]) | ||
| def test_composite_rms_norm(use_weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the names of the tests are inverted, no? As in this is the non_composite version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idea is to test 2 variants, one where i call torch op that will turn composite, and the other where i call the wrapper directly. In this case, i call the composite wrapper directly, so that is why I called the test that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
| Tests example model that has a composite RMS norm operation. | ||
| """ | ||
|
|
||
| class RMSNormModel(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a multichip test? RMSNorm is typically replicated across devices so batch parallel is fine.
This won't be needed for each op, but our infra in handling composites is new and relatively untested, so we should add it for the first few.
@sshonTT FYI.
Ticket
fixes: #1846
fixes: #2261
Problem description
explained in #1846
What's changed
Adding support for
RMSNorm.Also added a temporary option to turn off composite ops (mainly for debugging while this feature is not stable)
Checklist