Skip to content

Commit 3cbdb3a

Browse files
committed
Better runtime with a few @inline
1 parent b962b26 commit 3cbdb3a

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

ext/SpecialFunctionsChainRulesCoreExt.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -320,22 +320,22 @@ end
320320
# the auxiliary variable f, the continued-fraction coefficients a_n, b_n, and
321321
# their partial derivatives w.r.t. p (≡ a) and q (≡ b). See Boik & Robinson-Cox (1999).
322322

323-
function _Kfun(x::T, p::T, q::T) where {T<:AbstractFloat}
323+
@inline function _Kfun(x::T, p::T, q::T) where {T<:AbstractFloat}
324324
# K(x;p,q) = x^p (1-x)^{q-1} / (p * B(p,q)) computed in log-space for stability
325325
return exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - logbeta(p, q))
326326
end
327327

328-
function _ffun(x::T, p::T, q::T) where {T<:AbstractFloat}
328+
@inline function _ffun(x::T, p::T, q::T) where {T<:AbstractFloat}
329329
# f = q x / (p (1-x)) — convenience variable appearing in CF coefficients
330330
return q * x / (p * (1 - x))
331331
end
332332

333-
function _a1fun(p::T, q::T, f::T) where {T<:AbstractFloat}
333+
@inline function _a1fun(p::T, q::T, f::T) where {T<:AbstractFloat}
334334
# a₁ coefficient of the continued fraction for ₂F₁ representation
335335
return p * f * (q - 1) / (q * (p + 1))
336336
end
337337

338-
function _anfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
338+
@inline function _anfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
339339
# a_n coefficient (n ≥ 1) of the continued fraction for ₂F₁ in terms of p=a, q=b, f.
340340
# For n=1, falls back to a₁; for n≥2 uses the closed-form product from the Gauss CF.
341341
n == 1 && return _a1fun(p, q, f)
@@ -345,24 +345,24 @@ function _anfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
345345
return r * (n - 1) * (pn + q - 2) * (pn - 1) * (q - n) / ((p2n - 3) * (p2n - 2)^2 * (p2n - 1))
346346
end
347347

348-
function _bnfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
348+
@inline function _bnfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
349349
# b_n coefficient (n ≥ 1) of the continued fraction. Derived for the same CF.
350350
x = 2 * n * (p * f + 2 * q) * (n + p - 1) + p * q * (p - 2 - p * f)
351351
y = q * (p + 2*n - 2) * (p + 2*n)
352352
return x / y
353353
end
354354

355-
function _dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) where {T<:AbstractFloat}
355+
@inline function _dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) where {T<:AbstractFloat}
356356
# ∂K/∂p using digamma identities: d/dp log B(p,q) = ψ(p) - ψ(p+q)
357357
return K * (log(x) - inv(p) + ψpq - ψp)
358358
end
359359

360-
function _dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) where {T<:AbstractFloat}
360+
@inline function _dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) where {T<:AbstractFloat}
361361
# ∂K/∂q using identical pattern
362362
K * (log1p(-x) + ψpq - ψq)
363363
end
364364

365-
function _dK_dpdq(x::T, p::T, q::T) where {T<:AbstractFloat}
365+
@inline function _dK_dpdq(x::T, p::T, q::T) where {T<:AbstractFloat}
366366
# Convenience: compute (∂K/∂p, ∂K/∂q) together with shared ψ(p+q)
367367
ψ = digamma(p + q)
368368
Kf = _Kfun(x, p, q)
@@ -371,12 +371,12 @@ function _dK_dpdq(x::T, p::T, q::T) where {T<:AbstractFloat}
371371
return dKdp, dKdq
372372
end
373373

374-
function _da1_dp(p::T, q::T, f::T) where {T<:AbstractFloat}
374+
@inline function _da1_dp(p::T, q::T, f::T) where {T<:AbstractFloat}
375375
# ∂a₁/∂p from the closed form of a₁
376376
return - _a1fun(p, q, f) / (p + 1)
377377
end
378378

379-
function _dan_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
379+
@inline function _dan_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
380380
# ∂a_n/∂p via log-derivative: d a_n = a_n * d log a_n; for n=1, uses ∂a₁/∂p
381381
if n == 1
382382
return _da1_dp(p, q, f)
@@ -386,13 +386,13 @@ function _dan_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
386386
return an * dlog
387387
end
388388

389-
function _da1_dq(p::T, q::T, f::T) where {T<:AbstractFloat}
389+
@inline function _da1_dq(p::T, q::T, f::T) where {T<:AbstractFloat}
390390
# ∂a₁/∂q
391391
return _a1fun(p, q, f) / (q - 1)
392392
end
393393

394394

395-
function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
395+
@inline function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
396396
# ∂a_n/∂q avoiding the removable singularity at q ≈ n for integer q.
397397
# For n=1, defer to the specific a₁ derivative.
398398
if n == 1
@@ -414,7 +414,7 @@ function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
414414
return C * (p + 2*q - 2)
415415
end
416416

417-
function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
417+
@inline function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
418418
# ∂b_n/∂p via quotient rule on b_n = N/D.
419419
# Note the internal dependence f(p,q)=q x/(p(1-x)) — terms cancel in N as per derivation.
420420
g = p * f + 2 * q
@@ -430,7 +430,7 @@ function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
430430
return (dN_dp * D - N * dD_dp) / (D^2)
431431
end
432432

433-
function _dbn_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
433+
@inline function _dbn_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
434434
# ∂b_n/∂q similarly via quotient rule
435435
g = p * f + 2 * q
436436
A = 2 * n^2 + 2 * (p - 1) * n
@@ -446,7 +446,7 @@ function _dbn_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
446446
return (dN_dq * D - N * dD_dq) / (D^2)
447447
end
448448

449-
function _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) where {T<:AbstractFloat}
449+
@inline function _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) where {T<:AbstractFloat}
450450
# One step of the continuant recurrences:
451451
# A_n = a_n A_{n-2} + b_n A_{n-1}
452452
# B_n = a_n B_{n-2} + b_n B_{n-1}
@@ -457,7 +457,7 @@ function _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) where
457457
return An, Bn, an, bn
458458
end
459459

460-
function _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) where {T<:AbstractFloat}
460+
@inline function _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) where {T<:AbstractFloat}
461461
# Derivative propagation for the same recurrences (X∈{A,B})
462462
return dan * Xpp + an * dXpp + dbn * Xp + bn * dXp
463463
end

0 commit comments

Comments
 (0)