add an option to use aten for layernorm#1115
Open
ftynse wants to merge 3 commits into
Open
Conversation
rkayaith
reviewed
Aug 14, 2025
rkayaith
left a comment
Member
There was a problem hiding this comment.
Sorry for the late review, this fell off my radar. Just some minor comments/questions.
| shift = input - mean | ||
| var = (shift**2).mean(dim=self.normalized_dim, keepdim=True) | ||
| rstd = torch.rsqrt(var + self.eps) | ||
| # rstd = torch.rsqrt(input.to(self.forwarded_args_dtype).var(dim=self.normalized_dim, keepdim=True) + self.eps) |
| "bias": bias, | ||
| "dtype": dtype, | ||
| "forwarded_args_dtype": forwarded_dtype, | ||
| "forwarded_args_dtype": forwarded_dtype if not use_aten else None, |
Member
There was a problem hiding this comment.
if I understand this correctly, there's three variants we want to try:
- aten
- primitives+forwarded_dtype=f32
- primitives+forwarded_dtype=f64
can you make this explicit in the test parametrization, and avoid the use_aten logic inside the test:
@pytest.mark.parametrize("use_aten, forwarded_dtype", [(True, None), (False, torch.float32), (False, torch.float64)])(also is forwarded_dtype ignored when we use aten? should we have a check somewhere that raises an error if someone tries to use them together?)
Add an op export for combined computaiton of all gradients in layer norm. This may be more efficient than executing them one by one in some cases and requires separate testing. Signed-off-by: Alex Zinenko <git@ozinenko.com>
- plumb `use_aten` through the driver - use bitmask-style values for the mode enum Signed-off-by: Alex Zinenko <git@ozinenko.com>
Add a switch to all implementations whether to use the aten native layer norm functions or an equivalent implementation from primitives. The latter is sometimes more efficient when processed by the compiler. Add a test to ensure the implementation from primitives matches the aten version. Signed-off-by: Alex Zinenko <git@ozinenko.com>
6993e80 to
8581686
Compare
4a5ac3f to
3ffb389
Compare
8581686 to
d39f7c1
Compare
Base automatically changed from
users/ftynse/combined-layernorm-backward
to
main
September 24, 2025 13:22
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add a switch to all implementations whether to use the aten native layer norm functions or an equivalent implementation from primitives. The latter is sometimes more efficient when processed by the compiler. Add a test to ensure the implementation from primitives matches the aten version.