Skip to content

Commit a0774d9

Browse files
jobovyclaude
andauthored
Pspecial Tiers 3-4: Bessel K + associated Legendre + Gegenbauer (#917)
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent b87bbce commit a0774d9

6 files changed

Lines changed: 443 additions & 4 deletions

File tree

galpy/backend/special/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# capability test asserts native-vs-fallback agreement so it can be deleted).
1414
###############################################################################
1515
from ._router import (
16+
assoc_legendre,
1617
ellipe,
1718
ellipk,
1819
erf,
@@ -21,10 +22,14 @@
2122
gammainc,
2223
gammaincc,
2324
gammaln,
25+
gegenbauer,
2426
hyp1f1,
2527
hyp2f1,
2628
i0,
2729
i1,
30+
k0,
31+
k1,
32+
kn,
2833
xlogy,
2934
)
3035

@@ -42,4 +47,9 @@
4247
"hyp1f1",
4348
"ellipk",
4449
"ellipe",
50+
"k0",
51+
"k1",
52+
"kn",
53+
"assoc_legendre",
54+
"gegenbauer",
4555
]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
###############################################################################
2+
# Backend-agnostic associated Legendre functions P_l^m(x) for all degrees
3+
# l < L and orders 0 <= m < M, with the Condon-Shortley phase (matching
4+
# scipy.special.assoc_legendre_p_all(..., branch_cut=2)). Replaces
5+
# galpy.util.special.compute_legendre on the SCF / MultipoleExpansion path so
6+
# those potentials run and differentiate under every backend.
7+
#
8+
# P is built by the standard forward (Bonnet) recurrences:
9+
# P_m^m = (-1)^m (2m-1)!! (1-x^2)^{m/2}
10+
# P_{m+1}^m = x (2m+1) P_m^m
11+
# (l-m) P_l^m = x (2l-1) P_{l-1}^m - (l+m-1) P_{l-2}^m
12+
# The optional first/second x-derivatives use
13+
# (x^2-1) dP/dx = l x P_l^m - (l+m) P_{l-1}^m
14+
# (1-x^2) d2P/dx^2 = 2x dP/dx - l(l+1) P + m^2/(1-x^2) P (Legendre ODE)
15+
# (these diverge at the poles x=+-1 for m>=1, exactly as scipy returns, and are
16+
# multiplied by sin^2(theta) in the physical theta-derivatives downstream).
17+
#
18+
# Everything is pure arithmetic built with lists + xp.stack (no in-place
19+
# mutation), so it differentiates cleanly under jax and torch -- and the
20+
# x-derivatives are also available straight from autodiff.
21+
###############################################################################
22+
23+
24+
def assoc_legendre(xp, L, M, x, deriv=0):
25+
"""P_l^m(x), shape ``x.shape + (L, M)`` (Condon-Shortley phase).
26+
27+
deriv: 0 -> P; 1 -> (P, dP/dx); 2 -> (P, dP/dx, d2P/dx2).
28+
L, M are static ints; x is a backend array (or scalar) with |x| <= 1.
29+
"""
30+
x = xp.asarray(x) * 1.0
31+
one = xp.ones_like(x)
32+
zero = xp.zeros_like(x)
33+
# (1-x^2)^{1/2}; clip keeps it real at |x|=1 (interior x is unaffected).
34+
somx2 = xp.sqrt(xp.where(x * x < 1.0, 1.0 - x * x, zero))
35+
36+
# P[l][m] as a list-of-lists of backend arrays (functional, no mutation).
37+
P = [[zero for _ in range(M)] for _ in range(L)]
38+
pmm = one # running P_m^m diagonal
39+
for m in range(M):
40+
if m > 0:
41+
pmm = pmm * (-(2 * m - 1)) * somx2
42+
if m < L:
43+
P[m][m] = pmm
44+
if m + 1 < L:
45+
P[m + 1][m] = x * (2 * m + 1) * pmm
46+
for l in range(m + 2, L):
47+
P[l][m] = (x * (2 * l - 1) * P[l - 1][m] - (l + m - 1) * P[l - 2][m]) / (
48+
l - m
49+
)
50+
51+
def _stack(grid):
52+
return xp.stack([xp.stack(row, axis=-1) for row in grid], axis=-2)
53+
54+
Parr = _stack(P)
55+
if deriv == 0:
56+
return Parr
57+
58+
den = x * x - 1.0 # (x^2-1); singular only at the poles
59+
dP = [[zero for _ in range(M)] for _ in range(L)]
60+
for m in range(M):
61+
for l in range(m, L):
62+
plm1 = P[l - 1][m] if l - 1 >= m else zero
63+
dP[l][m] = (l * x * P[l][m] - (l + m) * plm1) / den
64+
dParr = _stack(dP)
65+
if deriv == 1:
66+
return Parr, dParr
67+
68+
om = 1.0 - x * x
69+
d2 = [[zero for _ in range(M)] for _ in range(L)]
70+
for m in range(M):
71+
for l in range(m, L):
72+
d2[l][m] = (
73+
2.0 * x * dP[l][m] - l * (l + 1) * P[l][m] + (m * m) / om * P[l][m]
74+
) / om
75+
return Parr, dParr, _stack(d2)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
###############################################################################
2+
# Fallbacks for the modified Bessel functions of the second kind K0, K1, Kn
3+
# on real x > 0. Needed on BOTH jax and torch:
4+
# - jax.scipy.special has no k0/k1/kn at all;
5+
# - torch.special has modified_bessel_k0/k1 but they are NOT differentiable
6+
# (no autograd backward) and lack kn entirely, so we use the fallback there
7+
# too (the router sees no torch.special.k0 attribute -> treats it missing).
8+
#
9+
# K0, K1 (``_k01``) use two regimes, each ~1e-15 vs scipy and AD-friendly:
10+
# - x <= 2: the Abramowitz & Stegun ascending series (9.6.13/9.6.11), built
11+
# on the native i0/i1 (Tier 1) plus elementary terms;
12+
# - x > 2: the trapezoidal rule on K_nu(x) = int_0^inf e^{-x cosh t}
13+
# cosh(nu t) dt. The integrand is double-exponentially decaying, so the
14+
# trapezoidal rule converges geometrically; its e^{-x(cosh t-1)} peak has
15+
# width ~1/sqrt(x), so the nodes are scaled by 1/sqrt(x) to resolve it
16+
# uniformly for all large x.
17+
# Each branch's argument is clamped into its valid region wherever the OTHER
18+
# branch is selected, so the unused branch cannot overflow (i0 at large x) or
19+
# NaN-poison reverse-mode gradients.
20+
#
21+
# Kn (``kn_fallback``) uses the upward recurrence K_{m+1}=K_{m-1}+(2m/x)K_m
22+
# from K0, K1 -- the stable direction for K.
23+
###############################################################################
24+
import numpy
25+
26+
_GAMMA = 0.5772156649015328606 # Euler-Mascheroni
27+
_NSERIES = 30 # ascending-series terms (x <= 2)
28+
_TRAP_H = 0.25 # trapezoidal step (in the 1/sqrt(x)-scaled variable)
29+
_TRAP_N = 64 # trapezoidal nodes
30+
# node positions i*h and weights (h/2 at the endpoint i=0), as numpy constants
31+
_TRAP_NODES = numpy.arange(_TRAP_N + 1) * _TRAP_H
32+
_TRAP_W = numpy.full(_TRAP_N + 1, _TRAP_H)
33+
_TRAP_W[0] = _TRAP_H / 2.0
34+
35+
36+
def _k01(xp, x):
37+
"""Return (K0(x), K1(x)) for real x > 0, ~1e-15 vs scipy, AD-friendly."""
38+
x = xp.asarray(x) * 1.0
39+
inside = x <= 2.0
40+
# Clamp the dead region of each branch into its valid domain.
41+
xs = xp.where(inside, x, xp.ones_like(x)) # series branch (x<=2)
42+
xt = xp.where(inside, 2.0 * xp.ones_like(x), x) # trapezoid branch (x>2)
43+
44+
# --- ascending series (x <= 2), via native i0/i1 ---
45+
from .._router import i0, i1
46+
47+
x2 = xs * xs / 4.0
48+
K0s = -(xp.log(xs / 2.0) + _GAMMA) * i0(xs)
49+
term = xp.ones_like(xs)
50+
harm = 0.0
51+
for k in range(1, _NSERIES):
52+
harm += 1.0 / k
53+
term = term * x2 / (k * k)
54+
K0s = K0s + term * harm
55+
s1 = xp.zeros_like(xs)
56+
term = xp.ones_like(xs)
57+
hk = 0.0
58+
for k in range(0, _NSERIES):
59+
hk1 = hk + 1.0 / (k + 1)
60+
s1 = s1 + term * ((hk + hk1) / 2.0 - _GAMMA)
61+
term = term * x2 / ((k + 1) * (k + 2))
62+
hk = hk1
63+
K1s = 1.0 / xs + xp.log(xs / 2.0) * i1(xs) - (xs / 2.0) * s1
64+
65+
# --- peak-resolving scaled trapezoidal (x > 2) ---
66+
nodes = xp.asarray(_TRAP_NODES)
67+
weights = xp.asarray(_TRAP_W)
68+
sc = 1.0 / xp.sqrt(xt)
69+
t = sc[..., None] * nodes # (..., N+1)
70+
cosh_t = xp.cosh(t)
71+
e = xp.exp(-xt[..., None] * cosh_t) * weights
72+
K0t = xp.sum(e, axis=-1) * sc
73+
K1t = xp.sum(e * cosh_t, axis=-1) * sc
74+
75+
return xp.where(inside, K0s, K0t), xp.where(inside, K1s, K1t)
76+
77+
78+
def k0_fallback(xp, x):
79+
"""Modified Bessel function of the second kind, order 0."""
80+
return _k01(xp, x)[0]
81+
82+
83+
def k1_fallback(xp, x):
84+
"""Modified Bessel function of the second kind, order 1."""
85+
return _k01(xp, x)[1]
86+
87+
88+
def kn_fallback(xp, n, x):
89+
"""Integer-order K_n(x) via the stable upward recurrence from K0, K1."""
90+
n = int(n)
91+
km1, k = _k01(xp, x) # K0, K1
92+
if n == 0:
93+
return km1
94+
if n == 1:
95+
return k
96+
x = xp.asarray(x) * 1.0
97+
for m in range(1, n):
98+
km1, k = k, km1 + (2.0 * m / x) * k
99+
return k
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
###############################################################################
2+
# Backend-agnostic Gegenbauer (ultraspherical) polynomials C_n^alpha(x) for
3+
# 0 <= n < N, via the standard three-term recurrence
4+
# C_0 = 1, C_1 = 2 alpha x,
5+
# (n+1) C_{n+1} = 2(n+alpha) x C_n - (n+2 alpha-1) C_{n-1}.
6+
# This is the SCFPotential radial basis (galpy.potential.SCFPotential._C uses
7+
# the same recurrence with alpha = 2l + 3/2). Built with lists + xp.stack (no
8+
# in-place mutation), so it differentiates under jax and torch; the numpy path
9+
# reproduces SCF's existing recurrence value-for-value.
10+
###############################################################################
11+
12+
13+
def gegenbauer(xp, N, alpha, x):
14+
"""C_n^alpha(x) for 0 <= n < N, shape ``x.shape + (N,)``.
15+
16+
N is a static int, alpha a scalar, x a backend array (or scalar).
17+
"""
18+
x = xp.asarray(x) * 1.0
19+
cols = [xp.ones_like(x)] # C_0 = 1
20+
if N > 1:
21+
cnm1 = cols[0]
22+
cn = 2.0 * alpha * x # C_1 = 2 alpha x
23+
cols.append(cn)
24+
for n in range(1, N - 1):
25+
cnp1 = (2.0 * (n + alpha) * x * cn - (n + 2.0 * alpha - 1.0) * cnm1) / (
26+
n + 1.0
27+
)
28+
cols.append(cnp1)
29+
cnm1, cn = cn, cnp1
30+
return xp.stack(cols, axis=-1)

