Skip to content

Commit 44ee4f5

Browse files
committed
Rename gelus and use OpenLibm_jll instead of SpecialFunctions.jl for erf
1 parent bdfa0f0 commit 44ee4f5

File tree

5 files changed

+47
-23
lines changed

5 files changed

+47
-23
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1010
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515

1616
[weakdeps]
@@ -41,8 +41,8 @@ ForwardDiff = "0.10.36"
4141
GPUArraysCore = "0.1, 0.2"
4242
KernelAbstractions = "0.9.2"
4343
LinearAlgebra = "<0.0.1, 1"
44+
OpenLibm_jll = "0.8.1"
4445
Random = "<0.0.1, 1"
45-
SpecialFunctions = "2.5.0"
4646
Statistics = "1"
4747
cuDNN = "1"
4848
julia = "1.9"

docs/src/reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ Non-linearities that go between layers of your model. Note that, unless otherwis
1010
celu
1111
elu
1212
gelu
13-
gelu_full
13+
gelu_tanh
14+
gelu_erf
1415
hardsigmoid
1516
sigmoid_fast
1617
hardtanh

src/NNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose
1515
using Random
1616
using Statistics
1717
using Statistics: mean
18-
using SpecialFunctions
18+
using OpenLibm_jll
1919

2020
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}
2121

src/activations.jl

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
ACTIVATIONS = [
77
:σ, :hardσ, :hardtanh, :relu,
8-
:leakyrelu, :relu6, :rrelu, :elu, :gelu, :gelu_full, :swish, :hardswish, :selu,
8+
:leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_erf, :swish, :hardswish, :selu,
99
:celu, :softplus, :softsign, :logσ, :logcosh,
1010
:mish, :tanhshrink, :softshrink, :trelu, :lisht,
1111
:tanh_fast, :sigmoid_fast,
@@ -301,14 +301,14 @@ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1))
301301
deriv_elu(Ω, α=1) = ifelse(Ω 0, one(Ω), Ω + oftype(Ω, α))
302302

303303
"""
304-
gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
304+
gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
305305
306-
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) (see also [`gelu_full`](@ref)).
306+
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation.
307307
308308
```julia-repl
309-
julia> lineplot(gelu, -2, 2, height=7)
309+
julia> lineplot(gelu_tanh, -2, 2, height=7)
310310
┌────────────────────────────────────────┐
311-
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu(x)
311+
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x)
312312
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│
313313
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│
314314
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
@@ -319,11 +319,11 @@ julia> lineplot(gelu, -2, 2, height=7)
319319
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
320320
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
321321
322-
julia> lineplot(gelu, -5, 0, height=7);
322+
julia> lineplot(gelu_tanh, -5, 0, height=7);
323323
324324
julia> lineplot!(ans, swish)
325325
┌────────────────────────────────────────┐
326-
0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu(x)
326+
0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x)
327327
│⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)
328328
│⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│
329329
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│
@@ -335,7 +335,7 @@ julia> lineplot!(ans, swish)
335335
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
336336
```
337337
"""
338-
function gelu(x)
338+
function gelu_tanh(x)
339339
α = oftf(x, 0.044715)
340340
# λ = oftf(x, gelu_λ)
341341
# x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference
@@ -346,7 +346,7 @@ end
346346
const gelu_λ = (2 / π)
347347
const gelu_2λ = (8 / π)
348348

349-
function deriv_gelu(x)
349+
function deriv_gelu_tanh(x)
350350
α = oftf(x, 0.044715)
351351
α2 = oftf(x, 0.08943)
352352
λλ = oftf(x, gelu_2λ)
@@ -358,18 +358,38 @@ function deriv_gelu(x)
358358
end
359359

360360
"""
361-
gelu_full(x) = xΦ(x) = 0.5x * (1 + erf(x/√2))
361+
gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2))
362362
363363
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation.
364364
"""
365-
gelu_full(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))))
365+
gelu_erf(x) = x/2*(1 + _erf(x/sqrt(oftf(x,2))))
366366

