Skip to content

Commit 8a093ee

Browse files
authored
[onnx] fix fp16 gelu pattern (#503)
1 parent d88bb1e commit 8a093ee

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,15 @@ bool isSplatFPCloseTo(Attribute attr, double value, double eps = 1e-5) {
573573
return false;
574574
if (!elementsAttr.isSplat())
575575
return false;
576+
577+
auto dtype = elementsAttr.getElementType();
578+
auto bitwidth = dtype.getIntOrFloatBitWidth();
579+
if (bitwidth <= 8) {
580+
eps = std::max(fabs(eps), 1e-2);
581+
} else if (bitwidth == 16) {
582+
eps = std::max(fabs(eps), 1e-3);
583+
}
584+
576585
double diff =
577586
elementsAttr.getSplatValue<FloatAttr>().getValueAsDouble() - value;
578587
return fabs(diff) < eps;

frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,23 @@ func.func @test_gelu(%37: tensor<1x3x5x5xf32>) -> tensor<1x3x5x5xf32> {
235235

236236
// -----
237237

238+
func.func @test_gelu_fp16(%97 : tensor<16x5638x3072xf16>) -> tensor<16x5638x3072xf16> {
239+
%19 = onnx.Constant dense<5.000000e-01> : tensor<f16>
240+
%20 = onnx.Constant dense<1.000000e+00> : tensor<f16>
241+
%21 = "onnx.Constant"() {value = dense<1.414060e+00> : tensor<f16>} : () -> tensor<f16>
242+
%98 = "onnx.Div"(%97, %21) {onnx_node_name = "Div"} : (tensor<16x5638x3072xf16>, tensor<f16>) -> tensor<16x5638x3072xf16>
243+
%99 = "onnx.Erf"(%98) {onnx_node_name = "Erf"} : (tensor<16x5638x3072xf16>) -> tensor<16x5638x3072xf16>
244+
%100 = "onnx.Add"(%99, %20) {onnx_node_name = "Add"} : (tensor<16x5638x3072xf16>, tensor<f16>) -> tensor<16x5638x3072xf16>
245+
%101 = "onnx.Mul"(%97, %100) {onnx_node_name = "Mul"} : (tensor<16x5638x3072xf16>, tensor<16x5638x3072xf16>) -> tensor<16x5638x3072xf16>
246+
%102 = "onnx.Mul"(%101, %19) {onnx_node_name = "Mul_1"} : (tensor<16x5638x3072xf16>, tensor<f16>) -> tensor<16x5638x3072xf16>
247+
return %102 : tensor<16x5638x3072xf16>
248+
}
249+
// CHECK-LABEL: @test_gelu_fp16
250+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x5638x3072xf16>) -> tensor<16x5638x3072xf16> {
251+
// CHECK-NEXT: [[VAR_0_:%.+]] = stablehlo.custom_call @byteir.gelu([[PARAM_0_]]) {byteir_attrs = {approximate = "erf"}} : (tensor<16x5638x3072xf16>) -> tensor<16x5638x3072xf16>
252+
253+
// -----
254+
238255
func.func @test_gelu_without_last_mul(%arg0: tensor<1x3x5x5xf32>, %arg1: tensor<1x3x5x5xf32>) -> tensor<1x3x5x5xf32> {
239256
%38 = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
240257
%39 = "onnx.Add"(%arg0, %38) : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3x5x5xf32>

0 commit comments

Comments
 (0)