Skip to content

Commit 03fbbfd

Browse files
committed
rft: make gamma(a,x) GPU compatible
- improve type stability for En_expand_origin_general, En_safe_expfact, En_expand_origin - this improves type stability for expint(a, x) and thus gamma(a, x) - make gamma(n::Int) for n>20 call gamma(Float64(n)) for compatibility with CUDA.jl's device override of the gamma function - manually inline single recursion in _zeta, _trigamma - modify while loop in _expint to guarantee termination - add related tests
1 parent d71eddb commit 03fbbfd

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

src/expint.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ end
284284
# series about origin, general ν
285285
# https://functions.wolfram.com/GammaBetaErf/ExpIntegralE/06/01/04/01/01/0003/
286286
function En_expand_origin_general::Number, z::Number, niter::Integer)
287+
FT = promote_type(typeof(ν), typeof(real(z))) # ensure return type isn't more precise than inputs
287288
# gammaterm = En_safe_gamma_term(ν, z)
288-
gammaterm = gamma(1-ν)*z^-1)
289+
gammaterm = FT(gamma(1-ν))*z^-1)
289290
frac = one(z)
290291
blowup = abs(1 - ν) < 0.5 ? frac / (1 - ν) : zero(z)
291292
sumterm = abs(1 - ν) < 0.5 ? zero(z) : frac / (1 - ν)
@@ -321,13 +322,14 @@ function En_expand_origin_general(ν::Number, z::Number, niter::Integer)
321322
series2 += (7π^4 + 15*(ψ₀^4 + 2ψ₀^2 *^2 - 3ψ₁) + ψ₁*(-2π^2 + 3ψ₁) + 4ψ₀*ψ₂) - 15ψ₃)*δ^3/360
322323
series2 += (3ψ₀^5 + ψ₀^3*(10π^2 - 30ψ₁) + 30ψ₀^2*ψ₂ + ψ₀*(45ψ₁^2 - 30π^2*ψ₁ - 15ψ₃ + 7π^4) - 30ψ₁*ψ₂ + 10π^2*ψ₂ + 3ψ₄)*δ^4/360
323324

324-
return (series1 + series2) * En_safe_expfact(n, z) * z^-n-1) - sumterm
325+
return (series1 + FT(series2)) * En_safe_expfact(n, z) * z^-n-1) - sumterm
325326
end
326327
return gammaterm - (blowup + sumterm)
327328
end
328329

329330
# compute (-z)^n / n!, avoiding overflow if possible, where n is an integer ≥ 0 (but not necessarily an Integer)
330331
function En_safe_expfact(n::Real, z::Number)
332+
FT = eltype(real(z)) # get the floating point type of the input
331333
if n < 100
332334
powerterm = one(z)
333335
for i = 1:Int(n)
@@ -337,9 +339,9 @@ function En_safe_expfact(n::Real, z::Number)
337339
else
338340
if z isa Real
339341
sgn = z 0 ? one(n) : (n <= typemax(Int) ? (isodd(Int(n)) ? -one(n) : one(n)) : (-1)^n)
340-
return sgn * exp(n * log(abs(z)) - loggamma(n+1))
342+
return sgn * exp(n * log(abs(z)) - loggamma(FT(n+1)))
341343
else
342-
return exp(n * log(-z) - loggamma(n+1))
344+
return exp(n * log(-z) - loggamma(FT(n+1)))
343345
end
344346
end
345347
end
@@ -379,10 +381,11 @@ end
379381
# can find imaginary part of E_ν(x) for x on negative real axis analytically
380382
# https://functions.wolfram.com/GammaBetaErf/ExpIntegralE/04/05/01/0003/
381383
function En_imagbranchcut::Number, z::Number)
384+
FT = promote_type(typeof(real(ν)), typeof(real(z)))
382385
a = real(z)
383386
e1 = exp(oftype(a, π) * imag(ν))
384-
e2 = Complex(cospi(real(ν)), -sinpi(real(ν)))
385-
lgamma, lgammasign = ν isa Real ? logabsgamma(ν) : (loggamma(ν), 1)
387+
e2 = Complex(cospi(FT(real(ν))), -sinpi(FT(real))))
388+
lgamma, lgammasign = ν isa Real ? logabsgamma(FT(ν)) : (loggamma(ν), 1)
386389
return -2 * lgammasign * e1 * π * e2 * exp((ν-1)*log(complex(a)) - lgamma) * im
387390
end
388391

