55
66ACTIVATIONS = [
77 :σ, :hardσ, :hardtanh, :relu,
8- :leakyrelu, :relu6, :rrelu, :elu, :gelu , :gelu_full , :swish, :hardswish, :selu,
8+ :leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh , :gelu_erf , :swish, :hardswish, :selu,
99 :celu, :softplus, :softsign, :logσ, :logcosh,
1010 :mish, :tanhshrink, :softshrink, :trelu, :lisht,
1111 :tanh_fast, :sigmoid_fast,
@@ -301,14 +301,14 @@ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1))
301301deriv_elu(Ω, α= 1 ) = ifelse(Ω ≥ 0 , one(Ω), Ω + oftype(Ω, α))
302302
303303"""
304- gelu (x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
304+ gelu_tanh (x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
305305
306- Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) (see also [`gelu_full`](@ref)) .
306+ Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation .
307307
308308```julia-repl
309- julia> lineplot(gelu , -2, 2, height=7)
309+ julia> lineplot(gelu_tanh , -2, 2, height=7)
310310 ┌────────────────────────────────────────┐
311- 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu (x)
311+ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh (x)
312312 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│
313313 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│
314314 f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
@@ -319,11 +319,11 @@ julia> lineplot(gelu, -2, 2, height=7)
319319 ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
320320 ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
321321
322- julia> lineplot(gelu , -5, 0, height=7);
322+ julia> lineplot(gelu_tanh , -5, 0, height=7);
323323
324324julia> lineplot!(ans, swish)
325325 ┌────────────────────────────────────────┐
326- 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu (x)
326+ 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh (x)
327327 │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)
328328 │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│
329329 f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│
@@ -335,7 +335,7 @@ julia> lineplot!(ans, swish)
335335 ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
336336```
337337"""
338- function gelu (x)
338+ function gelu_tanh (x)
339339 α = oftf(x, 0.044715 )
340340 # λ = oftf(x, gelu_λ)
341341 # x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference
346346const gelu_λ = √ (2 / π)
347347const gelu_2λ = √ (8 / π)
348348
349- function deriv_gelu (x)
349+ function deriv_gelu_tanh (x)
350350 α = oftf(x, 0.044715 )
351351 α2 = oftf(x, 0.08943 )
352352 λλ = oftf(x, gelu_2λ)
@@ -358,18 +358,38 @@ function deriv_gelu(x)
358358end
359359
360360"""
361- gelu_full (x) = xΦ(x) = 0.5x * (1 + erf(x/√2))
361+ gelu_erf (x) = xΦ(x) = 0.5x * (1 + erf(x/√2))
362362
363363Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation.
364364"""
365- gelu_full (x) = x/ 2 * (1 + erf (x/ sqrt(oftf(x,2 ))))
365+ gelu_erf (x) = x/ 2 * (1 + _erf (x/ sqrt(oftf(x,2 ))))
366366
367- function deriv_gelu_full (x)
367+ function deriv_gelu_erf (x)
368368 SQRT2 = sqrt(oftf(x,2 ))
369- Φ = (1 + erf (x/ SQRT2))/ 2
369+ Φ = (1 + _erf (x/ SQRT2))/ 2
370370 Φ + x/ SQRT2* exp(- (x^ 2 )/ 2 )/ sqrt(oftf(x,π))
371371end
372372
373+ _erf(x:: Number ) = _erf(float(x))
374+ _erf(x:: Float64 ) = ccall((:erf, libopenlibm), Float64, (Float64,), x)
375+ _erf(x:: Float32 ) = ccall((:erff, libopenlibm), Float32, (Float32,), x)
376+ _erf(x:: Float16 ) = Float16(_erf(Float32(x)))
377+ _erf(x:: BigFloat ) = begin
378+ z = BigFloat(x)
379+ ccall((:mpfr_erf, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, Base. MPFR. ROUNDING_MODE[])
380+ return z
381+ end
382+
383+ """
384+ gelu(x) = gelu_tanh(x)
385+
386+ Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415).
387+ See [`gelu_tanh`](@ref).
388+ """
389+ const gelu = gelu_tanh
390+ export gelu
391+ const deriv_gelu = deriv_gelu_tanh
392+
373393"""
374394 swish(x) = x * σ(x)
375395
@@ -887,8 +907,8 @@ UNARY_ACTS = [ # f, dfdx
887907 (:relu6, :((Ω> 0 ) & (Ω< 6 ))),
888908 # rrelu is random, can't write a rule.
889909 (:elu, :(deriv_elu(Ω))),
890- (:gelu , :(deriv_gelu (x))),
891- (:gelu_full , :(deriv_gelu_full (x))),
910+ (:gelu_tanh , :(deriv_gelu_tanh (x))),
911+ (:gelu_erf , :(deriv_gelu_erf (x))),
892912 (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))),
893913 (:hardswish, :(deriv_hardswish(x))),
894914 # lisht
0 commit comments