You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@inline minus_log_cutoff(::T) where T <:AbstractFloat=-log(small_x_cutoff_deriv_telu(zero(T)))
816
+
817
+
@inline function _deriv_telu_taylor_expansion(x::T) where {T <:Union{Float16, Float32, Float64}}
818
+
convert(T, deriv_telu_taylor_expansion[1]) + x * convert(T, deriv_telu_taylor_expansion[2])
819
+
end
820
+
821
+
@inline function _deriv_telu_taylor_expansion(x::T) where {T <:AbstractFloat}
822
+
tanh(one(T)) + x *8*exp(one(T))^2/ (one(T)+exp(one(T))^2)^2
823
+
end
824
+
825
+
function deriv_telu_fast(x, Ω)
826
+
ifelse(abs(x) < small_x_cutoff_deriv_telu(x), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827
+
ifelse(x >= minus_log_cutoff(x), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy
828
+
end
829
+
830
+
@inline function _deriv_telu_fast(x, Ω)
831
+
tanh_exp_x = Ω / x
804
832
sech_exp_x_squared =1- tanh_exp_x^2
805
-
ifelse(x >=4, one(x), tanh_exp_x + x * exp(x) * sech_exp_x_squared) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy
0 commit comments