@@ -468,12 +471,14 @@ function _expint(ν::Number, z::Number, niter::Int=1000, ::Val{expscaled}=Val{fa
468471
end
469472
return doconj ? conj(E_start) : E_start
470473
end
471-
while i == quick_niter
474+
doublings = 0
475+
while i == quick_niter && doublings < 50
472476
# double imaginary part until in region with fast convergence
473477
imstart *= 2
474478
z₀ = rez + imstart*im
475479
g_start, cf_start, i, _ = En_cf(ν, z₀, quick_niter)
476480
E_start = g_start + En_safeexpmult(-z₀, cf_start)
481+
doublings += 1
477482
end
478483

479484
# nsteps chosen so |Δ| ≤ 0.5

src/gamma.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ trigamma(x::Number) = _trigamma(float(x))
6464

6565
function _trigamma(z::ComplexOrReal{Float64})
6666
# via the derivative of the Kölbig digamma formulation
67+
z′ = z # save original value
68+
z = ifelse(real(z′) <= 0, 1 - z′, z′) # reflection formula (unrolled recursion)
6769
x = real(z)
68-
if x <= 0 # reflection formula
69-
return/ sinpi(z))^2 - trigamma(1 - z)
70-
end
7170
ψ = zero(z)
7271
N = 10
7372
if x < N
@@ -84,6 +83,7 @@ function _trigamma(z::ComplexOrReal{Float64})
8483
ψ += t + 0.5*w
8584
# the coefficients here are Float64(bernoulli[2:9])
8685
ψ += t*w * @evalpoly(w,0.16666666666666666,-0.03333333333333333,0.023809523809523808,-0.03333333333333333,0.07575757575757576,-0.2531135531135531,1.1666666666666667,-7.092156862745098)
86+
ifelse(real(z′) <= 0, (π / sinpi(z′))^2 - ψ, ψ) # reflection formula (unrolled recursion)
8787
end
8888

8989
signflip(m::Number, z) = (-1+0im)^m * z
@@ -417,9 +417,6 @@ function _zeta(s::ComplexOrReal{Float64})
417417
# Riemann zeta function; algorithm is based on specializing the Hurwitz
418418
# zeta function above for z==1.
419419

420-
# blows up to ±Inf, but get correct sign of imaginary zero
421-
s == 1 && return NaN + zero(s) * imag(s)
422-
423420
if !isfinite(s) # annoying NaN and Inf cases
424421
isnan(s) && return imag(s) == 0 ? s : s*s
425422
if isfinite(imag(s))
@@ -436,18 +433,28 @@ function _zeta(s::ComplexOrReal{Float64})
436433
-1.00078519447704240796017680222772921424,
437434
-0.9998792995005711649578008136558752359121)
438435
end
436+
ζ1ms = _zeta_core(1 - s)
439437
if absim > 12 # amplitude of sinpi(s/2) ≈ exp(imag(s)*π/2)
440438
# avoid overflow/underflow (issue #128)
441439
lg = loggamma(1 - s)
442440
rehalf = real(s)*0.5
443-
return zeta(1 - s) * exp(lg + absim*halfπ + s*log2π) * inv2π * Complex(
441+
return ζ1ms * exp(lg + absim*halfπ + s*log2π) * inv2π * Complex(
444442
sinpi(rehalf), flipsign(cospi(rehalf), imag(s))
445443
)
446444
else
447-
return zeta(1 - s) * gamma(1 - s) * sinpi(s*0.5) * twoπ^s * invπ
445+
return ζ1ms * gamma(1 - s) * sinpi(s*0.5) * twoπ^s * invπ
448446
end
449447
end
450448

449+
return _zeta_core(s)
450+
end
451+
452+
453+
# Core asymptotic computation of the Riemann zeta function for real(s) >= 0.5.
454+
# Factored out of _zeta to avoid unprovable recursion in the reflection formula.
455+
function _zeta_core(s::ComplexOrReal{Float64})
456+
# blows up to ±Inf, but get correct sign of imaginary zero
457+
s == 1 && return NaN + zero(s) * imag(s)
451458
m = s - 1
452459

453460
# shift using recurrence formula:
@@ -595,7 +602,7 @@ function gamma(n::Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64})
595602
n < 0 && throw(DomainError(n, "`n` must not be negative."))
596603
n == 0 && return Inf
597604
n <= 2 && return 1.0
598-
n > 20 && return _gamma(Float64(n))
605+
n > 20 && return gamma(Float64(n))
599606
@inbounds return Float64(Base._fact_table64[n-1])
600607
end
601608

test/gamma_inc.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,25 @@ end
281281
@testset "GPU compatibility ($FT)" for FT in (Float64, Float32, Float16)
282282
# Note: This test is a proxy for GPU compatibility by checking that the functions
283283
# are type stable and do not allocate memory. It does not launch any GPU kernels.
284+
@testset "gamma type stability" begin
285+
@test @inferred(gamma(FT(2), FT(0.5))) isa FT
286+
@test @inferred(gamma(FT(2), FT(0.5) * im)) isa Complex{FT}
287+
# Test with integer `a`
288+
@test @inferred(gamma(2, FT(0.5))) isa FT
289+
@test @inferred(gamma(2, FT(0.5) * im)) isa Complex{FT}
290+
end
284291
@testset "gamma_inc type stability" begin
285292
@test @inferred(gamma_inc(FT(30.0), FT(29.99999), 0)) isa Tuple{FT,FT}
286293
end
287294
@testset "gamma_inc_inv type stability" begin
288295
@test @inferred(gamma_inc_inv(FT(1.0), FT(0.01), FT(0.99))) isa FT
289296
end
290297

298+
@testset "gamma allocations" begin
299+
@test iszero((FT -> @allocated(gamma(2, FT(0.5))))(FT))
300+
@test iszero((FT -> @allocated(gamma(2, FT(0.5) * im)))(FT))
301+
end
302+
291303
@testset "gamma_inc allocations" begin
292304
# `@allocated` checks for allocations for specific code paths
293305
## a >= 1

0 commit comments

Comments
 (0)