@@ -33,15 +33,14 @@ struct GeluWithApproximateGradFunctor {
3333 MPType dout = static_cast <MPType>(arg_dout);
3434 MPType one = static_cast <MPType>(1 );
3535 MPType half = static_cast <MPType>(0.5 );
36- MPType kAlpha = static_cast <MPType>(M_2_SQRTPI * M_SQRT1_2 );
37- MPType kBeta =
38- kAlpha * static_cast <MPType>(GELU_CONSTANT) * static_cast <MPType>( 3 ) ;
36+ MPType kAlpha = M_SQRT2 * M_2_SQRTPI * static_cast <MPType>(0.5 );
37+ MPType kBeta = static_cast <MPType>(GELU_CONSTANT);
38+ auto x_seq = x * x ;
3939 auto cube_x = x * x * x;
40- auto tanh_out =
41- tanh (kAlpha * ((static_cast <MPType>(GELU_CONSTANT) * cube_x) + x));
42- auto ans =
43- half * (one + tanh_out +
44- (one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
40+ auto tanh_out = tanh (kAlpha * ((kBeta * cube_x) + x));
41+ auto ans = half * (one + tanh_out) +
42+ half * x * (one - tanh_out * tanh_out) *
43+ (kAlpha * (one + static_cast <MPType>(3 ) * kBeta * x_seq));
4544 return static_cast <T>(ans * dout);
4645 }
4746};
@@ -52,8 +51,9 @@ struct GeluWithoutApproximateGradFunctor {
5251 inline HOSTDEVICE T operator ()(T arg_x, T arg_dout) {
5352 MPType x = static_cast <MPType>(arg_x);
5453 MPType dout = static_cast <MPType>(arg_dout);
55- constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast <MPType>(0.5 );
56- const MPType cdf = normcdf (x);
54+ constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * MPType (0.5 );
55+ constexpr MPType kAlpha = M_SQRT1_2;
56+ const MPType cdf = MPType (0.5 ) * (MPType (1 ) + std::erf (x * kAlpha ));
5757 const MPType pdf = exp (static_cast <MPType>(-0.5 ) * x * x) * kBeta ;
5858 return static_cast <T>(dout * (cdf + x * pdf));
5959 }
0 commit comments