Skip to content

Commit 53e8888

Browse files
authored
Allow Real arguments in R functions and use Julia implementations only (#125)
* Allow arguments of type `Float32` and `Float16` in R functions * Add tests * Bump version * Allow `Real` if no Julia fallback exists and `Int`s, `UInt`s, and `Rational`s otherwise * Do not use R implementation when Julia implementation is available * Add more tests
1 parent fe65c00 commit 53e8888

22 files changed

+279
-130
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StatsFuns"
22
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3-
version = "0.9.10"
3+
version = "0.9.11"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -13,7 +13,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1313
[compat]
1414
ChainRulesCore = "1"
1515
IrrationalConstants = "0.1"
16-
LogExpFunctions = "0.3"
16+
LogExpFunctions = "0.3.2"
1717
Reexport = "1"
1818
Rmath = "0.4, 0.5, 0.6, 0.7"
1919
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ log4π, # log(4π)
4040
# basicfuns
4141
xlogx, # x * log(x), or 0 when x is zero
4242
xlogy, # x * log(y), or 0 when x is zero
43+
xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0
4344
logistic, # 1 / (1 + exp(-x))
4445
logit, # log(x / (1 - x))
4546
log1psq, # log(1 + x^2)

src/StatsFuns.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import ChainRulesCore
3535
@reexport using LogExpFunctions:
3636
xlogx, # x * log(x) for x > 0, or 0 when x == 0
3737
xlogy, # x * log(y) for x > 0, or 0 when x == 0
38+
xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0
3839
logistic, # 1 / (1 + exp(-x))
3940
logit, # log(x / (1 - x))
4041
log1psq, # log(1 + x^2)

src/chainrules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ ChainRulesCore.@scalar_rule(
1212
binomlogpdf(n::Real, p::Real, k::Real),
1313
@setup(z = digamma(n - k + 1)),
1414
(
15-
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
15+
ChainRulesCore.NoTangent(),
1616
(k / p - n) / (1 - p),
17-
z - digamma(k + 1) + logit(p),
17+
ChainRulesCore.NoTangent(),
1818
),
1919
)
2020

@@ -59,7 +59,7 @@ ChainRulesCore.@scalar_rule(
5959

6060
ChainRulesCore.@scalar_rule(
6161
poislogpdf::Number, x::Number),
62-
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
62+
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()),
6363
)
6464

6565
ChainRulesCore.@scalar_rule(

src/distrs/beta.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# functions related to beta distributions
22

3-
import .RFunctions:
4-
betapdf,
5-
betalogpdf,
3+
# R implementations
4+
# For pdf and logpdf we use the Julia implementation
5+
using .RFunctions:
66
betacdf,
77
betaccdf,
88
betalogcdf,
@@ -12,8 +12,13 @@ import .RFunctions:
1212
betainvlogcdf,
1313
betainvlogccdf
1414

15-
# pdf for numbers with generic types
16-
betapdf::Real, β::Real, x::Number) = x^- 1) * (1 - x)^- 1) / beta(α, β)
15+
# Julia implementations
16+
betapdf::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x))
1717

18-
# logpdf for numbers with generic types
19-
betalogpdf::Real, β::Real, x::Number) =- 1) * log(x) +- 1) * log1p(-x) - logbeta(α, β)
18+
betalogpdf::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...)
19+
function betalogpdf::T, β::T, x::T) where {T<:Real}
20+
# we ensure that `log(x)` and `log1p(-x)` do not error
21+
y = clamp(x, 0, 1)
22+
val = xlogy- 1, y) + xlog1py- 1, -y) - logbeta(α, β)
23+
return x < 0 || x > 1 ? oftype(val, -Inf) : val
24+
end

src/distrs/binom.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# functions related to binomial distribution
22

3-
import .RFunctions:
4-
binompdf,
5-
binomlogpdf,
3+
# R implementations
4+
# For pdf and logpdf we use the Julia implementation
5+
using .RFunctions:
66
binomcdf,
77
binomccdf,
88
binomlogcdf,
@@ -12,8 +12,13 @@ import .RFunctions:
1212
binominvlogcdf,
1313
binominvlogccdf
1414

15-
# pdf for numbers with generic types
15+
16+
# Julia implementations
1617
binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k))
1718

18-
# logpdf for numbers with generic types
19-
binomlogpdf(n::Real, p::Real, k::Real) = -log1p(n) - logbeta(n - k + 1, k + 1) + k * log(p) + (n - k) * log1p(-p)
19+
binomlogpdf(n::Real, p::Real, k::Real) = binomlogpdf(promote(n, p, k)...)
20+
function binomlogpdf(n::T, p::T, k::T) where {T<:Real}
21+
m = clamp(k, 0, n)
22+
val = betalogpdf(m + 1, n - m + 1, p) - log(n + 1)
23+
return 0 <= k <= n && isinteger(k) ? val : oftype(val, -Inf)
24+
end

src/distrs/chisq.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# functions related to chi-square distribution
22

