Skip to content

Commit dacb42f

Browse files
yufengleetiagoshibataKeDengMStianleiwu
authored
Cherry pick 3 fixes to rel-1.2.0 (#3158)
* Publish release symbols (#3152) * Publish release symbols * Publish symbols if IsReleaseBuild * Disable delayload for cuda dlls (#3147) This change fixes #3129. When running onnxruntime as dll on Windows, CUDA does some internal cleanups when process exits. After this, any call to CUDA would cause crash. Delayload makes thread_local destructor to happen after CUDA cleanup, thus the crash. * Update Gelu Fusion to support new graph pattern from PyTorch 1.4 (#3148) * update GeluFusion to support pattern from PyTorch 1.4; * Fix a bug that missing the check of an edge between mul2 and root. * update script to fuse gelu from PyTorch 1.4 * Add test for python optimizer Co-authored-by: Tiago Koji Castro Shibata <[email protected]> Co-authored-by: KeDengMS <[email protected]> Co-authored-by: Tianlei Wu <[email protected]>
1 parent b71554c commit dacb42f

File tree

18 files changed

+372
-62
lines changed

18 files changed

+372
-62
lines changed

cmake/CMakeLists.txt

+7-6
Original file line numberDiff line numberDiff line change
@@ -724,12 +724,13 @@ if (onnxruntime_USE_CUDA)
724724
if (WIN32)
725725
link_directories(${onnxruntime_CUDNN_HOME}/lib/x64)
726726

727-
file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*")
728-
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll")
729-
foreach(cuda_dll_path ${cuda_dll_paths})
730-
get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME)
731-
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:${cuda_dll_file_name}")
732-
endforeach(cuda_dll_path)
727+
# delayload causes crash on exit, so disable for now
728+
#file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*")
729+
#set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll")
730+
#foreach(cuda_dll_path ${cuda_dll_paths})
731+
# get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME)
732+
# set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:${cuda_dll_file_name}")
733+
#endforeach(cuda_dll_path)
733734

734735
else()
735736
link_directories(${onnxruntime_CUDNN_HOME}/lib64)

onnxruntime/core/optimizer/gelu_fusion.cc

+78-34
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,25 @@ static bool IsSupportedDataType(const Node& node) {
2424
}
2525
return true;
2626
}
27-
27+
/*
28+
This function fuses subgraph like the following into one Gelu node.
29+
Subgraph pattern 1:
30+
+-------Mul(0.5)---------------------+
31+
| |
32+
| v
33+
[root] --> Div -----> Erf --> Add --> Mul ==>
34+
(B=1.4142...) (1)
35+
36+
Subgraph pattern 2:
37+
+------------------------------------+
38+
| |
39+
| v
40+
[root] --> Div -----> Erf --> Add --> Mul -->Mul ==>
41+
(B=1.4142...) (1) (0.5)
42+
43+
After Fusion:
44+
[root]--> Gelu ==>
45+
*/
2846
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
2947
GraphViewer graph_viewer(graph);
3048
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@@ -68,13 +86,9 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
6886
continue;
6987
}
7088

71-
// Check the other input node(e.g. not of type Erf) is 1.0f.
72-
const Node& add_first_input_node = *(add_node.InputNodesBegin());
73-
int add_const_input_index = 0;
74-
if (add_first_input_node.OpType().compare("Erf") == 0) {
75-
add_const_input_index = 1;
76-
}
77-
const auto& add_const_input_arg = add_node.InputDefs()[add_const_input_index];
89+
// Check the other input node (e.g. not the Erf) is 1.0f.
90+
bool is_erf_first_input = (add_node.InputDefs()[0]->Name() == erf_node.MutableOutputDefs()[0]->Name());
91+
const auto& add_const_input_arg = add_node.InputDefs()[is_erf_first_input ? 1 : 0];
7892
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *add_const_input_arg, 1.0f, true)) {
7993
continue;
8094
}
@@ -87,35 +101,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
87101
continue;
88102
}
89103

90-
const Node* p_mul2_node = nullptr;
91-
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
92-
if ((*iter).OpType().compare("Mul") == 0) {
93-
// find the other input node of Mul
94-
p_mul2_node = &(*iter);
95-
break;
104+
bool is_pattern_1 = true;
105+
const Node* p_mul2_node = graph_utils::FirstParentByType(mul_node, "Mul");
106+
if (p_mul2_node != nullptr) {
107+
// Match subgraph pattern 1
108+
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
109+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
110+
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
111+
mul2_node.GetOutputEdgesCount() != 1 ||
112+
!IsSupportedDataType(mul2_node)) {
113+
continue;
96114
}
97-
}
98-
if (p_mul2_node == nullptr) {
99-
continue;
100-
}
101115

102-
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
103-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
104-
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
105-
mul2_node.GetOutputEdgesCount() != 1 ||
106-
!IsSupportedDataType(mul2_node)) {
107-
continue;
108-
}
116+
// One input of mul2_node shall be the subgraph input
117+
auto root_index = optimizer_utils::IndexOfNodeInput(*p_mul2_node, *div.InputDefs()[0]);
118+
if (root_index < 0)
119+
continue;
109120

