Skip to content

add an option to use aten for layernorm#1115

Open
ftynse wants to merge 3 commits into
mainfrom
users/ftynse/use-aten
Open

add an option to use aten for layernorm#1115
ftynse wants to merge 3 commits into
mainfrom
users/ftynse/use-aten

Conversation

@ftynse

@ftynse ftynse commented Aug 6, 2025

Copy link
Copy Markdown
Contributor

Add a switch to all implementations whether to use the aten native layer norm functions or an equivalent implementation from primitives. The latter is sometimes more efficient when processed by the compiler. Add a test to ensure the implementation from primitives matches the aten version.

@rkayaith rkayaith left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, this fell off my radar. Just some minor comments/questions.

shift = input - mean
var = (shift**2).mean(dim=self.normalized_dim, keepdim=True)
rstd = torch.rsqrt(var + self.eps)
# rstd = torch.rsqrt(input.to(self.forwarded_args_dtype).var(dim=self.normalized_dim, keepdim=True) + self.eps)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is leaving this in intended?

"bias": bias,
"dtype": dtype,
"forwarded_args_dtype": forwarded_dtype,
"forwarded_args_dtype": forwarded_dtype if not use_aten else None,

@rkayaith rkayaith Aug 14, 2025

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand this correctly, there's three variants we want to try:

  • aten
  • primitives+forwarded_dtype=f32
  • primitives+forwarded_dtype=f64

can you make this explicit in the test parametrization, and avoid the use_aten logic inside the test:

@pytest.mark.parametrize("use_aten, forwarded_dtype", [(True, None), (False, torch.float32), (False, torch.float64)])

(also is forwarded_dtype ignored when we use aten? should we have a check somewhere that raises an error if someone tries to use them together?)

Add an op export for combined computaiton of all gradients in layer
norm. This may be more efficient than executing them one by one in some
cases and requires separate testing.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
- plumb `use_aten` through the driver
- use bitmask-style values for the mode enum

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Add a switch to all implementations whether to use the aten native layer
norm functions or an equivalent implementation from primitives. The
latter is sometimes more efficient when processed by the compiler. Add a
test to ensure the implementation from primitives matches the aten
version.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse force-pushed the users/ftynse/combined-layernorm-backward branch from 6993e80 to 8581686 Compare September 2, 2025 12:36
@ftynse ftynse force-pushed the users/ftynse/use-aten branch from 4a5ac3f to 3ffb389 Compare September 2, 2025 12:36
@ftynse ftynse force-pushed the users/ftynse/combined-layernorm-backward branch from 8581686 to d39f7c1 Compare September 24, 2025 12:20
Base automatically changed from users/ftynse/combined-layernorm-backward to main September 24, 2025 13:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants