Skip to content

Commit a25a1d1

Browse files
committed
Add @fastmath for telu_fast and reuse telu(x) to compute telu'(x)
1 parent 245de57 commit a25a1d1

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

src/activations.jl

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -791,20 +791,51 @@ telu(x) = x * tanh(exp(x))
791791
This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative,
792792
`deriv_telu_fast`, which is faster but less accurate that `deriv_telu`.
793793
"""
794-
telu_fast(x) = x * tanh_fast(exp(x))
794+
telu_fast(x) = @fastmath x * tanh_fast(exp(x))
795795

796796
# Adapted from the Discourse post: <https://discourse.julialang.org/t/how-to-compute-tanhexp-telu-function-accurately/124464/7>
797797
function deriv_telu(x)
798798
exp_x = exp(x)
799799
tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2
800800
end
801801

802-
function deriv_telu_fast(x)
803-
tanh_exp_x = tanh(exp(x))
802+
# 0th and 1st order Taylor expansion for telu'(x) around x=0
803+
const deriv_telu_taylor_expansion = (tanh(1.0), 8*exp(1)^2 / (1+exp(1)^2)^2)
804+
805+
# Various cutoffs for numerical evaluations of telu'(x)
806+
const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64))
807+
const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = -log(sqrt_eps_f16), -log(sqrt_eps_f32), -log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow
808+
@inline small_x_cutoff_deriv_telu(::Float16) = sqrt_eps_f16
809+
@inline small_x_cutoff_deriv_telu(::Float32) = sqrt_eps_f32
810+
@inline small_x_cutoff_deriv_telu(::Float64) = sqrt_eps_f64
811+
@inline small_x_cutoff_deriv_telu(::T) where T <: AbstractFloat = sqrt(eps(T))
812+
@inline minus_log_cutoff(::Float16) = minus_log_cutoff_f16
813+
@inline minus_log_cutoff(::Float32) = minus_log_cutoff_f32
814+
@inline minus_log_cutoff(::Float64) = minus_log_cutoff_f64
815+
@inline minus_log_cutoff(::T) where T <: AbstractFloat = -log(small_x_cutoff_deriv_telu(zero(T)))
816+
817+
@inline function _deriv_telu_taylor_expansion(x::T) where {T <: Union{Float16, Float32, Float64}}
818+
convert(T, deriv_telu_taylor_expansion[1]) + x * convert(T, deriv_telu_taylor_expansion[2])
819+
end
820+
821+
@inline function _deriv_telu_taylor_expansion(x::T) where {T <: AbstractFloat}
822+
tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2
823+
end
824+
825+
function deriv_telu_fast(x, Ω)
826+
ifelse(abs(x) < small_x_cutoff_deriv_telu(x), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
827+
ifelse(x >= minus_log_cutoff(x), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy
828+
end
829+
830+
@inline function _deriv_telu_fast(x, Ω)
831+
tanh_exp_x = Ω / x
804832
sech_exp_x_squared = 1 - tanh_exp_x^2
805-
ifelse(x >= 4, one(x), tanh_exp_x + x * exp(x) * sech_exp_x_squared) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy
833+
tanh_exp_x + x * exp(x) * sech_exp_x_squared
806834
end
807835

836+
# for testing accuracy
837+
_deriv_telu_fast(x) = deriv_telu_fast(x, telu_fast(x))
838+
808839
# Define broadcasts for activation functions on arrays
809840
for f in ACTIVATIONS
810841
@eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)
@@ -948,7 +979,7 @@ UNARY_ACTS = [ # f, dfdx
948979
## Fast variants are the same!
949980
(:tanh_fast, :(conj(1 - Ω^2))),
950981
(:sigmoid_fast, :(conj(Ω * (1 - Ω)))),
951-
(:telu_fast, :(deriv_telu_fast(x)))
982+
(:telu_fast, :(deriv_telu_fast(x, Ω)))
952983
]
953984

954985
for (f, dfdx) in UNARY_ACTS

test/activations.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ end
214214

215215
## Faster variants
216216

217-
using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, deriv_telu_fast
217+
using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, _deriv_telu_fast
218218

219219
function countepsfrom(x::T, xtrue) where {T<:AbstractFloat}
220220
target = T(xtrue)
@@ -269,8 +269,8 @@ end
269269
mean_eps(telu, telu, x64) # 0.1146
270270
worst_eps(telu, telu, x64) # 2
271271

272-
@test mean_eps(telu_fast, telu, x64) < 0.13 # 0.12204
273-
@test worst_eps(telu_fast, telu, x64) <= 3 # 2
272+
@test mean_eps(telu_fast, telu, x64) < 0.14 # 0.1338
273+
@test worst_eps(telu_fast, telu, x64) <= 4 # 3
274274

275275
@test telu_fast.(xbig[1:end-1]) telu.(xbig[1:end-1])
276276
@test telu_fast.(-xbig[1:end-1]) telu.(-xbig[1:end-1])
@@ -279,11 +279,11 @@ end
279279
mean_eps(deriv_telu, deriv_telu, x64) # 0.09304
280280
worst_eps(deriv_telu, deriv_telu, x64) # 2
281281

282-
@test mean_eps(deriv_telu_fast, deriv_telu, x64) < 2.1 # 2.05944
283-
@test worst_eps(deriv_telu_fast, deriv_telu, x64) <= 29 # 28
282+
@test mean_eps(_deriv_telu_fast, deriv_telu, x64) < 4.1 # 4.06396
283+
@test worst_eps(_deriv_telu_fast, deriv_telu, x64) <= 125 # 120
284284

285-
@test deriv_telu_fast.(xbig[1:end-1]) deriv_telu.(xbig[1:end-1])
286-
@test deriv_telu_fast.(-xbig[1:end-1]) deriv_telu.(-xbig[1:end-1])
285+
@test _deriv_telu_fast.(xbig[1:end-1]) deriv_telu.(xbig[1:end-1])
286+
@test _deriv_telu_fast.(-xbig[1:end-1]) deriv_telu.(-xbig[1:end-1])
287287
end
288288
end
289289

@@ -335,11 +335,11 @@ end
335335
mean_eps(deriv_telu, deriv_telu, x32) # 0.07228
336336
worst_eps(deriv_telu, deriv_telu, x32) # 1
337337

338-
@test mean_eps(deriv_telu_fast, deriv_telu, x32) < 0.69 # 0.68772
339-
@test worst_eps(deriv_telu_fast, deriv_telu, x32) <= 11 # 10
338+
@test mean_eps(_deriv_telu_fast, deriv_telu, x32) < 2.4 # 2.31772
339+
@test worst_eps(_deriv_telu_fast, deriv_telu, x32) <= 70 # 66
340340

341-
@test deriv_telu_fast.(xbig32[1:end-1]) deriv_telu.(xbig32[1:end-1])
342-
@test deriv_telu_fast.(-xbig32[1:end-1]) deriv_telu.(-xbig32[1:end-1])
341+
@test _deriv_telu_fast.(xbig32[1:end-1]) deriv_telu.(xbig32[1:end-1])
342+
@test _deriv_telu_fast.(-xbig32[1:end-1]) deriv_telu.(-xbig32[1:end-1])
343343
end
344344
end
345345

0 commit comments

Comments
 (0)