Skip to content

Commit dd11e08

Browse files
authored
[NVIDIA#6187][feat] add LayerNorm module (NVIDIA#6625)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent 81f0ded commit dd11e08

File tree

2 files changed

+145
-7
lines changed

2 files changed

+145
-7
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Optional, Tuple, Union
17+
18+
import torch
19+
from torch import nn
20+
21+
22+
class LayerNorm(nn.Module):
23+
"""Layer normalization module with configurable weight and bias parameters.
24+
25+
This implementation provides standard layer normalization with optional
26+
learnable parameters and residual connection support.
27+
28+
Args:
29+
hidden_size: The size of the hidden dimension to normalize.
30+
eps: Small constant for numerical stability.
31+
dtype: Optional data type for parameters.
32+
device: Optional device for parameters.
33+
has_weights: Whether to include learnable weight parameters.
34+
has_bias: Whether to include learnable bias parameters.
35+
"""
36+
37+
def __init__(
38+
self,
39+
*,
40+
hidden_size: int,
41+
eps: float,
42+
dtype: Optional[torch.dtype] = None,
43+
device: Optional[torch.device] = None,
44+
has_weights: bool = True,
45+
has_bias: bool = True,
46+
):
47+
super().__init__()
48+
if has_weights:
49+
self.weight = nn.Parameter(
50+
torch.ones(hidden_size, dtype=dtype, device=device))
51+
else:
52+
self.register_buffer('weight',
53+
torch.ones(hidden_size,
54+
dtype=dtype,
55+
device=device),
56+
persistent=False)
57+
if has_bias:
58+
self.bias = nn.Parameter(
59+
torch.zeros(hidden_size, dtype=dtype, device=device))
60+
else:
61+
self.register_buffer('bias',
62+
torch.zeros(hidden_size,
63+
dtype=dtype,
64+
device=device),
65+
persistent=False)
66+
self.variance_epsilon = eps
67+
68+
def forward(
69+
self,
70+
hidden_states: torch.Tensor,
71+
residual: Optional[torch.Tensor] = ...,
72+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
73+
"""Apply layer normalization to input tensor.
74+
75+
Args:
76+
hidden_states: Input tensor to normalize.
77+
residual: Optional residual tensor to add before normalization.
78+
79+
Returns:
80+
Normalized tensor, or tuple of (normalized_tensor, residual) if residual provided.
81+
"""
82+
83+
input_dtype = hidden_states.dtype
84+
hidden_states = hidden_states.to(torch.float32)
85+
if isinstance(residual, torch.Tensor):
86+
hidden_states = hidden_states + residual.to(torch.float32)
87+
residual = hidden_states.to(input_dtype)
88+
89+
hidden_states = nn.functional.layer_norm(
90+
hidden_states,
91+
hidden_states.shape[-1],
92+
weight=self.weight,
93+
bias=self.bias,
94+
eps=self.variance_epsilon,
95+
)
96+
97+
if residual is ...:
98+
return hidden_states
99+
else:
100+
return hidden_states, residual
101+
102+
def skip_forward(
103+
self,
104+
hidden_states: torch.Tensor,
105+
residual: Optional[torch.Tensor] = ...,
106+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
107+
"""Skip normalization and return inputs unchanged.
108+
109+
Args:
110+
hidden_states: Input tensor to pass through.
111+
residual: Optional residual tensor to pass through.
112+
113+
Returns:
114+
Input tensors unchanged, maintaining same signature as forward.
115+
"""
116+
117+
if residual is ...:
118+
return hidden_states
119+
else:
120+
return hidden_states, residual

tensorrt_llm/_torch/modules/rms_norm.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import enum
217
from typing import Optional, Tuple, Union
318

@@ -9,13 +24,15 @@
924

1025
class RMSNorm(nn.Module):
1126

12-
def __init__(self,
13-
*,
14-
hidden_size: int,
15-
eps: float,
16-
dtype: Optional[torch.dtype] = None,
17-
device: Optional[torch.device] = None,
18-
has_weights: bool = True):
27+
def __init__(
28+
self,
29+
*,
30+
hidden_size: int,
31+
eps: float,
32+
dtype: Optional[torch.dtype] = None,
33+
device: Optional[torch.device] = None,
34+
has_weights: bool = True,
35+
):
1936
super().__init__()
2037
if has_weights:
2138
self.weight = nn.Parameter(
@@ -48,6 +65,7 @@ def forward(
4865
if isinstance(residual, torch.Tensor):
4966
hidden_states = hidden_states + residual.to(torch.float32)
5067
residual = hidden_states.to(input_dtype)
68+
5169
variance = hidden_states.pow(2).mean(-1, keepdim=True)
5270
hidden_states = hidden_states * torch.rsqrt(variance +
5371
self.variance_epsilon)

0 commit comments

Comments
 (0)