Skip to content

Commit 0aed232

Browse files
authored
RMS Normalization and Skip RMS Normalization fusion optimizations (microsoft#1974)
Implements RMS Normalization and Skip RMS Normalization fusion optimizations (for use of onnxruntime custom fused ops for these).
1 parent 86644e9 commit 0aed232

File tree

10 files changed

+661
-0
lines changed

10 files changed

+661
-0
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ exclude_patterns = [
5050
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
5151
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
5252
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
53+
'onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py', # onnxscript code
5354
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
5455
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
5556
'onnxscript/tools/function_unittest_producer.py', # FIXME
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""
5+
A one-layer SmolLM model test case.
6+
This is an onnxscript version of the model.
7+
"""
8+
9+
import numpy
10+
from onnx.helper import make_tensor
11+
12+
import onnxscript.ir as ir
13+
from onnxscript import script
14+
from onnxscript.onnx_opset import opset18
15+
from onnxscript.onnx_types import FLOAT, INT64
16+
17+
18+
def make_model(
19+
input_layernorm_weight_0,
20+
post_attention_layernorm_weight0,
21+
norm_weight,
22+
head_weight,
23+
self_attn_q_proj_weight0,
24+
self_attn_k_proj_weight0,
25+
self_attn_v_proj_weight0,
26+
self_attn_o_proj_weight0,
27+
mlp_gate_proj_weight0,
28+
mlp_up_proj_weight0,
29+
mlp_down_proj_weight0,
30+
):
31+
@script()
32+
def main_graph(
33+
input0: INT64[1, 10], input1: FLOAT[1, 10], input2: INT64[1, 10]
34+
) -> (FLOAT[1, 10, 49152], FLOAT[1, 32, 10, 64], FLOAT[1, 32, 10, 64]):
35+
model_layers_0_input_layernorm_weight = opset18.Constant(
36+
value=input_layernorm_weight_0
37+
)
38+
model_layers_0_post_attention_layernorm_weight = opset18.Constant(
39+
value=post_attention_layernorm_weight0
40+
)
41+
model_norm_weight = opset18.Constant(value=norm_weight)
42+
lm_head_weight = opset18.Constant(value=head_weight)
43+
model_layers_0_self_attn_q_proj_weight = opset18.Constant(
44+
value=self_attn_q_proj_weight0
45+
)
46+
model_layers_0_self_attn_k_proj_weight = opset18.Constant(
47+
value=self_attn_k_proj_weight0
48+
)
49+
model_layers_0_self_attn_v_proj_weight = opset18.Constant(
50+
value=self_attn_v_proj_weight0
51+
)
52+
model_layers_0_self_attn_o_proj_weight = opset18.Constant(
53+
value=self_attn_o_proj_weight0
54+
)
55+
model_layers_0_mlp_gate_proj_weight = opset18.Constant(value=mlp_gate_proj_weight0)
56+
model_layers_0_mlp_up_proj_weight = opset18.Constant(value=mlp_up_proj_weight0)
57+
model_layers_0_mlp_down_proj_weight = opset18.Constant(value=mlp_down_proj_weight0)
58+
59+
embedding = opset18.Gather(lm_head_weight, input0, axis=0)
60+
minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38])
61+
mask_10x10 = opset18.Trilu(minus_inf_10x10, 1)
62+
slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10])
63+
unsqueeze_2 = opset18.Unsqueeze(input1, 1)
64+
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2)
65+
add = slice_5 + unsqueeze_3
66+
eq = add == 0.0
67+
slice_10 = slice_5
68+
masked_fill = opset18.Where(eq, -3.4028235e38, slice_10)
69+
val_179 = opset18.Transpose(masked_fill, perm=[2, 1, 0, 3])
70+
slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3])
71+
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
72+
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
73+
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
74+
_to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
75+
view_1 = opset18.Constant(
76+
value=make_tensor(
77+
"value",
78+
1,
79+
dims=[1, 32, 1],
80+
vals=[
81+
1.0,
82+
0.7498942017555237,
83+
0.5623413324356079,
84+
0.4216965138912201,
85+
0.3162277638912201,
86+
0.23713736236095428,
87+
0.17782793939113617,
88+
0.1333521455526352,
89+
0.10000000149011612,
90+
0.07498941570520401,
91+
0.05623412877321243,
92+
0.04216964915394783,
93+
0.03162277862429619,
94+
0.0237137358635664,
95+
0.017782794311642647,
96+
0.01333521492779255,
97+
0.009999999776482582,
98+
0.007498942315578461,
99+
0.005623413249850273,
100+
0.0042169648222625256,
101+
0.003162277862429619,
102+
0.0023713738191872835,
103+
0.0017782794311642647,
104+
0.0013335214462131262,
105+
0.0010000000474974513,
106+
0.0007498941849917173,
107+
0.000562341301701963,
108+
0.00042169648804701865,
109+
0.0003162277862429619,
110+
0.0002371373848291114,
111+
0.00017782794020604342,
112+
0.0001333521504420787,
113+
],
114+
)
115+
)
116+
view_2 = opset18.Reshape(_to_copy_1, [1, 1, 10], allowzero=0)
117+
bmm = view_1 @ view_2
118+
view_3 = opset18.Reshape(bmm, [1, 32, 10], allowzero=0)
119+
transpose = opset18.Transpose(view_3, perm=[0, 2, 1])
120+
cat = opset18.Concat(transpose, transpose, axis=-1)
121+
cos = opset18.Cos(cat)
122+
sin = opset18.Sin(cat)
123+
pow_1 = embedding**2.0
124+
mean = opset18.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0)
125+
add_1 = mean + 1e-05
126+
val_244 = opset18.Sqrt(add_1)
127+
rsqrt = opset18.Reciprocal(val_244)
128+
mul_3 = embedding * rsqrt
129+
mul_4 = model_layers_0_input_layernorm_weight * mul_3
130+
t = opset18.Transpose(model_layers_0_self_attn_q_proj_weight, perm=[1, 0])
131+
view_5 = mul_4 @ t
132+
t_1 = opset18.Transpose(model_layers_0_self_attn_k_proj_weight, perm=[1, 0])
133+
view_7 = mul_4 @ t_1
134+
t_2 = opset18.Transpose(model_layers_0_self_attn_v_proj_weight, perm=[1, 0])
135+
view_9 = mul_4 @ t_2
136+
view_10 = opset18.Reshape(view_5, [1, 10, 32, 64], allowzero=0)
137+
transpose_1 = opset18.Transpose(view_10, perm=[0, 2, 1, 3])
138+
view_11 = opset18.Reshape(view_7, [1, 10, 32, 64], allowzero=0)
139+
transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3])
140+
view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0)
141+
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
142+
unsqueeze_7 = opset18.Unsqueeze(cos, 1)
143+
unsqueeze_8 = opset18.Unsqueeze(sin, 1)
144+
mul_5 = transpose_1 * unsqueeze_7
145+
val_267 = opset18.Constant(value_ints=[1])
146+
slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267)
147+
val_277 = opset18.Constant(value_ints=[1])
148+
slice_20 = opset18.Slice(transpose_1, [32], [9223372036854775807], [3], val_277)
149+
neg = opset18.Neg(slice_20)
150+
cat_1 = opset18.Concat(neg, slice_19, axis=-1)
151+
mul_6 = cat_1 * unsqueeze_8
152+
add_2 = mul_5 + mul_6
153+
mul_7 = transpose_2 * unsqueeze_7
154+
val_287 = opset18.Constant(value_ints=[1])
155+
slice_21 = opset18.Slice(transpose_2, [0], [32], [3], val_287)
156+
val_297 = opset18.Constant(value_ints=[1])
157+
slice_22 = opset18.Slice(transpose_2, [32], [9223372036854775807], [3], val_297)
158+
neg_1 = opset18.Neg(slice_22)
159+
cat_2 = opset18.Concat(neg_1, slice_21, axis=-1)
160+
mul_8 = cat_2 * unsqueeze_8
161+
add_3 = mul_7 + mul_8
162+
val_346 = opset18.Reshape(add_3, [-1, 10, 64], allowzero=0)
163+
val_347 = opset18.Transpose(val_346, perm=[0, 2, 1])
164+
val_349 = opset18.Reshape(val_347, [1, 32, 64, 10], allowzero=0)
165+
val_351 = add_2 * [0.35355338]
166+
val_353 = val_349 * [0.35355338]
167+
val_354 = val_351 @ val_353
168+
val_355 = val_354 + slice_scatter_1
169+
val_356 = opset18.Softmax(val_355, axis=-1)
170+
getitem = val_356 @ transpose_3
171+
transpose_4 = opset18.Transpose(getitem, perm=[0, 2, 1, 3])
172+
view_13 = opset18.Reshape(transpose_4, [1, 10, -1], allowzero=0)
173+
t_3 = opset18.Transpose(model_layers_0_self_attn_o_proj_weight, perm=[1, 0])
174+
view_15 = view_13 @ t_3
175+
add_4 = embedding + view_15
176+
pow_2 = add_4**2.0
177+
mean_1 = opset18.ReduceMean(pow_2, [-1], keepdims=1, noop_with_empty_axes=0)
178+
add_5 = mean_1 + 1e-05
179+
val_379 = opset18.Sqrt(add_5)
180+
rsqrt_1 = opset18.Reciprocal(val_379)
181+
mul_9 = add_4 * rsqrt_1
182+
mul_10 = model_layers_0_post_attention_layernorm_weight * mul_9
183+
t_4 = opset18.Transpose(model_layers_0_mlp_gate_proj_weight, perm=[1, 0])
184+
view_17 = mul_10 @ t_4
185+
val_383 = opset18.Sigmoid(view_17)
186+
silu = view_17 * val_383
187+
t_5 = opset18.Transpose(model_layers_0_mlp_up_proj_weight, perm=[1, 0])
188+
view_19 = mul_10 @ t_5
189+
mul_11 = silu * view_19
190+
t_6 = opset18.Transpose(model_layers_0_mlp_down_proj_weight, perm=[1, 0])
191+
view_21 = mul_11 @ t_6
192+
add_6 = add_4 + view_21
193+
pow_3 = add_6**2.0
194+
mean_2 = opset18.ReduceMean(pow_3, [-1], keepdims=1, noop_with_empty_axes=0)
195+
add_7 = mean_2 + 1e-05
196+
val_391 = opset18.Sqrt(add_7)
197+
rsqrt_2 = opset18.Reciprocal(val_391)
198+
mul_12 = add_6 * rsqrt_2
199+
mul_13 = model_norm_weight * mul_12
200+
t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0])
201+
view_23 = mul_13 @ t_7
202+
_to_copy_12 = opset18.Identity(view_23)
203+
return _to_copy_12, add_3, transpose_3
204+
205+
model = main_graph.to_model_proto()
206+
return model
207+
208+
209+
def make_model_with_random_weights():
210+
input_layernorm_weight_0 = numpy.random.rand(2048).astype(numpy.float32)
211+
post_attention_layernorm_weight0 = numpy.random.rand(2048).astype(numpy.float32)
212+
norm_weight = numpy.random.rand(2048).astype(numpy.float32)
213+
head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32)
214+
self_attn_q_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
215+
self_attn_k_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
216+
self_attn_v_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
217+
self_attn_o_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
218+
mlp_gate_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32)
219+
mlp_up_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32)
220+
mlp_down_proj_weight0 = numpy.random.rand(2048, 8192).astype(numpy.float32)
221+
model = make_model(
222+
input_layernorm_weight_0,
223+
post_attention_layernorm_weight0,
224+
norm_weight,
225+
head_weight,
226+
self_attn_q_proj_weight0,
227+
self_attn_k_proj_weight0,
228+
self_attn_v_proj_weight0,
229+
self_attn_o_proj_weight0,
230+
mlp_gate_proj_weight0,
231+
mlp_up_proj_weight0,
232+
mlp_down_proj_weight0,
233+
)
234+
return model
235+
236+
237+
class _SmollmTestData:
238+
def get_onnx_model(self):
239+
if not hasattr(self, "_onnx_model"):
240+
model_proto = make_model_with_random_weights()
241+
model = ir.serde.deserialize_model(model_proto)
242+
self._onnx_model = model
243+
return self._onnx_model
244+
245+
def get_ort_inputs(self):
246+
if not hasattr(self, "_ort_inputs"):
247+
inputs = {
248+
"input0": numpy.random.randint(0, 49152, (1, 10)).astype(numpy.int64),
249+
"input1": numpy.ones((1, 10), dtype=numpy.float32),
250+
"input2": numpy.arange(10, dtype=numpy.int64).reshape(1, 10),
251+
}
252+
self._ort_inputs = inputs
253+
return self._ort_inputs
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import os
6+
import tempfile
7+
8+
import numpy as np
9+
import onnxruntime
10+
import torch
11+
import transformers
12+
from transformers import LlamaConfig
13+
14+
import onnxscript.ir as ir
15+
import onnxscript.ir._io as io
16+
import onnxscript.optimizer
17+
18+
# Create a LlamaConfig object with the desired parameters
19+
_config = LlamaConfig(
20+
_name_or_path="HuggingFaceTB/SmolLM-1.7B",
21+
architectures=["LlamaForCausalLM"],
22+
attention_bias=False,
23+
attention_dropout=0.0,
24+
bos_token_id=0,
25+
eos_token_id=0,
26+
hidden_act="silu",
27+
hidden_size=2048,
28+
initializer_range=0.02,
29+
intermediate_size=8192,
30+
max_position_embeddings=2048,
31+
model_type="llama",
32+
num_attention_heads=32,
33+
num_hidden_layers=1,
34+
num_key_value_heads=32,
35+
pretraining_tp=1,
36+
rms_norm_eps=1e-05,
37+
rope_scaling=None,
38+
rope_theta=10000.0,
39+
tie_word_embeddings=True,
40+
torch_dtype="float32",
41+
transformers_version="4.37.2",
42+
use_cache=True,
43+
vocab_size=49152,
44+
)
45+
46+
# Dimensions for inputs:
47+
_batch_size = 1
48+
_seq_len = 10
49+
_hidden_size = _config.hidden_size
50+
_num_attention_heads = _config.num_attention_heads
51+
dim = _hidden_size // _num_attention_heads
52+
_vocab_size = _config.vocab_size
53+
54+
55+
class _SmollmTestData:
56+
def __init__(self):
57+
pass
58+
59+
def get_torch_model(self):
60+
if not hasattr(self, "_torch_model"):
61+
model = transformers.LlamaForCausalLM(_config)
62+
model.eval()
63+
self._torch_model = model
64+
return self._torch_model
65+
66+
def get_onnx_model(self) -> ir.Model:
67+
model = self.get_torch_model()
68+
inputs = self.get_inputs()
69+
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None]
70+
exported = torch.onnx.export(
71+
model, inputs, input_names=input_names, dynamo=True, fallback=True
72+
)
73+
# ORT Transformer optimizations are applied after basic optimization.
74+
exported_model = exported.model # type: ignore[union-attr]
75+
onnxscript.optimizer.optimize(exported_model)
76+
return exported_model
77+
78+
def get_inputs(self):
79+
if not hasattr(self, "_inputs"):
80+
input_ids = torch.randint(0, _vocab_size, (_batch_size, _seq_len)).to(torch.int64)
81+
attention_mask = torch.ones(input_ids.shape)
82+
position_ids = torch.arange(0, input_ids.size(-1)).unsqueeze(0)
83+
self._inputs = (input_ids, attention_mask, position_ids)
84+
return self._inputs
85+
86+
def get_torch_outputs(self):
87+
output = self.get_torch_model()(*self.get_inputs())
88+
logits = output.logits
89+
past_key_value = output.past_key_values[0]
90+
key = past_key_value[0]
91+
value = past_key_value[1]
92+
return (logits.detach().numpy(), key.detach().numpy(), value.detach().numpy())
93+
94+
def get_ort_inputs(self):
95+
inputs = self.get_inputs()
96+
return {
97+
f"input{i}": input.numpy() for i, input in enumerate(inputs) if input is not None
98+
}
99+
100+
101+
def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2):
102+
providers = ["CPUExecutionProvider"]
103+
with tempfile.TemporaryDirectory() as temp_dir:
104+
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
105+
io.save(model, model_path)
106+
# Run model
107+
session = onnxruntime.InferenceSession(model_path, providers=providers)
108+
ort_outputs = session.run(None, inputs)
109+
110+
for i, (baseline_output, optimized_output) in enumerate(
111+
zip(expected_outputs, ort_outputs)
112+
):
113+
try:
114+
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
115+
np.testing.assert_allclose(
116+
baseline_output, optimized_output, rtol=rtol, atol=atol
117+
)
118+
except AssertionError as e:
119+
print(
120+
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
121+
)
122+
raise

0 commit comments

Comments
 (0)