367-
function deriv_gelu_full(x)
367+
function deriv_gelu_erf(x)
368368
SQRT2 = sqrt(oftf(x,2))
369-
Φ = (1 + erf(x/SQRT2))/2
369+
Φ = (1 + _erf(x/SQRT2))/2
370370
Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π))
371371
end
372372

373+
_erf(x::Number) = _erf(float(x))
374+
_erf(x::Float64) = ccall((:erf, libopenlibm), Float64, (Float64,), x)
375+
_erf(x::Float32) = ccall((:erff, libopenlibm), Float32, (Float32,), x)
376+
_erf(x::Float16) = Float16(_erf(Float32(x)))
377+
_erf(x::BigFloat) = begin
378+
z = BigFloat(x)
379+
ccall((:mpfr_erf, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, Base.MPFR.ROUNDING_MODE[])
380+
return z
381+
end
382+
383+
"""
384+
gelu(x) = gelu_tanh(x)
385+
386+
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415).
387+
See [`gelu_tanh`](@ref).
388+
"""
389+
const gelu = gelu_tanh
390+
export gelu
391+
const deriv_gelu = deriv_gelu_tanh
392+
373393
"""
374394
swish(x) = x * σ(x)
375395
@@ -887,8 +907,8 @@ UNARY_ACTS = [ # f, dfdx
887907
(:relu6, :((Ω>0) &<6))),
888908
# rrelu is random, can't write a rule.
889909
(:elu, :(deriv_elu(Ω))),
890-
(:gelu, :(deriv_gelu(x))),
891-
(:gelu_full, :(deriv_gelu_full(x))),
910+
(:gelu_tanh, :(deriv_gelu_tanh(x))),
911+
(:gelu_erf, :(deriv_gelu_erf(x))),
892912
(:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))),
893913
(:hardswish, :(deriv_hardswish(x))),
894914
# lisht

test/activations.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
1212
@test rrelu(0.0) == 0.0
1313
@test elu(0.0) == 0.0
1414
@test gelu(0.0) == 0.0
15-
@test gelu_full(0.0) == 0.0
15+
@test gelu_tanh(0.0) == 0.0
16+
@test gelu_erf(0.0) == 0.0
1617
@test swish(0.0) == 0.0
1718
@test hardswish(0.0) == 0.0
1819
@test lisht(0.0) == 0.0
@@ -37,7 +38,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
3738
@test rrelu(1.0) == 1.0
3839
@test elu(1.0) == 1.0
3940
@test gelu(1.0) == 0.8411919906082768
40-
@test gelu_full(1.0) == 0.8413447460685429
41+
@test gelu_tanh(1.0) == 0.8411919906082768
42+
@test gelu_erf(1.0) == 0.8413447460685429
4143
@test swish(1.0) == sigmoid(1.0)
4244
@test hardswish(1.0) == hardsigmoid(1.0)
4345
@test lisht(1.0) 1.0 * tanh(1.0)
@@ -60,7 +62,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
6062
@test -1/3.0 <= rrelu(-1.0) <= -1/8.0
6163
@test elu(-1.0) == exp(-1.0) - 1.0
6264
@test gelu(-1.0) -0.15880800939172324
63-
@test gelu_full(-1.0) == -0.15865525393145707
65+
@test gelu_tanh(-1.0) -0.15880800939172324
66+
@test gelu_erf(-1.0) == -0.15865525393145707
6467
@test swish(-1.0) == -sigmoid(-1.0)
6568
@test hardswish(-1.0) == -hardsigmoid(-1.0)
6669
@test lisht(-1.0) -1.0 * tanh(-1.0)
@@ -117,7 +120,7 @@ end
117120
a == softsign && continue
118121
@test !isnan(a(Inf32))
119122

120-
a in [gelu, gelu_full, swish, hardswish, logcosh, mish] && continue
123+
a in [gelu, gelu_tanh, gelu_erf, swish, hardswish, logcosh, mish] && continue
121124
@test !isnan(a(-Inf32))
122125
end
123126
end

0 commit comments

Comments
 (0)