110-
// Check the other input node(e.g. not of type Add) is 0.5f.
111-
int mul_const_input_index = 0;
112-
if (mul2_node.InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
113-
mul_const_input_index = 1;
114-
}
121+
// Check the other input node is 0.5f.
122+
int mul_const_input_index = (root_index == 0 ? 1 : 0);
115123

116-
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
117-
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
118-
continue;
124+
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
125+
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
126+
continue;
127+
}
128+
} else {
129+
is_pattern_1 = false;
130+
131+
// Match subgraph pattern 2
132+
if (mul_node.GetOutputEdgesCount() != 1) {
133+
continue;
134+
}
135+
136+
// Another input of Mul node shall be the subgraph input.
137+
auto root_index = optimizer_utils::IndexOfNodeInput(mul_node, *div.InputDefs()[0]);
138+
if (root_index < 0)
139+
continue;
140+
141+
Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
142+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
143+
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
144+
!IsSupportedDataType(mul_node)) {
145+
continue;
146+
}
147+
148+
int mul_const_input_index = 0;
149+
if (mul2_node.InputDefs()[0]->Name() == mul_node.MutableOutputDefs()[0]->Name()) {
150+
mul_const_input_index = 1;
151+
}
152+
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
153+
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
154+
continue;
155+
}
156+
157+
p_mul2_node = &mul2_node;
119158
}
120159

121160
const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
@@ -131,7 +170,12 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
131170
// move input edges to div (first in list) across to the gelu_node.
132171
// move output definitions and output edges from mul_node (last in list) to gelu_node.
133172
// remove all the other nodes.
134-
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node);
173+
Node& mul2 = *graph.GetNode(p_mul2_node->Index());
174+
if (is_pattern_1) {
175+
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2, mul_node}, gelu_node);
176+
} else {
177+
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul_node, mul2}, gelu_node);
178+
}
135179

136180
modified = true;
137181
}

onnxruntime/python/tools/bert/BertOnnxModel.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,24 @@ def fuse_gelu(self, gelu_op_name):
234234

235235
"""
236236
Fuse Gelu with Erf into one node:
237-
+-------Mul(B=0.5)-------------------+
237+
Pattern 1:
238+
+-------Mul(0.5)---------------------+
238239
| |
239240
| v
240241
[root] --> Div -----> Erf --> Add --> Mul -->
241-
(B=1.4142...) (B=1)
242+
(B=1.4142...) (1)
243+
244+
Pattern 2:
245+
+------------------------------------+
246+
| |
247+
| v
248+
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
249+
(B=1.4142...) (1) (0.5)
242250
243251
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
244252
"""
245253
def fuse_gelu_with_elf(self, gelu_op_name):
254+
logger.debug(f"start fuse_gelu_with_elf({gelu_op_name})")
246255
input_name_to_nodes = self.input_name_to_nodes()
247256
output_name_to_node = self.output_name_to_node()
248257

@@ -276,25 +285,38 @@ def fuse_gelu_with_elf(self, gelu_op_name):
276285
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
277286
continue
278287

279-
root_node = self.get_parent(div, 0, output_name_to_node)
280-
if root_node is None:
281-
continue
288+
subgraph_input = div.input[0]
282289

283-
mul_half = self.match_parent(mul_after_erf, 'Mul', None, output_name_to_node)
284-
if mul_half is None:
285-
continue
290+
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
291+
if subgraph_input == mul_after_erf.input[another]: # pattern 2
292+
children = input_name_to_nodes[mul_after_erf.output[0]]
293+
if len(children) != 1 or children[0].op_type != 'Mul':
294+
continue
295+
mul_half = children[0]
296+
if not self.has_constant_input(mul_half, 0.5):
297+
continue
298+
subgraph_output = mul_half.output[0]
299+
else: # pattern 1
300+
mul_half = self.match_parent(mul_after_erf, 'Mul', another, output_name_to_node)
301+
if mul_half is None:
302+
continue
286303

287-
if not self.has_constant_input(mul_half, 0.5):
288-
continue
304+
if not self.has_constant_input(mul_half, 0.5):
305+
continue
306+
307+
if subgraph_input not in mul_half.input:
308+
continue
309+
310+
subgraph_output = mul_after_erf.output[0]
289311

290312
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
291-
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul_after_erf.output[0]], input_name_to_nodes, output_name_to_node):
313+
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
292314
continue
293315

294316
nodes_to_remove.extend(subgraph_nodes)
295317
gelu_node = onnx.helper.make_node(gelu_op_name,
296-
inputs=[root_node.output[0]],
297-
outputs=[mul_after_erf.output[0]])
318+
inputs=[subgraph_input],
319+
outputs=[subgraph_output])
298320
gelu_node.domain = "com.microsoft"
299321
nodes_to_add.append(gelu_node)
300322

