Skip to content

Commit c9d0ac5

Browse files
authored
[quantizer] support {LayerNorm, RMSNorm} (#390)
* [quantzier]support LayerNorm/RMSNorm * [quantzier] add doc support * [quantzier] fix test
1 parent afee089 commit c9d0ac5

File tree

5 files changed

+144
-3
lines changed

5 files changed

+144
-3
lines changed

docs/quantization_support.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite).
2727
| `pow` | / |
2828
| `prelu` | / |
2929
| `reciprocal` | / |
30+
| `rsqrt` | / |
3031
| `silu` | / |
3132
| `sin` | / |
3233
| `softmax` | / |
@@ -71,8 +72,10 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite).
7172
| `sum` | For TFLiteConverter, set `rewrite_quantizable=True` |
7273
| `torch.nn.GLU` | No action needed |
7374
| `torch.nn.Hardsigmoid` | No action needed |
75+
| `torch.nn.LayerNorm` | No action needed |
7476
| `torch.nn.LogSoftmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`<br>For TFLiteConverter, set `rewrite_quantizable=True` |
7577
| `torch.nn.PReLU` | No action needed |
78+
| `torch.nn.RMSNorm` | No action needed |
7679
| `torch.nn.SiLU` | No action needed |
7780
| `torch.nn.Softmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`<br>For TFLiteConverter, set `rewrite_quantizable=True` |
7881
| `truediv` | For TFLiteConverter, set `rewrite_quantizable=True` |

tests/qat_module_test.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import unittest
2+
import random
23

34
import torch
45
import torch.nn as nn
5-
from tinynn.graph.quantization.modules import QGLU, QPReLU, QSiLU
6+
from tinynn.graph.quantization.modules import QGLU, QPReLU, QSiLU, QLayerNorm, QRMSNorm
67

78

89
class QATModuleTester(unittest.TestCase):
@@ -89,6 +90,55 @@ def test_glu(self):
8990

9091
self.assertTrue(False)
9192

93+
def test_layer_norm(self):
94+
for i in range(100):
95+
normalized_shape = tuple(random.randint(10, 100) for _ in range(random.randint(1, 3)))
96+
non_normalized_shape = tuple(random.randint(1, 100) for _ in range(random.randint(1, 2)))
97+
98+
orig = nn.LayerNorm(normalized_shape)
99+
quant = QLayerNorm(orig)
100+
101+
inp = torch.randn((*non_normalized_shape, *normalized_shape))
102+
103+
orig_outp = orig(inp)
104+
quant_outp = quant(inp)
105+
106+
if not torch.allclose(orig_outp, quant_outp, atol=1e-6):
107+
print(normalized_shape, non_normalized_shape)
108+
print('original:')
109+
print(orig_outp)
110+
print('quanted:')
111+
print(quant_outp)
112+
113+
print('diff (min, max):', torch.max(quant_outp - orig_outp), torch.min(quant_outp - orig_outp))
114+
115+
self.assertTrue(False)
116+
117+
@unittest.skipIf(not hasattr(torch.nn, 'RMSNorm'), 'RMSNorm is not supported')
118+
def test_rms_norm(self):
119+
for i in range(100):
120+
normalized_shape = tuple(random.randint(10, 100) for _ in range(random.randint(1, 3)))
121+
non_normalized_shape = tuple(random.randint(1, 100) for _ in range(random.randint(1, 2)))
122+
123+
orig = nn.RMSNorm(normalized_shape)
124+
quant = QRMSNorm(orig)
125+
126+
inp = torch.randn((*non_normalized_shape, *normalized_shape))
127+
128+
orig_outp = orig(inp)
129+
quant_outp = quant(inp)
130+
131+
if not torch.allclose(orig_outp, quant_outp, atol=1e-6):
132+
print(normalized_shape, non_normalized_shape)
133+
print('original:')
134+
print(orig_outp)
135+
print('quanted:')
136+
print(quant_outp)
137+
138+
print('diff (min, max):', torch.max(quant_outp - orig_outp), torch.min(quant_outp - orig_outp))
139+
140+
self.assertTrue(False)
141+
92142

93143
if __name__ == '__main__':
94144
unittest.main()

tests/quantizer_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def forward(self, x):
380380
return self.norm(y[0])
381381

382382
model = Model()
383+
torch.nn.init.uniform_(model.norm.bias, -0.1, 0.1)
383384
inputs = torch.randn(1, 3, 224, 224)
384385

385386
check_quantize_rewrite(model, inputs)