3-
import .RFunctions:
4-
chisqpdf,
5-
chisqlogpdf,
3+
# R implementations
4+
# For pdf and logpdf we use the Julia implementation
5+
using .RFunctions:
66
chisqcdf,
77
chisqccdf,
88
chisqlogcdf,
@@ -12,14 +12,11 @@ import .RFunctions:
1212
chisqinvlogcdf,
1313
chisqinvlogccdf
1414

15-
# pdf for numbers with generic types
16-
function chisqpdf(k::Real, x::Number)
17-
hk = k / 2 # half k
18-
1 / (2^(hk) * gamma(hk)) * x^(hk - 1) * exp(-x / 2)
19-
end
15+
# Julia implementations
16+
# promotion ensures that we do forward e.g. `chisqpdf(::Int, ::Float32)` to
17+
# `gammapdf(::Float32, ::Int, ::Float32)` but not `gammapdf(::Float64, ::Int, ::Float32)`
18+
chisqpdf(k::Real, x::Real) = chisqpdf(promote(k, x)...)
19+
chisqpdf(k::T, x::T) where {T<:Real} = gammapdf(k / 2, 2, x)
2020

21-
# logpdf for numbers with generic types
22-
function chisqlogpdf(k::Real, x::Number)
23-
hk = k / 2 # half k
24-
-hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2
25-
end
21+
chisqlogpdf(k::Real, x::Real) = chisqlogpdf(promote(k, x)...)
22+
chisqlogpdf(k::T, x::T) where {T<:Real} = gammalogpdf(k / 2, 2, x)

src/distrs/fdist.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# functions related to F distribution
22

3-
import .RFunctions:
4-
fdistpdf,
5-
fdistlogpdf,
3+
# R implementations
4+
# For pdf and logpdf we use the Julia implementation
5+
using .RFunctions:
66
fdistcdf,
77
fdistccdf,
88
fdistlogcdf,
@@ -12,8 +12,14 @@ import .RFunctions:
1212
fdistinvlogcdf,
1313
fdistinvlogccdf
1414

15-
# pdf for numbers with generic types
16-
fdistpdf(ν1::Real, ν2::Real, x::Number) = sqrt((ν1 * x)^ν1 * ν2^ν2 / (ν1 * x + ν2)^(ν1 + ν2)) / (x * beta(ν1 / 2, ν2 / 2))
15+
# Julia implementations
16+
fdistpdf(ν1::Real, ν2::Real, x::Real) = exp(fdistlogpdf(ν1, ν2, x))
1717

18-
# logpdf for numbers with generic types
19-
fdistlogpdf(ν1::Real, ν2::Real, x::Number) = (ν1 * log(ν1 * x) + ν2 * log(ν2) - (ν1 + ν2) * log(ν1 * x + ν2)) / 2 - log(x) - logbeta(ν1 / 2, ν2 / 2)
18+
fdistlogpdf(ν1::Real, ν2::Real, x::Real) = fdistlogpdf(promote(ν1, ν2, x)...)
19+
function fdistlogpdf(ν1::T, ν2::T, x::T) where {T<:Real}
20+
# we ensure that `log(x)` does not error if `x < 0`
21+
ν1ν2 = ν1 / ν2
22+
y = max(x, 0)
23+
val = (xlogy(ν1, ν1ν2) + xlogy(ν1 - 2, y) - xlogy(ν1 + ν2, 1 + ν1ν2 * y)) / 2 - logbeta(ν1 / 2, ν2 / 2)
24+
return x < 0 ? oftype(val, -Inf) : val
25+
end

src/distrs/gamma.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# functions related to gamma distribution
22

3-
import .RFunctions:
4-
gammapdf,
5-
gammalogpdf,
3+
# R implementations
4+
# For pdf and logpdf we use the Julia implementation
5+
using .RFunctions:
66
gammacdf,
77
gammaccdf,
88
gammalogcdf,
@@ -12,8 +12,13 @@ import .RFunctions:
1212
gammainvlogcdf,
1313
gammainvlogccdf
1414

15-
# pdf for numbers with generic types
16-
gammapdf(k::Real, θ::Real, x::Number) = 1 / (gamma(k) * θ^k) * x^(k - 1) * exp(-x / θ)
15+
# Julia implementations
16+
gammapdf(k::Real, θ::Real, x::Real) = exp(gammalogpdf(k, θ, x))
1717

18-
# logpdf for numbers with generic types
19-
gammalogpdf(k::Real, θ::Real, x::Number) = -loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ
18+
gammalogpdf(k::Real, θ::Real, x::Real) = gammalogpdf(promote(k, θ, x)...)
19+
function gammalogpdf(k::T, θ::T, x::T) where {T<:Real}
20+
# we ensure that `log(x)` does not error if `x < 0`
21+
= max(x, 0) / θ
22+
val = -loggamma(k) + xlogy(k - 1, xθ) - log(θ) -
23+
return x < 0 ? oftype(val, -Inf) : val
24+
end

src/distrs/hyper.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# functions related to hyper-geometric distribution
22

3-
import .RFunctions:
3+
# R implementations
4+
using .RFunctions:
45
hyperpdf,
56
hyperlogpdf,
67
hypercdf,

0 commit comments

Comments
 (0)