@@ -799,45 +799,11 @@ function deriv_telu(x)
799799 tanh(exp_x) + 4 x / (exp(exp_x - x/ 2 ) + exp(- exp_x - x/ 2 ))^ 2
800800end
801801
802- # 0th and 1st order Taylor expansion for telu'(x) around x=0
803- const deriv_telu_taylor_expansion = (tanh(1.0 ), 8 * exp(1 )^ 2 / (1 + exp(1 )^ 2 )^ 2 )
804-
805- # Various cutoffs for numerical evaluations of telu'(x)
806- const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64))
807- const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = - log(sqrt_eps_f16), - log(sqrt_eps_f32), - log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow
808-
809- @inline function _deriv_telu_taylor_expansion(x:: Float64 )
810- deriv_telu_taylor_expansion[1 ] + x * deriv_telu_taylor_expansion[2 ]
811- end
812-
813- @inline function _deriv_telu_taylor_expansion(x:: Float32 )
814- convert(Float32, deriv_telu_taylor_expansion[1 ]) + x * convert(Float32, deriv_telu_taylor_expansion[2 ])
815- end
816-
817- @inline function _deriv_telu_taylor_expansion(x:: Float16 )
818- convert(Float16, deriv_telu_taylor_expansion[1 ]) + x * convert(Float16, deriv_telu_taylor_expansion[2 ])
819- end
820-
821- @inline function _deriv_telu_taylor_expansion(x:: T ) where {T <: AbstractFloat }
802+ @inline function _deriv_telu_taylor_expansion(x:: T ) where T
822803 tanh(one(T)) + x * 8 * exp(one(T))^ 2 / (one(T)+ exp(one(T))^ 2 )^ 2
823804end
824805
825- function deriv_telu_fast(x:: Float64 , Ω)
826- ifelse(abs(x) < sqrt_eps_f64, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827- ifelse(x >= minus_log_cutoff_f64, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
828- end
829-
830- function deriv_telu_fast(x:: Float32 , Ω)
831- ifelse(abs(x) < sqrt_eps_f32, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
832- ifelse(x >= minus_log_cutoff_f32, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
833- end
834-
835- function deriv_telu_fast(x:: Float16 , Ω)
836- ifelse(abs(x) < sqrt_eps_f16, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
837- ifelse(x >= minus_log_cutoff_f16, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
838- end
839-
840- function deriv_telu_fast(x:: T , Ω) where T <: AbstractFloat
806+ function deriv_telu_fast(x:: T , Ω) where T
841807 ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
842808 ifelse(x >= - log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
843809end
0 commit comments