Skip to content

Commit 949d1ad

Browse files
authored
First pass at easyrules (#2674)
* f * tf * her * fix * f * fix * fix * WIP rrule * fix * Reverse mode * Fix array * fix * fix * fix * @constant * fix * fix * Some special function tests
1 parent 2dd8d87 commit 949d1ad

File tree

10 files changed

+1535
-31
lines changed

10 files changed

+1535
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
4444
CEnum = "0.4, 0.5"
4545
ChainRulesCore = "1"
4646
DynamicPPL = "0.35, 0.36, 0.37, 0.38"
47-
EnzymeCore = "0.8.14"
47+
EnzymeCore = "0.8.15"
4848
Enzyme_jll = "0.0.203"
4949
GPUArraysCore = "0.1.6, 0.2"
5050
GPUCompiler = "1.6.2"

ext/EnzymeSpecialFunctionsExt.jl

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,351 @@ function __init__()
1111
Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselk)] = (:cmplx_kn, 2, nothing)
1212
end
1313

14+
# x/ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/506
15+
## Incomplete beta derivatives via Boik & Robinson-Cox
16+
#
17+
# Reference
18+
# R. J. Boik and J. F. Robinson-Cox (1999).
19+
# "Derivatives of the incomplete beta function."
20+
# Journal of Statistical Software, 3(1).
21+
# URL: https://www.jstatsoft.org/article/view/v003i01
22+
#
23+
# The following implementation computes the regularized incomplete beta
24+
# I_x(a,b) together with its partial derivatives with respect to a, b, and x
25+
# using a continued-fraction representation of ₂F₁ and differentiating through it.
26+
# This is an independent implementation adapted from https://github.com/arzwa/IncBetaDer.jl.
27+
28+
# Generic-typed helpers used by the continued-fraction evaluation of I_x(a,b)
29+
# and its partial derivatives. These implement the scalar prefactor K(x;p,q),
30+
# the auxiliary variable f, the continued-fraction coefficients a_n, b_n, and
31+
# their partial derivatives w.r.t. p (≡ a) and q (≡ b). See Boik & Robinson-Cox (1999).
32+
33+
function _Kfun(x::T, p::T, q::T) where {T<:AbstractFloat}
34+
# K(x;p,q) = x^p (1-x)^{q-1} / (p * B(p,q)) computed in log-space for stability
35+
return exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - logbeta(p, q))
36+
end
37+
38+
function _ffun(x::T, p::T, q::T) where {T<:AbstractFloat}
39+
# f = q x / (p (1-x)) — convenience variable appearing in CF coefficients
40+
return q * x / (p * (1 - x))
41+
end
42+
43+
function _a1fun(p::T, q::T, f::T) where {T<:AbstractFloat}
44+
# a₁ coefficient of the continued fraction for ₂F₁ representation
45+
return p * f * (q - 1) / (q * (p + 1))
46+
end
47+
48+
function _anfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
49+
# a_n coefficient (n ≥ 1) of the continued fraction for ₂F₁ in terms of p=a, q=b, f.
50+
# For n=1, falls back to a₁; for n≥2 uses the closed-form product from the Gauss CF.
51+
n == 1 && return _a1fun(p, q, f)
52+
return p^2 * f^2 * (n - 1) * (p + q + n - 2) * (p + n - 1) * (q - n) /
53+
(q^2 * (p + 2*n - 3) * (p + 2*n - 2)^2 * (p + 2*n - 1))
54+
end
55+
56+
function _bnfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
57+
# b_n coefficient (n ≥ 1) of the continued fraction. Derived for the same CF.
58+
x = 2 * (p * f + 2 * q) * n^2 + 2 * (p * f + 2 * q) * (p - 1) * n + p * q * (p - 2 - p * f)
59+
y = q * (p + 2*n - 2) * (p + 2*n)
60+
return x / y
61+
end
62+
63+
function _dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) where {T<:AbstractFloat}
64+
# ∂K/∂p using digamma identities: d/dp log B(p,q) = ψ(p) - ψ(p+q)
65+
return K * (log(x) - inv(p) + ψpq - ψp)
66+
end
67+
68+
function _dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) where {T<:AbstractFloat}
69+
# ∂K/∂q using identical pattern
70+
K * (log1p(-x) + ψpq - ψq)
71+
end
72+
73+
function _dK_dpdq(x::T, p::T, q::T) where {T<:AbstractFloat}
74+
# Convenience: compute (∂K/∂p, ∂K/∂q) together with shared ψ(p+q)
75+
ψ = digamma(p + q)
76+
Kf = _Kfun(x, p, q)
77+
dKdp = _dK_dp(x, p, q, Kf, ψ, digamma(p))
78+
dKdq = _dK_dq(x, p, q, Kf, ψ, digamma(q))
79+
return dKdp, dKdq
80+
end
81+
82+
function _da1_dp(p::T, q::T, f::T) where {T<:AbstractFloat}
83+
# ∂a₁/∂p from the closed form of a₁
84+
return - _a1fun(p, q, f) / (p + 1)
85+
end
86+
87+
function _dan_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
88+
# ∂a_n/∂p via log-derivative: d a_n = a_n * d log a_n; for n=1, uses ∂a₁/∂p
89+
if n == 1
90+
return _da1_dp(p, q, f)
91+
end
92+
an = _anfun(p, q, f, n)
93+
dlog = inv(p + q + n - 2) + inv(p + n - 1) - inv(p + 2*n - 3) - 2 * inv(p + 2*n - 2) - inv(p + 2*n - 1)
94+
return an * dlog
95+
end
96+
97+
function _da1_dq(p::T, q::T, f::T) where {T<:AbstractFloat}
98+
# ∂a₁/∂q
99+
return _a1fun(p, q, f) / (q - 1)
100+
end
101+
102+
103+
function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
104+
# ∂a_n/∂q avoiding the removable singularity at q ≈ n for integer q.
105+
# For n=1, defer to the specific a₁ derivative.
106+
if n == 1
107+
return _da1_dq(p, q, f)
108+
end
109+
# Use the simplified closed-form of a_n that eliminates explicit q^2 via f:
110+
# a_n = (x/(1-x))^2 * (n-1) * (p+n-1) * (p+q+n-2) * (q-n) / D(p,n)
111+
# where D(p,n) = (p+2n-3)*(p+2n-2)^2*(p+2n-1) and (x/(1-x)) = p*f/q.
112+
# Differentiate only the q-dependent factor G(q) = (p+q+n-2)*(q-n):
113+
# dG/dq = (q-n) + (p+q+n-2) = p + 2q - 2.
114+
115+
# This is equivalent to
116+
# return _anfun(p,q,f,n) * (inv(p+q+n-2) + inv(q-n))
117+
# but more precise.
118+
119+
pfq = (p * f) / q
120+
C = (pfq * pfq) * (n - 1) * (p + n - 1) /
121+
((p + 2*n - 3) * (p + 2*n - 2)^2 * (p + 2*n - 1))
122+
return C * (p + 2*q - 2)
123+
end
124+
125+
function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
126+
# ∂b_n/∂p via quotient rule on b_n = N/D.
127+
# Note the internal dependence f(p,q)=q x/(p(1-x)) — terms cancel in N as per derivation.
128+
g = p * f + 2 * q
129+
A = 2 * n^2 + 2 * (p - 1) * n
130+
N1 = g * A
131+
N2 = p * q * (p - 2 - p * f)
132+
N = N1 + N2
133+
D = q * (p + 2*n - 2) * (p + 2*n)
134+
dN1_dp = 2 * n * g
135+
dN2_dp = q * (2 * p - 2) - p * q * f
136+
dN_dp = dN1_dp + dN2_dp
137+
dD_dp = q * (2 * p + 4 * n - 2)
138+
return (dN_dp * D - N * dD_dp) / (D^2)
139+
end
140+
141+
function _dbn_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat}
142+
# ∂b_n/∂q similarly via quotient rule
143+
g = p * f + 2 * q
144+
A = 2 * n^2 + 2 * (p - 1) * n
145+
N1 = g * A
146+
N2 = p * q * (p - 2 - p * f)
147+
N = N1 + N2
148+
D = q * (p + 2*n - 2) * (p + 2*n)
149+
g_q = p * (f / q) + 2
150+
dN1_dq = g_q * A
151+
dN2_dq = p * (p - 2 - p * f) - p^2 * f
152+
dN_dq = dN1_dq + dN2_dq
153+
dD_dq = (p + 2*n - 2) * (p + 2*n)
154+
return (dN_dq * D - N * dD_dq) / (D^2)
155+
end
156+
157+
function _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) where {T<:AbstractFloat}
158+
# One step of the continuant recurrences:
159+
# A_n = a_n A_{n-2} + b_n A_{n-1}
160+
# B_n = a_n B_{n-2} + b_n B_{n-1}
161+
an = _anfun(p, q, f, n)
162+
bn = _bnfun(p, q, f, n)
163+
An = an * App + bn * Ap
164+
Bn = an * Bpp + bn * Bp
165+
return An, Bn, an, bn
166+
end
167+
168+
function _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) where {T<:AbstractFloat}
169+
# Derivative propagation for the same recurrences (X∈{A,B})
170+
return dan * Xpp + an * dXpp + dbn * Xp + bn * dXp
171+
end
172+
173+
function _beta_inc_grad(a, b, x; maxapp::Int=200, minapp::Int=3)
174+
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)));
175+
err::T=eps(T)*T(1e4)
176+
a = T(a)
177+
b = T(b)
178+
x = T(x)
179+
# Compute I_x(a,b) and partial derivatives (∂I/∂a, ∂I/∂b, ∂I/∂x)
180+
# using a differentiated continued fraction with convergence control.
181+
oneT = one(T)
182+
zeroT = zero(T)
183+
184+
# 1) Boundary cases for x
185+
x == oneT && return oneT, zeroT, zeroT, zeroT
186+
x == zeroT && return zeroT, zeroT, zeroT, zeroT
187+
188+
# 2) Clamp iteration/tolerance parameters to robust defaults
189+
ϵ = min(err, T(1e-14))
190+
maxapp = max(1000, maxapp)
191+
minapp = max(5, minapp)
192+
193+
# 3) Non-boundary path: precompute ∂I/∂x at original (a,b,x) via stable log form
194+
dx = exp((a - oneT) * log(x) + (b - oneT) * log1p(-x) - logbeta(a,b))
195+
196+
# 4) Optional tail-swap for symmetry and improved CF convergence:
197+
# if x > a/(a+b), evaluate at (p,q,x₀) = (b,a,1-x) and swap back at the end.
198+
p = a
199+
q = b
200+
x₀ = x
201+
swap = false
202+
if x > a / (a + b)
203+
x₀ = oneT - x
204+
p = b
205+
q = a
206+
swap = true
207+
end
208+
209+
# 5) Initialize CF state and derivatives
210+
K = _Kfun(x₀, p, q)
211+
dK_dp_val, dK_dq_val = _dK_dpdq(x₀, p, q)
212+
f = _ffun(x₀, p, q)
213+
App = oneT
214+
Ap = oneT
215+
Bpp = zeroT
216+
Bp = oneT
217+
dApp_dp = zeroT
218+
dBpp_dp = zeroT
219+
dAp_dp = zeroT
220+
dBp_dp = zeroT
221+
dApp_dq = zeroT
222+
dBpp_dq = zeroT
223+
dAp_dq = zeroT
224+
dBp_dq = zeroT
225+
dI_dp = T(NaN)
226+
dI_dq = T(NaN)
227+
Ixpq = T(NaN)
228+
Ixpqn = T(NaN)
229+
dI_dp_prev = T(NaN)
230+
dI_dq_prev = T(NaN)
231+
232+
# 6) Main CF loop (n from 1): update continuants, scale, form current approximant Cn=A_n/B_n
233+
# and its derivatives to update I and ∂I/∂(p,q). Stop on relative convergence of all.
234+
for n=1:maxapp
235+
236+
# Update continuants.
237+
An, Bn, an, bn = _nextapp(f, p, q, n, App, Ap, Bpp, Bp)
238+
dan = _dan_dp(p, q, f, n)
239+
dbn = _dbn_dp(p, q, f, n)
240+
dAn_dp = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp)
241+
dBn_dp = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp)
242+
dan = _dan_dq(p, q, f, n)
243+
dbn = _dbn_dq(p, q, f, n)
244+
dAn_dq = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq)
245+
dBn_dq = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq)
246+
247+
# Normalize states to control growth/underflow (scale-invariant transform)
248+
s = maximum((abs(An), abs(Bn), abs(Ap), abs(Bp), abs(App), abs(Bpp)))
249+
if isfinite(s) && s > zeroT
250+
invs = inv(s)
251+
An *= invs
252+
Bn *= invs
253+
Ap *= invs
254+
Bp *= invs
255+
App *= invs
256+
Bpp *= invs
257+
dAn_dp *= invs
258+
dBn_dp *= invs
259+
dAn_dq *= invs
260+
dBn_dq *= invs
261+
dAp_dp *= invs
262+
dBp_dp *= invs
263+
dApp_dp *= invs
264+
dBpp_dp *= invs
265+
dAp_dq *= invs
266+
dBp_dq *= invs
267+
dApp_dq *= invs
268+
dBpp_dq *= invs
269+
end
270+
271+
# Form current approximant Cn=A_n/B_n and its derivatives.
272+
# Guard against tiny/zero Bn to avoid NaNs/Inf in divisions.
273+
tiny = sqrt(eps(T))
274+
absBn = abs(Bn)
275+
sgnBn = ifelse(Bn >= zeroT, oneT, -oneT)
276+
invBn = absBn > tiny && isfinite(absBn) ? inv(Bn) : inv(sgnBn * tiny)
277+
Cn = An * invBn
278+
invBn2 = invBn * invBn
279+
dI_dp = dK_dp_val * Cn + K * (invBn * dAn_dp - (An * invBn2) * dBn_dp)
280+
dI_dq = dK_dq_val * Cn + K * (invBn * dAn_dq - (An * invBn2) * dBn_dq)
281+
Ixpqn = K * Cn
282+
283+
# Decide convergence:
284+
if n >= minapp
285+
# Relative convergence for I, ∂I/∂p, ∂I/∂q (guards against tiny denominators)
286+
denomI = max(abs(Ixpqn), abs(Ixpq), eps(T))
287+
denomp = max(abs(dI_dp), abs(dI_dp_prev), eps(T))
288+
denomq = max(abs(dI_dq), abs(dI_dq_prev), eps(T))
289+
rI = abs(Ixpqn - Ixpq) / denomI
290+
rp = abs(dI_dp - dI_dp_prev) / denomp
291+
rq = abs(dI_dq - dI_dq_prev) / denomq
292+
if max(rI, rp, rq) < ϵ
293+
break
294+
end
295+
end
296+
Ixpq = Ixpqn
297+
dI_dp_prev = dI_dp
298+
dI_dq_prev = dI_dq
299+
300+
# Shift CF state for next iteration
301+
App = Ap
302+
Bpp = Bp
303+
Ap = An
304+
Bp = Bn
305+
dApp_dp = dAp_dp
306+
dApp_dq = dAp_dq
307+
dBpp_dp = dBp_dp
308+
dBpp_dq = dBp_dq
309+
dAp_dp = dAn_dp
310+
dAp_dq = dAn_dq
311+
dBp_dp = dBn_dp
312+
dBp_dq = dBn_dq
313+
end
314+
315+
# 7) Undo tail-swap if applied; ∂I/∂x is the pdf at original (a,b,x)
316+
if swap
317+
return oneT - Ixpqn, -dI_dq, -dI_dp, dx
318+
else
319+
return Ixpqn, dI_dp, dI_dq, dx
320+
end
321+
end
322+
323+
EnzymeRules.@easy_rule(
324+
SpecialFunctions.beta_inc(a, b, x),
325+
@setup(
326+
(_, dIa, dIb, dIx) = _beta_inc_grad(a, b, x)
327+
),
328+
(dIa, dIb, dIx),
329+
(-dIa, -dIb, -dIx),
330+
)
331+
332+
Enzyme.EnzymeRules.@easy_rule(
333+
SpecialFunctions.beta_inc(a, b, x, y),
334+
@setup(
335+
(_, dIa, dIb, dIx) = _beta_inc_grad(a, b, x)
336+
),
337+
(dIa, dIb, dIx, -dIx),
338+
(-dIa, -dIb, -dIx, dIx),
339+
)
340+
341+
Enzyme.EnzymeRules.@easy_rule(
342+
SpecialFunctions.beta_inc_inv(a, b, p),
343+
@setup(
344+
345+
(x, y) = Ω,
346+
347+
# Implicit differentiation at solved x: I_x(a,b) = p
348+
(_, dIa, dIb, _) = _beta_inc_grad(a, b, x),
349+
350+
# ∂I/∂x at solved x via stable log-space expression
351+
dIx_acc = exp(muladd(a - one(a), log(x), muladd(b - one(b), log1p(-x), -logbeta(a, b)))),
352+
inv_dIx = inv(dIx_acc),
353+
dx_da = -dIa * inv_dIx,
354+
dx_db = -dIb * inv_dIx,
355+
dx_dp = inv_dIx,
356+
),
357+
(dx_da, dx_db, dx_dp),
358+
(-dx_da, -dx_db, -dx_dp)
359+
)
360+
14361
end

lib/EnzymeCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeCore"
22
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4-
version = "0.8.14"
4+
version = "0.8.15"
55

66
[compat]
77
Adapt = "3, 4"

0 commit comments

Comments
 (0)