Skip to content

Commit 72ff5df

Browse files
committed
Rewrite code to be friendly with dual numbers.
1 parent a25a1d1 commit 72ff5df

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed

src/activations.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -805,26 +805,41 @@ const deriv_telu_taylor_expansion = (tanh(1.0), 8*exp(1)^2 / (1+exp(1)^2)^2)
805805
# Various cutoffs for numerical evaluations of telu'(x)
806806
const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64))
807807
const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = -log(sqrt_eps_f16), -log(sqrt_eps_f32), -log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow
808-
@inline small_x_cutoff_deriv_telu(::Float16) = sqrt_eps_f16
809-
@inline small_x_cutoff_deriv_telu(::Float32) = sqrt_eps_f32
810-
@inline small_x_cutoff_deriv_telu(::Float64) = sqrt_eps_f64
811-
@inline small_x_cutoff_deriv_telu(::T) where T <: AbstractFloat = sqrt(eps(T))
812-
@inline minus_log_cutoff(::Float16) = minus_log_cutoff_f16
813-
@inline minus_log_cutoff(::Float32) = minus_log_cutoff_f32
814-
@inline minus_log_cutoff(::Float64) = minus_log_cutoff_f64
815-
@inline minus_log_cutoff(::T) where T <: AbstractFloat = -log(small_x_cutoff_deriv_telu(zero(T)))
816-
817-
@inline function _deriv_telu_taylor_expansion(x::T) where {T <: Union{Float16, Float32, Float64}}
818-
convert(T, deriv_telu_taylor_expansion[1]) + x * convert(T, deriv_telu_taylor_expansion[2])
808+
809+
@inline function _deriv_telu_taylor_expansion(x::Float64)
810+
deriv_telu_taylor_expansion[1] + x * deriv_telu_taylor_expansion[2]
811+
end
812+
813+
@inline function _deriv_telu_taylor_expansion(x::Float32)
814+
convert(Float32, deriv_telu_taylor_expansion[1]) + x * convert(Float32, deriv_telu_taylor_expansion[2])
815+
end
816+
817+
@inline function _deriv_telu_taylor_expansion(x::Float16)
818+
convert(Float16, deriv_telu_taylor_expansion[1]) + x * convert(Float16, deriv_telu_taylor_expansion[2])
819819
end
820820

821821
@inline function _deriv_telu_taylor_expansion(x::T) where {T <: AbstractFloat}
822822
tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2
823823
end
824824

825-
function deriv_telu_fast(x, Ω)
826-
ifelse(abs(x) < small_x_cutoff_deriv_telu(x), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827-
ifelse(x >= minus_log_cutoff(x), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy
825+
function deriv_telu_fast(x::Float64, Ω)
826+
ifelse(abs(x) < sqrt_eps_f64, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827+
ifelse(x >= minus_log_cutoff_f64, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
828+
end
829+
830+
function deriv_telu_fast(x::Float32, Ω)
831+
ifelse(abs(x) < sqrt_eps_f32, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
832+
ifelse(x >= minus_log_cutoff_f32, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
833+
end
834+
835+
function deriv_telu_fast(x::Float16, Ω)
836+
ifelse(abs(x) < sqrt_eps_f16, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
837+
ifelse(x >= minus_log_cutoff_f16, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
838+
end
839+
840+
function deriv_telu_fast(x::T, Ω) where T <: AbstractFloat
841+
ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
842+
ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
828843
end
829844

830845
@inline function _deriv_telu_fast(x, Ω)

0 commit comments

Comments
 (0)