Skip to content

Commit 8701e7e

Browse files
committed
Use eps(T) for float type T to control cutoffs for telu derivative evaluations.
1 parent 72ff5df commit 8701e7e

File tree

1 file changed

+2
-36
lines changed

1 file changed

+2
-36
lines changed

src/activations.jl

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -799,45 +799,11 @@ function deriv_telu(x)
799799
tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2
800800
end
801801

802-
# 0th and 1st order Taylor expansion for telu'(x) around x=0
803-
const deriv_telu_taylor_expansion = (tanh(1.0), 8*exp(1)^2 / (1+exp(1)^2)^2)
804-
805-
# Various cutoffs for numerical evaluations of telu'(x)
806-
const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64))
807-
const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = -log(sqrt_eps_f16), -log(sqrt_eps_f32), -log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow
808-
809-
@inline function _deriv_telu_taylor_expansion(x::Float64)
810-
deriv_telu_taylor_expansion[1] + x * deriv_telu_taylor_expansion[2]
811-
end
812-
813-
@inline function _deriv_telu_taylor_expansion(x::Float32)
814-
convert(Float32, deriv_telu_taylor_expansion[1]) + x * convert(Float32, deriv_telu_taylor_expansion[2])
815-
end
816-
817-
@inline function _deriv_telu_taylor_expansion(x::Float16)
818-
convert(Float16, deriv_telu_taylor_expansion[1]) + x * convert(Float16, deriv_telu_taylor_expansion[2])
819-
end
820-
821-
@inline function _deriv_telu_taylor_expansion(x::T) where {T <: AbstractFloat}
802+
@inline function _deriv_telu_taylor_expansion(x::T) where T
822803
tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2
823804
end
824805

825-
function deriv_telu_fast(x::Float64, Ω)
826-
ifelse(abs(x) < sqrt_eps_f64, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827-
ifelse(x >= minus_log_cutoff_f64, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
828-
end
829-
830-
function deriv_telu_fast(x::Float32, Ω)
831-
ifelse(abs(x) < sqrt_eps_f32, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
832-
ifelse(x >= minus_log_cutoff_f32, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
833-
end
834-
835-
function deriv_telu_fast(x::Float16, Ω)
836-
ifelse(abs(x) < sqrt_eps_f16, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
837-
ifelse(x >= minus_log_cutoff_f16, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
838-
end
839-
840-
function deriv_telu_fast(x::T, Ω) where T <: AbstractFloat
806+
function deriv_telu_fast(x::T, Ω) where T
841807
ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
842808
ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
843809
end

0 commit comments

Comments
 (0)