|
12 | 12 | # (hasattr on the backend's special module), so entries are removed as backends |
13 | 13 | # add the native version. (numpy always has the full scipy.special.) |
14 | 14 | _NATIVE_MISSING = { |
15 | | - "jax": frozenset(("ellipk", "ellipe")), |
| 15 | + "jax": frozenset(("ellipk", "ellipe", "k0", "k1", "kn")), |
| 16 | + # torch.special lacks all of these. (It does have modified_bessel_k0/k1, but |
| 17 | + # they are NOT differentiable -- no autograd backward -- and there is no kn, |
| 18 | + # so the k0/k1/kn fallbacks are used; the router sees no torch.special.k0.) |
16 | 19 | "torch": frozenset( |
17 | | - ("gamma", "ellipk", "ellipe", "hyp2f1", "hyp1f1") |
18 | | - ), # torch.special lacks all of these |
| 20 | + ("gamma", "ellipk", "ellipe", "hyp2f1", "hyp1f1", "k0", "k1", "kn") |
| 21 | + ), |
19 | 22 | } |
20 | 23 |
|
21 | 24 | # Functions whose native implementation EXISTS but is too inaccurate on galpy's |
@@ -152,6 +155,66 @@ def ellipe(m): |
152 | 155 | return _dispatch("ellipe", (m,), ellipe_fallback) |
153 | 156 |
|
154 | 157 |
|
| 158 | +# --- Tier 3: modified Bessel functions of the second kind (disk force paths) -- |
| 159 | +def k0(x): |
| 160 | + from ._fallback.bessel_k import k0_fallback |
| 161 | + |
| 162 | + return _dispatch("k0", (x,), k0_fallback) |
| 163 | + |
| 164 | + |
| 165 | +def k1(x): |
| 166 | + from ._fallback.bessel_k import k1_fallback |
| 167 | + |
| 168 | + return _dispatch("k1", (x,), k1_fallback) |
| 169 | + |
| 170 | + |
| 171 | +def kn(n, x): |
| 172 | + # Integer-order modified Bessel K_n; only the array arg x carries the namespace. |
| 173 | + from ._fallback.bessel_k import kn_fallback |
| 174 | + |
| 175 | + return _dispatch("kn", (n, x), kn_fallback, ns_args=(x,)) |
| 176 | + |
| 177 | + |
| 178 | +# --- Tier 4: associated Legendre P_l^m (SCF / MultipoleExpansion) ------------- |
| 179 | +def _scipy_assoc_legendre(L, M, x, deriv): |
| 180 | + """numpy path: scipy.special.assoc_legendre_p_all reshaped to (...,L,M), |
| 181 | + byte-identical to scipy (the convention used by util.special.compute_legendre).""" |
| 182 | + import scipy.special as sp |
| 183 | + |
| 184 | + arr = numpy.asarray( |
| 185 | + sp.assoc_legendre_p_all( |
| 186 | + L - 1, M - 1, numpy.asarray(x, dtype=float), branch_cut=2, diff_n=deriv |
| 187 | + ) |
| 188 | + ) # (deriv+1, L, 2M-1, *x.shape) -- m=0..M-1 are the first M columns |
| 189 | + out = numpy.moveaxis(arr[:, :, :M], (1, 2), (-2, -1)) # (deriv+1, *x.shape, L, M) |
| 190 | + return out[0] if deriv == 0 else tuple(out[i] for i in range(deriv + 1)) |
| 191 | + |
| 192 | + |
| 193 | +def assoc_legendre(L, M, x, deriv=0): |
| 194 | + """P_l^m(x) for 0<=l<L, 0<=m<M (Condon-Shortley phase), shape x.shape+(L,M). |
| 195 | +
|
| 196 | + deriv: 0 -> P; 1 -> (P, dP/dx); 2 -> (P, dP/dx, d2P/dx2). numpy routes to |
| 197 | + scipy (byte-identical); jax/torch use the pure-backend Bonnet recurrence. |
| 198 | + """ |
| 199 | + name, _ = _backend_special(get_namespace(x)) |
| 200 | + if name == "numpy": |
| 201 | + return _scipy_assoc_legendre(L, M, x, deriv) |
| 202 | + from ._fallback.assoc_legendre import assoc_legendre as _fb |
| 203 | + |
| 204 | + return _fb(get_namespace(x), L, M, x, deriv) |
| 205 | + |
| 206 | + |
| 207 | +def gegenbauer(N, alpha, x): |
| 208 | + """Gegenbauer polynomials C_n^alpha(x) for 0<=n<N, shape x.shape+(N,). |
| 209 | +
|
| 210 | + N static int, alpha scalar, x a backend array. Uses the three-term |
| 211 | + recurrence on every backend (galpy's SCF radial basis never used a scipy |
| 212 | + Gegenbauer, so there is no native to prefer).""" |
| 213 | + from ._fallback.gegenbauer import gegenbauer as _fb |
| 214 | + |
| 215 | + return _fb(get_namespace(x), N, alpha, x) |
| 216 | + |
| 217 | + |
155 | 218 | def xlogy(x, y): |
156 | 219 | # x * log(y), with the scipy/native convention 0 * log(0) = 0. |
157 | 220 | from ._fallback.xlogy import xlogy_fallback |
|
0 commit comments