-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathlayernorm.py
More file actions
522 lines (453 loc) · 17.3 KB
/
layernorm.py
File metadata and controls
522 lines (453 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
from typing import Tuple, Optional
import torch
from torch import Tensor
from torch.overrides import (
has_torch_function_unary,
handle_torch_function,
)
from atom.config import QuantizationConfig, LayerQuantConfig
from atom.utils.decorators import mark_trace
from torch import nn
from aiter import (
rmsnorm2d_fwd,
rmsnorm2d_fwd_with_add,
layernorm2d_fwd,
layernorm2d_fwd_with_add,
)
from aiter.dist.communication_op import tensor_model_parallel_fused_allreduce_rmsnorm
from aiter.dist.parallel_state import get_tensor_model_parallel_world_size
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter import (
QuantType,
)
def silu(input: Tensor, inplace: bool = False) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
The SiLU function is also known as the swish function.
.. math::
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
.. note::
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
where the SiLU was experimented with later.
See :class:`~torch.nn.SiLU` for more details.
"""
if has_torch_function_unary(input):
return handle_torch_function(silu, (input,), input, inplace=inplace)
if inplace:
return torch._C._nn.silu_(input)
return torch._C._nn.silu(input)
@torch_compile_guard()
def rmsnorm2d_fwd_(
x: torch.Tensor, weight: torch.Tensor, eps: float, dim: int
) -> torch.Tensor:
ori_shape = x.shape
x = x.reshape(-1, dim)
return rmsnorm2d_fwd(x, weight, eps).view(ori_shape)
@torch_compile_guard()
def rmsnorm2d_fwd_with_add_(
x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, eps: float, dim: int
) -> Tuple[torch.Tensor, torch.Tensor]:
ori_shape = x.shape
x = x.reshape(-1, dim)
out = torch.empty_like(x)
residual_out = torch.empty_like(x)
rmsnorm2d_fwd_with_add(out, x, residual, residual_out, weight, eps)
return out.view(ori_shape), residual_out.view(ori_shape)
def fused_rmsnorm_pad_fake_tensors(
x: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
x_pad_to_multiple: int = 0,
) -> torch.Tensor:
M, N = x.shape
N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple
out = torch.empty((M, N_out), dtype=x.dtype, device=x.device)
return out
@torch_compile_guard(gen_fake=fused_rmsnorm_pad_fake_tensors)
def fused_rmsnorm_pad_(
x: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
x_pad_to_multiple: int = 0,
) -> torch.Tensor:
return fused_add_rmsnorm_pad(x, weight, epsilon, None, x_pad_to_multiple)
def fused_add_rmsnorm_pad_fake_tensors(
x: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
res: torch.Tensor,
x_pad_to_multiple: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple
out = torch.empty((M, N_out), dtype=x.dtype, device=x.device)
res_out = torch.empty((M, N), dtype=res.dtype, device=res.device)
return out, res_out
@torch_compile_guard(gen_fake=fused_add_rmsnorm_pad_fake_tensors)
def fused_add_rmsnorm_pad_(
x: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
res: torch.Tensor,
x_pad_to_multiple: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
return fused_add_rmsnorm_pad(x, weight, epsilon, res, x_pad_to_multiple)
def mxfp4_rms_quant_fuse_fake(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
shuffle: bool = False,
res1: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
out = torch.empty((M, N // 2), dtype=torch.float4_e2m1fn_x2, device=x.device)
MXFP4_QUANT_BLOCK_SIZE = 32
SCALE_N_valid = (N + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE
use_scale_shuffle_padding = shuffle
if use_scale_shuffle_padding:
SCALE_M = ((M + 255) // 256) * 256
SCALE_N = ((SCALE_N_valid + 7) // 8) * 8
else:
SCALE_M = M
SCALE_N = SCALE_N_valid
scale = torch.empty(
(SCALE_M, SCALE_N),
dtype=torch.float8_e8m0fnu,
device=x.device,
)
out_res1 = None
if res1 is not None:
out_res1 = torch.empty_like(res1)
return (out, scale, out_res1)
# It's important to use mutates_args=[] to avoid functionized_v2 op generation
@torch_compile_guard(gen_fake=mxfp4_rms_quant_fuse_fake, mutates_args=[])
def mxfp4_rms_quant_fuse(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
shuffle: bool = False,
res1: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_mxfp4_quant import (
fused_rms_mxfp4_quant,
)
(x_quant, x_scale), _, _, residual_out = fused_rms_mxfp4_quant(
x, weight, eps, shuffle=shuffle, res1=res1
)
return x_quant, x_scale, residual_out
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
x_pad_to_multiple: int = 0,
fused_allreduce: bool = False,
fused_quant: bool = False,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.x_pad_to_multiple = x_pad_to_multiple
self.fused_allreduce = fused_allreduce
self.use_fused_quant = fused_quant
self.tp_size = get_tensor_model_parallel_world_size()
layer_quant_config = (
LayerQuantConfig()
if quant_config is None
else quant_config.global_quant_config
)
quant_type = layer_quant_config["quant_type"]
params_dtype = layer_quant_config["quant_dtype"]
self.quant_type = quant_type
self.params_dtype = params_dtype
@mark_trace(prefix="rmsnorm", torch_compile=True)
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
x_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.x_pad_to_multiple > 0:
assert (
not self.fused_allreduce
), "fused_allreduce_rmsnorm is not supported with rms_norm padding!"
if residual is None:
x = fused_rmsnorm_pad_(x, self.weight, self.eps, self.x_pad_to_multiple)
return x
else:
x, residual = fused_add_rmsnorm_pad_(
x, self.weight, self.eps, residual, self.x_pad_to_multiple
)
return x, residual
if self.fused_allreduce and self.tp_size > 1:
assert (
residual is not None
), "fused_allreduce_rmsnorm requires residual input!"
x, residual = tensor_model_parallel_fused_allreduce_rmsnorm(
x,
residual,
self.weight,
self.eps,
)
return x, residual
else:
if x_scale is not None and self.use_fused_quant:
from aiter.ops.triton.fused_fp8_quant import (
fused_rms_fp8_per_tensor_static_quant,
)
import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
# static FP8 quantization
if residual is None:
x, _, _, _ = fused_rms_fp8_per_tensor_static_quant(
x,
self.weight,
self.eps,
x_scale,
None,
None,
self.eps,
dtype_quant=rocm_aiter_fp8_dtype,
res1=None,
)
return (x, x_scale)
else:
x, _, _, residual = fused_rms_fp8_per_tensor_static_quant(
x,
self.weight,
self.eps,
x_scale,
None,
None,
self.eps,
dtype_quant=rocm_aiter_fp8_dtype,
res1=residual,
)
return (x, x_scale), residual
elif self.use_fused_quant and (
x_scale is None and self.quant_type.value == QuantType.per_1x32.value
):
if residual is None:
x, x_scale, _ = mxfp4_rms_quant_fuse(
x, self.weight, self.eps, shuffle=True
)
return x, x_scale
else:
x, x_scale, residual = mxfp4_rms_quant_fuse(
x, self.weight, self.eps, shuffle=True, res1=residual
)
return (x, x_scale), residual
else:
if residual is None:
# return rmsnorm2d_fwd(x, self.weight, self.eps).view(ori_shape)
x = rmsnorm2d_fwd_(x, self.weight, self.eps, self.dim)
return x
else:
# return self.add_rms_forward(x, residual)
x, residual = rmsnorm2d_fwd_with_add_(
x, self.weight, residual, self.eps, self.dim
)
return x, residual
class RMSNormGated(nn.Module):
"""RMS Normalization with optional gating.
This is a native PyTorch implementation that supports:
- Standard RMS normalization
- Group RMS normalization
- Optional gating with SiLU activation
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
group_size: int | None = None,
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize RMSNormGated.
Args:
hidden_size: Size of the hidden dimension
eps: Epsilon for numerical stability
group_size: If not None, do GroupNorm with each group
having group_size elements.
group_size=None is equivalent to group_size=hidden_size
(i.e. there's only 1 group).
norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
If False and z is provided: out = norm(x * silu(z))
dtype: Data type for parameters
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward_native(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
"""
Native PyTorch implementation of RMS normalization with gating.
Args:
x: Input tensor
z: Optional gating tensor
Returns:
Normalized (and optionally gated) tensor
If z is not None:
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
"""
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * silu(z)
# RMS Normalization
if self.group_size is None:
# Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)
out = x_normed * self.weight
else:
# Group RMS norm
from einops import rearrange
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps)
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * silu(z)
return out
def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
if torch.compiler.is_compiling():
return self.forward_native(x, z)
return self.forward_native(x, z)
# from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
# return rmsnorm_fn(
# x,
# self.weight,
# self.bias,
# z=z,
# eps=self.eps,
# group_size=self.group_size,
# norm_before_gate=self.norm_before_gate,
# )
def forward(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor:
return self.forward_cuda(x, z)
class GemmaRMSNorm(nn.Module):
"""RMS normalization for Gemma.
Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
@staticmethod
def forward_static(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
)
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
def forward_cuda(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile(self.forward_static) # type: ignore
self._is_compiled = True
return self.forward_native(x, residual)
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(x, residual)
@torch_compile_guard()
def layernorm2d_fwd_(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, dim: int
) -> torch.Tensor:
ori_shape = x.shape
x = x.reshape(-1, dim)
return layernorm2d_fwd(x, weight, bias, eps).view(ori_shape)
@torch_compile_guard()
def layernorm2d_fwd_with_add_(
x: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
bias: torch.Tensor,
eps: float,
dim: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
ori_shape = x.shape
x = x.reshape(-1, dim)
out = torch.empty_like(x)
residual_out = torch.empty_like(x)
layernorm2d_fwd_with_add(out, x, residual, residual_out, weight, bias, eps)
return out.view(ori_shape), residual_out.view(ori_shape)
class LayerNorm(nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return layernorm2d_fwd_(x, self.weight, self.bias, self.eps, self.dim)
else:
return layernorm2d_fwd_with_add_(
x, self.weight, residual, self.bias, self.eps, self.dim
)