You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@inline minus_log_cutoff(::T) where T <:AbstractFloat=-log(small_x_cutoff_deriv_telu(zero(T)))
816
-
817
-
@inline function _deriv_telu_taylor_expansion(x::T) where {T <:Union{Float16, Float32, Float64}}
818
-
convert(T, deriv_telu_taylor_expansion[1]) + x * convert(T, deriv_telu_taylor_expansion[2])
808
+
809
+
@inline function _deriv_telu_taylor_expansion(x::Float64)
810
+
deriv_telu_taylor_expansion[1] + x * deriv_telu_taylor_expansion[2]
811
+
end
812
+
813
+
@inline function _deriv_telu_taylor_expansion(x::Float32)
814
+
convert(Float32, deriv_telu_taylor_expansion[1]) + x * convert(Float32, deriv_telu_taylor_expansion[2])
815
+
end
816
+
817
+
@inline function _deriv_telu_taylor_expansion(x::Float16)
818
+
convert(Float16, deriv_telu_taylor_expansion[1]) + x * convert(Float16, deriv_telu_taylor_expansion[2])
819
819
end
820
820
821
821
@inline function _deriv_telu_taylor_expansion(x::T) where {T <:AbstractFloat}
822
822
tanh(one(T)) + x *8*exp(one(T))^2/ (one(T)+exp(one(T))^2)^2
823
823
end
824
824
825
-
function deriv_telu_fast(x, Ω)
826
-
ifelse(abs(x) < small_x_cutoff_deriv_telu(x), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827
-
ifelse(x >= minus_log_cutoff(x), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy
825
+
function deriv_telu_fast(x::Float64, Ω)
826
+
ifelse(abs(x) < sqrt_eps_f64, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827
+
ifelse(x >= minus_log_cutoff_f64, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
828
+
end
829
+
830
+
function deriv_telu_fast(x::Float32, Ω)
831
+
ifelse(abs(x) < sqrt_eps_f32, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
832
+
ifelse(x >= minus_log_cutoff_f32, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
833
+
end
834
+
835
+
function deriv_telu_fast(x::Float16, Ω)
836
+
ifelse(abs(x) < sqrt_eps_f16, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
837
+
ifelse(x >= minus_log_cutoff_f16, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
838
+
end
839
+
840
+
function deriv_telu_fast(x::T, Ω) where T <:AbstractFloat
841
+
ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
842
+
ifelse(x >=-log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
0 commit comments