Skip to content

Commit 31f801d

Browse files
authored
fix (#75605)
1 parent a02d1aa commit 31f801d

2 files changed

Lines changed: 16 additions & 14 deletions

File tree

paddle/phi/kernels/gpu/gelu_grad_kernel.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@ struct GeluWithApproximateGradFunctor {
3333
MPType dout = static_cast<MPType>(arg_dout);
3434
MPType one = static_cast<MPType>(1);
3535
MPType half = static_cast<MPType>(0.5);
36-
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
37-
MPType kBeta =
38-
kAlpha * static_cast<MPType>(GELU_CONSTANT) * static_cast<MPType>(3);
36+
MPType kAlpha = M_SQRT2 * M_2_SQRTPI * static_cast<MPType>(0.5);
37+
MPType kBeta = static_cast<MPType>(GELU_CONSTANT);
38+
auto x_seq = x * x;
3939
auto cube_x = x * x * x;
40-
auto tanh_out =
41-
tanh(kAlpha * ((static_cast<MPType>(GELU_CONSTANT) * cube_x) + x));
42-
auto ans =
43-
half * (one + tanh_out +
44-
(one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
40+
auto tanh_out = tanh(kAlpha * ((kBeta * cube_x) + x));
41+
auto ans = half * (one + tanh_out) +
42+
half * x * (one - tanh_out * tanh_out) *
43+
(kAlpha * (one + static_cast<MPType>(3) * kBeta * x_seq));
4544
return static_cast<T>(ans * dout);
4645
}
4746
};
@@ -52,8 +51,9 @@ struct GeluWithoutApproximateGradFunctor {
5251
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
5352
MPType x = static_cast<MPType>(arg_x);
5453
MPType dout = static_cast<MPType>(arg_dout);
55-
constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast<MPType>(0.5);
56-
const MPType cdf = normcdf(x);
54+
constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * MPType(0.5);
55+
constexpr MPType kAlpha = M_SQRT1_2;
56+
const MPType cdf = MPType(0.5) * (MPType(1) + std::erf(x * kAlpha));
5757
const MPType pdf = exp(static_cast<MPType>(-0.5) * x * x) * kBeta;
5858
return static_cast<T>(dout * (cdf + x * pdf));
5959
}

paddle/phi/kernels/gpu/gelu_kernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ struct GeluWithApproximateFunctor {
3737
MPType x = static_cast<MPType>(arg_x);
3838
MPType one = static_cast<MPType>(1);
3939
MPType half = static_cast<MPType>(0.5);
40-
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
40+
MPType kAlpha = M_SQRT2 * M_2_SQRTPI * MPType(0.5);
4141
auto tanh_out =
42-
tanh(kAlpha * x * (one + static_cast<MPType>(GELU_CONSTANT) * x * x));
43-
MPType out = x * half * (one + tanh_out);
42+
tanh(kAlpha * (x + static_cast<MPType>(GELU_CONSTANT) * (x * x * x)));
43+
MPType out = half * x * (one + tanh_out);
4444
return static_cast<T>(out);
4545
}
4646
};
@@ -51,7 +51,9 @@ struct GeluWithoutApproximateFunctor {
5151
inline HOSTDEVICE T operator()(T arg_x) {
5252
// actual gelu with approximation = false
5353
MPType x = static_cast<MPType>(arg_x);
54-
return static_cast<T>(x * normcdf(x));
54+
// return static_cast<T>(x * normcdf(x));
55+
constexpr MPType kAlpha = M_SQRT1_2;
56+
return static_cast<T>(x * MPType(0.5) * (MPType(1) + std::erf(x * kAlpha)));
5557
}
5658
};
5759

0 commit comments

Comments
 (0)