onnxruntime/python/tools/bert/test_bert_optimization.py

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BERT_TEST_MODELS = {
2323
"bert_pytorch_0": 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_0.onnx',
2424
"bert_pytorch_1": 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_1.onnx',
25+
"bert_squad_pytorch1.4_opset10_fp32": 'test_data\\bert_squad_pytorch1.4_opset10_fp32\\BertForQuestionAnswering.onnx',
2526
"bert_keras_0": 'test_data\\bert_mrpc_tensorflow2.1_opset10\\TFBertForSequenceClassification_1.onnx'
2627
}
2728

@@ -155,6 +156,13 @@ def test_pytorch_model_0_gpu(self):
155156
}
156157
self.verify_node_count(bert_model, expected_node_count)
157158

159+
def test_pytorch_model_2_cpu(self):
160+
input = BERT_TEST_MODELS['bert_squad_pytorch1.4_opset10_fp32']
161+
bert_model = optimize_model(input, 'bert', gpu_only=False,
162+
num_heads=2, hidden_size=8, sequence_length=10,
163+
input_int32=False, float16=False)
164+
self.assertTrue(bert_model.is_fully_optimized())
165+
158166
def test_keras_model_1_cpu(self):
159167
input = BERT_TEST_MODELS['bert_keras_0']
160168

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+

2+
BstartJ(���A2I<W�1<nR�<G�;�^<�q?;���<
3+
r�;��<
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+

2+
BendJ(���=Fڞ=L��=QR�=�w�=6\�=��=���=��=Jg�=

onnxruntime/test/optimizer/graph_transform_test.cc

+57
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,63 @@ TEST(GraphTransformationTests, GeluFusionTest) {
11301130
ASSERT_TRUE(op_to_count["Gelu"] == 1);
11311131
}
11321132

1133+
TEST(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
1134+
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
1135+
std::shared_ptr<Model> p_model;
1136+
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
1137+
Graph& graph = p_model->MainGraph();
1138+
1139+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1140+
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
1141+
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
1142+
ASSERT_TRUE(ret.IsOK());
1143+
1144+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1145+
ASSERT_TRUE(op_to_count["Div"] == 0);
1146+
ASSERT_TRUE(op_to_count["Add"] == 0);
1147+
ASSERT_TRUE(op_to_count["Erf"] == 0);
1148+
ASSERT_TRUE(op_to_count["Mul"] == 0);
1149+
ASSERT_TRUE(op_to_count["Gelu"] == 1);
1150+
}
1151+
1152+
TEST(GraphTransformationTests, GeluFusionTestFormat2) {
1153+
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1.onnx";
1154+
std::shared_ptr<Model> p_model;
1155+
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
1156+
Graph& graph = p_model->MainGraph();
1157+
1158+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1159+
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
1160+
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
1161+
ASSERT_TRUE(ret.IsOK());
1162+
1163+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1164+
ASSERT_TRUE(op_to_count["Div"] == 0);
1165+
ASSERT_TRUE(op_to_count["Add"] == 0);
1166+
ASSERT_TRUE(op_to_count["Erf"] == 0);
1167+
ASSERT_TRUE(op_to_count["Mul"] == 0);
1168+
ASSERT_TRUE(op_to_count["Gelu"] == 1);
1169+
}
1170+
1171+
TEST(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
1172+
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1_use_graph_input.onnx";
1173+
std::shared_ptr<Model> p_model;
1174+
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
1175+
Graph& graph = p_model->MainGraph();
1176+
1177+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1178+
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
1179+
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
1180+
ASSERT_TRUE(ret.IsOK());
1181+
1182+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1183+
ASSERT_TRUE(op_to_count["Div"] == 0);
1184+
ASSERT_TRUE(op_to_count["Add"] == 0);
1185+
ASSERT_TRUE(op_to_count["Erf"] == 0);
1186+
ASSERT_TRUE(op_to_count["Mul"] == 0);
1187+
ASSERT_TRUE(op_to_count["Gelu"] == 1);
1188+
}
1189+
11331190
TEST(GraphTransformationTests, BiasGeluTest) {
11341191
auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion.onnx";
11351192
std::shared_ptr<Model> p_model;

onnxruntime/test/shared_lib/test_model_loading.cc

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
// Licensed under the MIT License.
33

44
#include "core/session/onnxruntime_cxx_api.h"
5+
#ifdef USE_CUDA
6+
#include "core/providers/cuda/cuda_provider_factory.h"
7+
#endif
58
#include <fstream>
69
#include "test_fixture.h"
710
#include "file_util.h"
@@ -25,6 +28,12 @@ TEST(CApiTest, model_from_array) {
2528

2629
Ort::SessionOptions so;
2730
Ort::Session session(*ort_env.get(), buffer.data(), buffer.size(), so);
31+
32+
#ifdef USE_CUDA
33+
// test with CUDA provider when using onnxruntime as dll
34+
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(so, 0));
35+
Ort::Session session_cuda(*ort_env.get(), buffer.data(), buffer.size(), so);
36+
#endif
2837
}
2938
} // namespace test
3039
} // namespace onnxruntime
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)