tinynn/graph/quantization/modules.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,86 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
8989
x2 = self.act_r(x1)
9090
x3 = self.q(self.dq(x2))
9191
return self.f_mul.mul_scalar(x3, 1 / 6)
92+
93+
94+
class QLayerNorm(nn.Module):
95+
def __init__(self, layernorm: nn.LayerNorm) -> None:
96+
super().__init__()
97+
self.mean_dims = tuple(range(-len(layernorm.normalized_shape), 0))
98+
self.weight = torch.nn.Parameter(layernorm.weight.data.detach().clone())
99+
self.bias = torch.nn.Parameter(layernorm.bias.data.detach().clone())
100+
self.eps = layernorm.eps
101+
102+
self.q_rsqrt = torch_q.QuantStub()
103+
self.q_weight = torch_q.QuantStub()
104+
self.q_bias = torch_q.QuantStub()
105+
self.dq_rsqrt = torch_q.DeQuantStub()
106+
107+
self.f_neg = nnq.FloatFunctional()
108+
self.f_add_0 = nnq.FloatFunctional()
109+
self.f_mul_0 = nnq.FloatFunctional()
110+
self.f_add_1 = nnq.FloatFunctional()
111+
self.f_mul_1 = nnq.FloatFunctional()
112+
self.f_mul_2 = nnq.FloatFunctional()
113+
self.f_add_2 = nnq.FloatFunctional()
114+
115+
def forward(self, input: torch.Tensor) -> torch.Tensor:
116+
# LayerNorm(input) = (input - mean(input)) * rsqrt(mean(( input - mean(input) )**20) + eps ) * alpha + beta
117+
# Currently, we completely split LayerNorm and independently count the quantization parameters
118+
# of the intermediate activation value, which may lead to a decrease in quantization accuracy.
119+
120+
mean = input.mean(self.mean_dims, keepdim=True)
121+
diff = self.f_add_0.add(input, self.f_neg.mul_scalar(mean, -1.0).expand_as(input))
122+
squarer_difference = self.f_mul_0.mul(diff, diff)
123+
var = squarer_difference.mean(self.mean_dims, keepdim=True)
124+
var_eps = self.f_add_1.add_scalar(var, self.eps)
125+
126+
fdq_var_eps = self.dq_rsqrt(var_eps)
127+
std_inverse = torch.rsqrt(fdq_var_eps)
128+
q_std_inverse = self.q_rsqrt(std_inverse)
129+
130+
weight_fq = self.q_weight(self.weight)
131+
bias_fq = self.q_bias(self.bias)
132+
norm = self.f_mul_1.mul(diff, q_std_inverse)
133+
weight_fq_expand = weight_fq.expand_as(norm)
134+
norm_alpha = self.f_mul_2.mul(norm, weight_fq_expand)
135+
bias_fq_expand = bias_fq.expand_as(norm_alpha)
136+
return self.f_add_2.add(norm_alpha, bias_fq_expand)
137+
138+
139+
class QRMSNorm(nn.Module):
140+
def __init__(self, rmsnorm: 'nn.RMSNorm') -> None:
141+
super().__init__()
142+
self.mean_dims = tuple(range(-len(rmsnorm.normalized_shape), 0))
143+
self.weight = torch.nn.Parameter(rmsnorm.weight.data.detach().clone())
144+
self.eps = rmsnorm.eps
145+
146+
self.q_rsqrt = torch_q.QuantStub()
147+
self.q_weight = torch_q.QuantStub()
148+
self.dq_rsqrt = torch_q.DeQuantStub()
149+
150+
self.f_add_0 = nnq.FloatFunctional()
151+
self.f_mul_0 = nnq.FloatFunctional()
152+
self.f_mul_1 = nnq.FloatFunctional()
153+
self.f_mul_2 = nnq.FloatFunctional()
154+
155+
def forward(self, input: torch.Tensor) -> torch.Tensor:
156+
# RMSNorm(input) = (input) * rsqrt(mean(input**2)) + eps ) * alpha + beta
157+
158+
squard_input = self.f_mul_0.mul(input, input)
159+
if self.eps is None:
160+
rms_pre = squard_input.mean(self.mean_dims, keepdim=True)
161+
else:
162+
rms_pre = self.f_add_0.add_scalar(
163+
squard_input.mean(self.mean_dims, keepdim=True),
164+
self.eps,
165+
)
166+
167+
fdq_rms_pre = self.dq_rsqrt(rms_pre)
168+
rms_inverse = torch.rsqrt(fdq_rms_pre)
169+
q_rms = self.q_rsqrt(rms_inverse)
170+
171+
weight_fq = self.q_weight(self.weight)
172+
norm = self.f_mul_1.mul(input, q_rms)
173+
weight_fq_expand = weight_fq.expand_as(norm)
174+
return self.f_mul_2.mul(norm, weight_fq_expand)

tinynn/graph/quantization/quantizer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
FakeQuantizeBFloat16,
2424
FakeQuantizeTFLite,
2525
)
26-
from tinynn.graph.quantization.modules import QGLU, QHardsigmoid, QPReLU, QSiLU
26+
from tinynn.graph.quantization.modules import QGLU, QHardsigmoid, QPReLU, QSiLU, QLayerNorm, QRMSNorm
2727
from tinynn.graph.quantization.observer import (
2828
HistogramObserverKL,
2929
MinMaxObserver,
@@ -174,6 +174,7 @@
174174
'pow': None,
175175
'truediv': None,
176176
'sqrt': None,
177+
'rsqrt': None,
177178
'atan2': None,
178179
'atan': None,
179180
'sin': None,
@@ -263,8 +264,11 @@
263264
Q_MODULES_MAPPING.update({nn.SiLU: QSiLU})
264265
FUNCTIONAL_MODULE_MAPPING.update({'silu': nn.SiLU})
265266

267+
if hasattr(nn, 'LayerNorm'):
268+
Q_MODULES_MAPPING.update({nn.LayerNorm: QLayerNorm})
269+
266270
if hasattr(nn, 'RMSNorm'):
267-
UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.update({nn.RMSNorm: None})
271+
Q_MODULES_MAPPING.update({nn.RMSNorm: QRMSNorm})
268272

269273
# Processed QAT fuse rules
270274
processed_qat_rules = {}

0 commit comments

Comments
 (0)