galpy/backend/special/_router.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# (hasattr on the backend's special module), so entries are removed as backends
1313
# add the native version. (numpy always has the full scipy.special.)
1414
_NATIVE_MISSING = {
15-
"jax": frozenset(("ellipk", "ellipe")),
15+
"jax": frozenset(("ellipk", "ellipe", "k0", "k1", "kn")),
16+
# torch.special lacks all of these. (It does have modified_bessel_k0/k1, but
17+
# they are NOT differentiable -- no autograd backward -- and there is no kn,
18+
# so the k0/k1/kn fallbacks are used; the router sees no torch.special.k0.)
1619
"torch": frozenset(
17-
("gamma", "ellipk", "ellipe", "hyp2f1", "hyp1f1")
18-
), # torch.special lacks all of these
20+
("gamma", "ellipk", "ellipe", "hyp2f1", "hyp1f1", "k0", "k1", "kn")
21+
),
1922
}
2023

2124
# Functions whose native implementation EXISTS but is too inaccurate on galpy's
@@ -152,6 +155,66 @@ def ellipe(m):
152155
return _dispatch("ellipe", (m,), ellipe_fallback)
153156

154157

158+
# --- Tier 3: modified Bessel functions of the second kind (disk force paths) --
159+
def k0(x):
160+
from ._fallback.bessel_k import k0_fallback
161+
162+
return _dispatch("k0", (x,), k0_fallback)
163+
164+
165+
def k1(x):
166+
from ._fallback.bessel_k import k1_fallback
167+
168+
return _dispatch("k1", (x,), k1_fallback)
169+
170+
171+
def kn(n, x):
172+
# Integer-order modified Bessel K_n; only the array arg x carries the namespace.
173+
from ._fallback.bessel_k import kn_fallback
174+
175+
return _dispatch("kn", (n, x), kn_fallback, ns_args=(x,))
176+
177+
178+
# --- Tier 4: associated Legendre P_l^m (SCF / MultipoleExpansion) -------------
179+
def _scipy_assoc_legendre(L, M, x, deriv):
180+
"""numpy path: scipy.special.assoc_legendre_p_all reshaped to (...,L,M),
181+
byte-identical to scipy (the convention used by util.special.compute_legendre)."""
182+
import scipy.special as sp
183+
184+
arr = numpy.asarray(
185+
sp.assoc_legendre_p_all(
186+
L - 1, M - 1, numpy.asarray(x, dtype=float), branch_cut=2, diff_n=deriv
187+
)
188+
) # (deriv+1, L, 2M-1, *x.shape) -- m=0..M-1 are the first M columns
189+
out = numpy.moveaxis(arr[:, :, :M], (1, 2), (-2, -1)) # (deriv+1, *x.shape, L, M)
190+
return out[0] if deriv == 0 else tuple(out[i] for i in range(deriv + 1))
191+
192+
193+
def assoc_legendre(L, M, x, deriv=0):
194+
"""P_l^m(x) for 0<=l<L, 0<=m<M (Condon-Shortley phase), shape x.shape+(L,M).
195+
196+
deriv: 0 -> P; 1 -> (P, dP/dx); 2 -> (P, dP/dx, d2P/dx2). numpy routes to
197+
scipy (byte-identical); jax/torch use the pure-backend Bonnet recurrence.
198+
"""
199+
name, _ = _backend_special(get_namespace(x))
200+
if name == "numpy":
201+
return _scipy_assoc_legendre(L, M, x, deriv)
202+
from ._fallback.assoc_legendre import assoc_legendre as _fb
203+
204+
return _fb(get_namespace(x), L, M, x, deriv)
205+
206+
207+
def gegenbauer(N, alpha, x):
208+
"""Gegenbauer polynomials C_n^alpha(x) for 0<=n<N, shape x.shape+(N,).
209+
210+
N static int, alpha scalar, x a backend array. Uses the three-term
211+
recurrence on every backend (galpy's SCF radial basis never used a scipy
212+
Gegenbauer, so there is no native to prefer)."""
213+
from ._fallback.gegenbauer import gegenbauer as _fb
214+
215+
return _fb(get_namespace(x), N, alpha, x)
216+
217+
155218
def xlogy(x, y):
156219
# x * log(y), with the scipy/native convention 0 * log(0) = 0.
157220
from ._fallback.xlogy import xlogy_fallback

0 commit comments

Comments
 (0)