@@ -11,4 +11,351 @@ function __init__()
1111 Enzyme. Compiler. cmplx_known_ops[typeof (SpecialFunctions. besselk)] = (:cmplx_kn , 2 , nothing )
1212end
1313
14+ # x/ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/506
15+ # # Incomplete beta derivatives via Boik & Robinson-Cox
16+ #
17+ # Reference
18+ # R. J. Boik and J. F. Robinson-Cox (1999).
19+ # "Derivatives of the incomplete beta function."
20+ # Journal of Statistical Software, 3(1).
21+ # URL: https://www.jstatsoft.org/article/view/v003i01
22+ #
23+ # The following implementation computes the regularized incomplete beta
24+ # I_x(a,b) together with its partial derivatives with respect to a, b, and x
25+ # using a continued-fraction representation of ₂F₁ and differentiating through it.
26+ # This is an independent implementation adapted from https://github.com/arzwa/IncBetaDer.jl.
27+
28+ # Generic-typed helpers used by the continued-fraction evaluation of I_x(a,b)
29+ # and its partial derivatives. These implement the scalar prefactor K(x;p,q),
30+ # the auxiliary variable f, the continued-fraction coefficients a_n, b_n, and
31+ # their partial derivatives w.r.t. p (≡ a) and q (≡ b). See Boik & Robinson-Cox (1999).
32+
33+ function _Kfun (x:: T , p:: T , q:: T ) where {T<: AbstractFloat }
34+ # K(x;p,q) = x^p (1-x)^{q-1} / (p * B(p,q)) computed in log-space for stability
35+ return exp (p * log (x) + (q - 1 ) * log1p (- x) - log (p) - logbeta (p, q))
36+ end
37+
38+ function _ffun (x:: T , p:: T , q:: T ) where {T<: AbstractFloat }
39+ # f = q x / (p (1-x)) — convenience variable appearing in CF coefficients
40+ return q * x / (p * (1 - x))
41+ end
42+
43+ function _a1fun (p:: T , q:: T , f:: T ) where {T<: AbstractFloat }
44+ # a₁ coefficient of the continued fraction for ₂F₁ representation
45+ return p * f * (q - 1 ) / (q * (p + 1 ))
46+ end
47+
48+ function _anfun (p:: T , q:: T , f:: T , n:: Int ) where {T<: AbstractFloat }
49+ # a_n coefficient (n ≥ 1) of the continued fraction for ₂F₁ in terms of p=a, q=b, f.
50+ # For n=1, falls back to a₁; for n≥2 uses the closed-form product from the Gauss CF.
51+ n == 1 && return _a1fun (p, q, f)
52+ return p^ 2 * f^ 2 * (n - 1 ) * (p + q + n - 2 ) * (p + n - 1 ) * (q - n) /
53+ (q^ 2 * (p + 2 * n - 3 ) * (p + 2 * n - 2 )^ 2 * (p + 2 * n - 1 ))
54+ end
55+
56+ function _bnfun (p:: T , q:: T , f:: T , n:: Int ) where {T<: AbstractFloat }
57+ # b_n coefficient (n ≥ 1) of the continued fraction. Derived for the same CF.
58+ x = 2 * (p * f + 2 * q) * n^ 2 + 2 * (p * f + 2 * q) * (p - 1 ) * n + p * q * (p - 2 - p * f)
59+ y = q * (p + 2 * n - 2 ) * (p + 2 * n)
60+ return x / y
61+ end
62+
63+ function _dK_dp (x:: T , p:: T , q:: T , K:: T , ψpq:: T , ψp:: T ) where {T<: AbstractFloat }
64+ # ∂K/∂p using digamma identities: d/dp log B(p,q) = ψ(p) - ψ(p+q)
65+ return K * (log (x) - inv (p) + ψpq - ψp)
66+ end
67+
68+ function _dK_dq (x:: T , p:: T , q:: T , K:: T , ψpq:: T , ψq:: T ) where {T<: AbstractFloat }
69+ # ∂K/∂q using identical pattern
70+ K * (log1p (- x) + ψpq - ψq)
71+ end
72+
73+ function _dK_dpdq (x:: T , p:: T , q:: T ) where {T<: AbstractFloat }
74+ # Convenience: compute (∂K/∂p, ∂K/∂q) together with shared ψ(p+q)
75+ ψ = digamma (p + q)
76+ Kf = _Kfun (x, p, q)
77+ dKdp = _dK_dp (x, p, q, Kf, ψ, digamma (p))
78+ dKdq = _dK_dq (x, p, q, Kf, ψ, digamma (q))
79+ return dKdp, dKdq
80+ end
81+
82+ function _da1_dp (p:: T , q:: T , f:: T ) where {T<: AbstractFloat }
83+ # ∂a₁/∂p from the closed form of a₁
84+ return - _a1fun (p, q, f) / (p + 1 )
85+ end
86+
87+ function _dan_dp (p:: T , q:: T , f:: T , n:: Int ) where {T<: AbstractFloat }
88+ # ∂a_n/∂p via log-derivative: d a_n = a_n * d log a_n; for n=1, uses ∂a₁/∂p
89+ if n == 1
90+ return _da1_dp (p, q, f)
91+ end
92+ an = _anfun (p, q, f, n)
93+ dlog = inv (p + q + n - 2 ) + inv (p + n - 1 ) - inv (p + 2 * n - 3 ) - 2 * inv (p + 2 * n - 2 ) - inv (p + 2 * n - 1 )
94+ return an * dlog
95+ end
96+
97+ function _da1_dq (p:: T , q:: T , f:: T ) where {T<: AbstractFloat }
98+ # ∂a₁/∂q
99+ return _a1fun (p, q, f) / (q - 1 )
100+ end
101+
102+
103+ function _dan_dq (p:: T , q:: T , f:: T , n:: Int ) where {T<: AbstractFloat }
104+ # ∂a_n/∂q avoiding the removable singularity at q ≈ n for integer q.
105+ # For n=1, defer to the specific a₁ derivative.
106+ if n == 1
107+ return _da1_dq (p, q, f)
108+ end
109+ # Use the simplified closed-form of a_n that eliminates explicit q^2 via f:
110+ # a_n = (x/(1-x))^2 * (n-1) * (p+n-1) * (p+q+n-2) * (q-n) / D(p,n)
111+ # where D(p,n) = (p+2n-3)*(p+2n-2)^2*(p+2n-1) and (x/(1-x)) = p*f/q.
112+ # Differentiate only the q-dependent factor G(q) = (p+q+n-2)*(q-n):
113+ # dG/dq = (q-n) + (p+q+n-2) = p + 2q - 2.
114+
115+ # This is equivalent to
116+ # return _anfun(p,q,f,n) * (inv(p+q+n-2) + inv(q-n))
117+ # but more precise.
118+
119+ pfq = (p * f) / q
120+ C = (pfq * pfq) * (n - 1 ) * (p + n - 1 ) /
121+ ((p + 2 * n - 3 ) * (p + 2 * n - 2 )^ 2 * (p + 2 * n - 1 ))
122+ return C * (p + 2 * q - 2 )
123+ end
124+
125+ function _dbn_dp (p:: T , q:: T , f:: T , n:: Int ) where {T<: AbstractFloat }
126+ # ∂b_n/∂p via quotient rule on b_n = N/D.
127+ # Note the internal dependence f(p,q)=q x/(p(1-x)) — terms cancel in N as per derivation.
128+ g = p * f + 2 * q
129+ A = 2 * n^ 2 + 2 * (p - 1 ) * n
130+ N1 = g * A
131+ N2 = p * q * (p - 2 - p * f)
132+ N = N1 + N2
133+ D = q * (p + 2 * n - 2 ) * (p + 2 * n)
134+ dN1_dp = 2 * n * g
135+ dN2_dp = q * (2 * p - 2 ) - p * q * f
136+ dN_dp = dN1_dp + dN2_dp
137+ dD_dp = q * (2 * p + 4 * n - 2 )
138+ return (dN_dp * D - N * dD_dp) / (D^ 2 )
139+ end
140+
141+ function _dbn_dq (p:: T , q:: T , f:: T , n:: Int ) where {T<: AbstractFloat }
142+ # ∂b_n/∂q similarly via quotient rule
143+ g = p * f + 2 * q
144+ A = 2 * n^ 2 + 2 * (p - 1 ) * n
145+ N1 = g * A
146+ N2 = p * q * (p - 2 - p * f)
147+ N = N1 + N2
148+ D = q * (p + 2 * n - 2 ) * (p + 2 * n)
149+ g_q = p * (f / q) + 2
150+ dN1_dq = g_q * A
151+ dN2_dq = p * (p - 2 - p * f) - p^ 2 * f
152+ dN_dq = dN1_dq + dN2_dq
153+ dD_dq = (p + 2 * n - 2 ) * (p + 2 * n)
154+ return (dN_dq * D - N * dD_dq) / (D^ 2 )
155+ end
156+
157+ function _nextapp (f:: T , p:: T , q:: T , n:: Int , App:: T , Ap:: T , Bpp:: T , Bp:: T ) where {T<: AbstractFloat }
158+ # One step of the continuant recurrences:
159+ # A_n = a_n A_{n-2} + b_n A_{n-1}
160+ # B_n = a_n B_{n-2} + b_n B_{n-1}
161+ an = _anfun (p, q, f, n)
162+ bn = _bnfun (p, q, f, n)
163+ An = an * App + bn * Ap
164+ Bn = an * Bpp + bn * Bp
165+ return An, Bn, an, bn
166+ end
167+
168+ function _dnextapp (an:: T , bn:: T , dan:: T , dbn:: T , Xpp:: T , Xp:: T , dXpp:: T , dXp:: T ) where {T<: AbstractFloat }
169+ # Derivative propagation for the same recurrences (X∈{A,B})
170+ return dan * Xpp + an * dXpp + dbn * Xp + bn * dXp
171+ end
172+
173+ function _beta_inc_grad (a, b, x; maxapp:: Int = 200 , minapp:: Int = 3 )
174+ T = promote_type (float (typeof (a)), float (typeof (b)), float (typeof (x)));
175+ err:: T = eps (T)* T (1e4 )
176+ a = T (a)
177+ b = T (b)
178+ x = T (x)
179+ # Compute I_x(a,b) and partial derivatives (∂I/∂a, ∂I/∂b, ∂I/∂x)
180+ # using a differentiated continued fraction with convergence control.
181+ oneT = one (T)
182+ zeroT = zero (T)
183+
184+ # 1) Boundary cases for x
185+ x == oneT && return oneT, zeroT, zeroT, zeroT
186+ x == zeroT && return zeroT, zeroT, zeroT, zeroT
187+
188+ # 2) Clamp iteration/tolerance parameters to robust defaults
189+ ϵ = min (err, T (1e-14 ))
190+ maxapp = max (1000 , maxapp)
191+ minapp = max (5 , minapp)
192+
193+ # 3) Non-boundary path: precompute ∂I/∂x at original (a,b,x) via stable log form
194+ dx = exp ((a - oneT) * log (x) + (b - oneT) * log1p (- x) - logbeta (a,b))
195+
196+ # 4) Optional tail-swap for symmetry and improved CF convergence:
197+ # if x > a/(a+b), evaluate at (p,q,x₀) = (b,a,1-x) and swap back at the end.
198+ p = a
199+ q = b
200+ x₀ = x
201+ swap = false
202+ if x > a / (a + b)
203+ x₀ = oneT - x
204+ p = b
205+ q = a
206+ swap = true
207+ end
208+
209+ # 5) Initialize CF state and derivatives
210+ K = _Kfun (x₀, p, q)
211+ dK_dp_val, dK_dq_val = _dK_dpdq (x₀, p, q)
212+ f = _ffun (x₀, p, q)
213+ App = oneT
214+ Ap = oneT
215+ Bpp = zeroT
216+ Bp = oneT
217+ dApp_dp = zeroT
218+ dBpp_dp = zeroT
219+ dAp_dp = zeroT
220+ dBp_dp = zeroT
221+ dApp_dq = zeroT
222+ dBpp_dq = zeroT
223+ dAp_dq = zeroT
224+ dBp_dq = zeroT
225+ dI_dp = T (NaN )
226+ dI_dq = T (NaN )
227+ Ixpq = T (NaN )
228+ Ixpqn = T (NaN )
229+ dI_dp_prev = T (NaN )
230+ dI_dq_prev = T (NaN )
231+
232+ # 6) Main CF loop (n from 1): update continuants, scale, form current approximant Cn=A_n/B_n
233+ # and its derivatives to update I and ∂I/∂(p,q). Stop on relative convergence of all.
234+ for n= 1 : maxapp
235+
236+ # Update continuants.
237+ An, Bn, an, bn = _nextapp (f, p, q, n, App, Ap, Bpp, Bp)
238+ dan = _dan_dp (p, q, f, n)
239+ dbn = _dbn_dp (p, q, f, n)
240+ dAn_dp = _dnextapp (an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp)
241+ dBn_dp = _dnextapp (an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp)
242+ dan = _dan_dq (p, q, f, n)
243+ dbn = _dbn_dq (p, q, f, n)
244+ dAn_dq = _dnextapp (an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq)
245+ dBn_dq = _dnextapp (an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq)
246+
247+ # Normalize states to control growth/underflow (scale-invariant transform)
248+ s = maximum ((abs (An), abs (Bn), abs (Ap), abs (Bp), abs (App), abs (Bpp)))
249+ if isfinite (s) && s > zeroT
250+ invs = inv (s)
251+ An *= invs
252+ Bn *= invs
253+ Ap *= invs
254+ Bp *= invs
255+ App *= invs
256+ Bpp *= invs
257+ dAn_dp *= invs
258+ dBn_dp *= invs
259+ dAn_dq *= invs
260+ dBn_dq *= invs
261+ dAp_dp *= invs
262+ dBp_dp *= invs
263+ dApp_dp *= invs
264+ dBpp_dp *= invs
265+ dAp_dq *= invs
266+ dBp_dq *= invs
267+ dApp_dq *= invs
268+ dBpp_dq *= invs
269+ end
270+
271+ # Form current approximant Cn=A_n/B_n and its derivatives.
272+ # Guard against tiny/zero Bn to avoid NaNs/Inf in divisions.
273+ tiny = sqrt (eps (T))
274+ absBn = abs (Bn)
275+ sgnBn = ifelse (Bn >= zeroT, oneT, - oneT)
276+ invBn = absBn > tiny && isfinite (absBn) ? inv (Bn) : inv (sgnBn * tiny)
277+ Cn = An * invBn
278+ invBn2 = invBn * invBn
279+ dI_dp = dK_dp_val * Cn + K * (invBn * dAn_dp - (An * invBn2) * dBn_dp)
280+ dI_dq = dK_dq_val * Cn + K * (invBn * dAn_dq - (An * invBn2) * dBn_dq)
281+ Ixpqn = K * Cn
282+
283+ # Decide convergence:
284+ if n >= minapp
285+ # Relative convergence for I, ∂I/∂p, ∂I/∂q (guards against tiny denominators)
286+ denomI = max (abs (Ixpqn), abs (Ixpq), eps (T))
287+ denomp = max (abs (dI_dp), abs (dI_dp_prev), eps (T))
288+ denomq = max (abs (dI_dq), abs (dI_dq_prev), eps (T))
289+ rI = abs (Ixpqn - Ixpq) / denomI
290+ rp = abs (dI_dp - dI_dp_prev) / denomp
291+ rq = abs (dI_dq - dI_dq_prev) / denomq
292+ if max (rI, rp, rq) < ϵ
293+ break
294+ end
295+ end
296+ Ixpq = Ixpqn
297+ dI_dp_prev = dI_dp
298+ dI_dq_prev = dI_dq
299+
300+ # Shift CF state for next iteration
301+ App = Ap
302+ Bpp = Bp
303+ Ap = An
304+ Bp = Bn
305+ dApp_dp = dAp_dp
306+ dApp_dq = dAp_dq
307+ dBpp_dp = dBp_dp
308+ dBpp_dq = dBp_dq
309+ dAp_dp = dAn_dp
310+ dAp_dq = dAn_dq
311+ dBp_dp = dBn_dp
312+ dBp_dq = dBn_dq
313+ end
314+
315+ # 7) Undo tail-swap if applied; ∂I/∂x is the pdf at original (a,b,x)
316+ if swap
317+ return oneT - Ixpqn, - dI_dq, - dI_dp, dx
318+ else
319+ return Ixpqn, dI_dp, dI_dq, dx
320+ end
321+ end
322+
323+ EnzymeRules. @easy_rule (
324+ SpecialFunctions. beta_inc (a, b, x),
325+ @setup (
326+ (_, dIa, dIb, dIx) = _beta_inc_grad (a, b, x)
327+ ),
328+ (dIa, dIb, dIx),
329+ (- dIa, - dIb, - dIx),
330+ )
331+
332+ Enzyme. EnzymeRules. @easy_rule (
333+ SpecialFunctions. beta_inc (a, b, x, y),
334+ @setup (
335+ (_, dIa, dIb, dIx) = _beta_inc_grad (a, b, x)
336+ ),
337+ (dIa, dIb, dIx, - dIx),
338+ (- dIa, - dIb, - dIx, dIx),
339+ )
340+
341+ Enzyme. EnzymeRules. @easy_rule (
342+ SpecialFunctions. beta_inc_inv (a, b, p),
343+ @setup (
344+
345+ (x, y) = Ω,
346+
347+ # Implicit differentiation at solved x: I_x(a,b) = p
348+ (_, dIa, dIb, _) = _beta_inc_grad (a, b, x),
349+
350+ # ∂I/∂x at solved x via stable log-space expression
351+ dIx_acc = exp (muladd (a - one (a), log (x), muladd (b - one (b), log1p (- x), - logbeta (a, b)))),
352+ inv_dIx = inv (dIx_acc),
353+ dx_da = - dIa * inv_dIx,
354+ dx_db = - dIb * inv_dIx,
355+ dx_dp = inv_dIx,
356+ ),
357+ (dx_da, dx_db, dx_dp),
358+ (- dx_da, - dx_db, - dx_dp)
359+ )
360+
14361end
0 commit comments