Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a899775
Multi-backend (JAX / PyTorch / array-API) support for galpy
jobovy Jun 5, 2026
95f1985
Differentiable potential parameters: pass backend arrays through the …
jobovy Jun 6, 2026
4f59dac
P2.1 Spherical potentials: backend-agnostic namespace-swap (#893)
jobovy Jun 6, 2026
c47f6bc
P2.2 Analytic disk potentials: backend-agnostic namespace-swap (#894)
jobovy Jun 6, 2026
f84aac2
Strengthen spherical+disk backend gradient tests with analytic identi…
jobovy Jun 7, 2026
72e2c98
P2.3 Ellipsoidal/triaxial potentials: backend-agnostic namespace-swap…
jobovy Jun 7, 2026
1307ed8
Audit/fix spherical xp.where dead-branch regimes (R=0/seams) (#914)
jobovy Jun 8, 2026
0e80463
P2.4 Halo/bar/non-axisymmetric potentials: backend-agnostic namespace…
jobovy Jun 8, 2026
67e5f6b
Pspecial: native-preferring special-function router (Tiers 1-2) (#916)
jobovy Jun 8, 2026
e139124
In-backend differentiable orbit integration (diffrax / torchdiffeq) (…
jobovy Jun 8, 2026
2e16b8a
Pspecial Tiers 3-4: Bessel K + associated Legendre + Gegenbauer (#917)
jobovy Jun 9, 2026
e4d9dff
PowerSphericalPotentialwCutoff: backend-agnostic / jit-clean (+ torch…
jobovy Jun 9, 2026
c925cdd
P2.5: EinastoPotential backend-agnostic (numpy/jax/torch) (#927)
jobovy Jun 10, 2026
d59a0d7
P2.5: DoubleExponentialDiskPotential backend-agnostic (Ogata quadratu…
jobovy Jun 10, 2026
b5ad070
P2.5: KingPotential / interpSphericalPotential backend-agnostic splin…
jobovy Jun 10, 2026
16ed93e
P2.5: MultipoleExpansionPotential backend-agnostic (static path) (#931)
jobovy Jun 10, 2026
c5315f7
P2.5: SCFPotential backend-agnostic / jit-clean (jax/torch) (#930)
jobovy Jun 10, 2026
d3cccc0
P2.5: DiskSCF/KuijkenDubinski/DiskMultipole backend-agnostic correcti…
jobovy Jun 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ jobs:
REQUIRES_ASTROQUERY: false
REQUIRES_NUMBA: false
REQUIRES_JAX: true
- os: ubuntu-latest
python-version: "3.14"
TEST_FILES: tests/test_backend*.py
REQUIRES_PYNBODY: false
REQUIRES_ASTROPY: false
REQUIRES_ASTROQUERY: false
REQUIRES_NUMBA: false
REQUIRES_JAX: true
REQUIRES_TORCH: true
- os: ubuntu-latest
python-version: "3.14"
TEST_FILES: tests/test_actionAngleTorus.py tests/test_conversion.py tests/test_galpypaper.py tests/test_import.py tests/test_interp_potential.py tests/test_kuzminkutuzov.py tests/test_util.py
Expand Down Expand Up @@ -285,7 +294,14 @@ jobs:
pip install --upgrade --force-reinstall setuptools
- name: Install JAX
if: ${{ matrix.REQUIRES_JAX }}
run: pip install jax jaxlib
# diffrax: backend ODE integrator (jax) used by the test_backend_inbackend_ode
# tests, so galpy.backend._reference/_jax is exercised and covered.
run: pip install jax jaxlib array-api-compat diffrax
- name: Install PyTorch (CPU)
if: ${{ matrix.REQUIRES_TORCH }}
run: |
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install torchdiffeq # backend ODE integrator (torch); covers galpy.backend._torch
- name: Install torus code
env:
TEST_FILES: ${{ matrix.TEST_FILES }}
Expand Down
3 changes: 3 additions & 0 deletions doc/source/examples/galpyrc
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ astropy-coords = True

[warnings]
verbose= False

[backend]
default = numpy
5 changes: 5 additions & 0 deletions doc/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ currently used:

* To set options related to whether or not to check for new versions of galpy (``do-check= False`` turns all such checks off; ``check-non-interactive`` sets whether or not to do the version check in non-interactive (script) sessions; ``check-non-interactive`` sets the cadence of how often to check for version updates in non-interactive sessions [in days; interactive sessions always check]; ``last-non-interactive-check`` is an internal variable to store when the last check occurred)

* To set the default array backend (``default = numpy``, the standard numpy/scipy code path). ``galpy`` resolves the backend by following the data — passing JAX or PyTorch arrays makes ``galpy`` compute (and differentiate) with that backend automatically — so this default only applies when there is no array to dispatch on. It can be set to ``jax`` or ``torch`` to make those the default (this requires the corresponding optional dependency, installable with ``pip install galpy[jax]`` or ``pip install galpy[torch]``); at runtime it can be overridden with the ``galpy.backend.use(...)`` context manager. With ``default = numpy`` (the default) there is no change to the numpy behavior and no extra dependency is needed.

The current configuration file therefore looks like this::

[normalization]
Expand All @@ -604,6 +606,9 @@ The current configuration file therefore looks like this::
check-non-interactive-every = 1
last-non-interactive-check = 2000-01-01

[backend]
default = numpy

where ``ro`` is the distance scale specified in kpc, ``vo`` the
velocity scale in km/s, and the setting is to *not* return output as a
Quantity. These are the current default settings.
Expand Down
28 changes: 28 additions & 0 deletions galpy/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
###############################################################################
# galpy.backend: multi-backend (numpy / jax / torch / array-API) dispatch.
#
# The whole of galpy's pure-Python compute layer resolves its array namespace
# through ``get_namespace`` so that the same code runs and differentiates under
# numpy, JAX, and PyTorch. Backend selection follows the data first (the type
# of the array arguments), with an explicit ``xp=`` override and a
# context-manager/global default as fallbacks. See ``_resolver`` for details.
###############################################################################
from ._namespaces import is_backend_array
from ._resolver import (
_seed_from_config,
backend,
get_namespace,
set_default_backend,
use,
)

# Seed the default backend from the [backend] section of the config file.
_seed_from_config()

__all__ = [
"get_namespace",
"backend",
"use",
"set_default_backend",
"is_backend_array",
]
7 changes: 7 additions & 0 deletions galpy/backend/_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
###############################################################################
# galpy.backend._jax: jax-specific backend implementations.
#
# Code here imports jax / diffrax directly (it is genuinely jax-specific, as
# opposed to the data-first xp-dispatch compute layer). The torch counterparts
# live in galpy.backend._torch.
###############################################################################
33 changes: 33 additions & 0 deletions galpy/backend/_jax/orbit_ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
###############################################################################
# galpy.backend._jax.orbit_ode: jax (diffrax) in-backend orbit integration.
#
# The jax-specific half of galpy.backend._reference.integrate_orbit. Integrates
# the shared backend-agnostic EOM (_eom_rhs) with diffrax. The torch counterpart
# is galpy.backend._torch.orbit_ode.
###############################################################################


def integrate(pot, y0, ts, *, rtol, atol, max_steps):
"""Integrate the EOM with diffrax (Dopri8, adaptive). y0/ys in rectangular
EOM variables [x, vx, y, vy, z, vz]. Reverse-mode differentiable (diffrax uses
a custom_vjp -> forward-mode jacfwd is unavailable; use jacrev)."""
import diffrax
import jax.numpy as jnp

from .._reference.inbackend_ode import _eom_rhs

def field(t, y, args):
return jnp.stack(_eom_rhs(y, pot, t, jnp))

sol = diffrax.diffeqsolve(
diffrax.ODETerm(field),
diffrax.Dopri8(),
t0=ts[0],
t1=ts[-1],
dt0=None,
y0=y0,
saveat=diffrax.SaveAt(ts=ts),
stepsize_controller=diffrax.PIDController(rtol=rtol, atol=atol),
max_steps=max_steps,
)
return sol.ys
101 changes: 101 additions & 0 deletions galpy/backend/_namespaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
###############################################################################
# galpy.backend._namespaces: helpers mapping backend names to array
# namespaces and small namespace-agnostic utilities.
###############################################################################
import numpy

from ..util._optional_deps import (
_ARRAY_API_COMPAT_LOADED,
_JAX_LOADED,
_TORCH_LOADED,
)

# Canonical backend names accepted throughout galpy.backend
_NUMPY_NAMES = frozenset(("numpy", "np"))
_JAX_NAMES = frozenset(("jax", "jnp", "jax.numpy"))
_TORCH_NAMES = frozenset(("torch", "pytorch"))


def _is_python_scalar(x):
"""True for plain Python scalars (and None), which carry no backend info."""
return x is None or isinstance(x, (bool, int, float, complex))


def is_backend_array(x):
"""True if ``x`` is a non-numpy backend array (a jax or torch array/tensor).

Plain Python scalars, ``None``, numpy arrays/scalars, astropy Quantities, and
anything the backend layer does not recognise return ``False`` -- so the
numpy/Quantity code paths stay byte-identical and only genuine backend arrays
(including traced ones, so autodiff w.r.t. parameters works) take any
pass-through branch keyed on this. Detection is by direct ``isinstance``
against the public ``jax.Array`` / ``torch.Tensor`` base classes, gated on the
optional-dependency flags so a numpy-only install never imports jax/torch.
"""
if _is_python_scalar(x) or isinstance(x, (numpy.ndarray, numpy.generic)):
return False
if _JAX_LOADED:
import jax

if isinstance(x, jax.Array):
return True
if _TORCH_LOADED:
import torch

if isinstance(x, torch.Tensor):
return True
return False


def namespace_for_name(name):
"""Map a backend name ('numpy'|'jax'|'torch') to its array namespace module.

numpy resolves to the *plain* numpy module (so the numpy code path is
byte-identical to today); jax/torch resolve to their array-API namespaces.
"""
if not isinstance(name, str):
# Already a namespace module; pass through.
return name
lname = name.lower()
if lname in _NUMPY_NAMES:
return numpy
if lname in _JAX_NAMES:
if not _JAX_LOADED: # pragma: no cover - defensive: needs jax absent
raise ImportError("galpy backend 'jax' requested but jax is not installed")
import jax.numpy as jnp

return jnp
if lname in _TORCH_NAMES:
if not _TORCH_LOADED: # pragma: no cover - defensive: needs torch absent
raise ImportError(
"galpy backend 'torch' requested but torch is not installed"
)
import array_api_compat.torch as txp

return txp
raise ValueError(f"unknown galpy backend '{name}'")


def namespace_from_arrays(arrays):
"""Infer the array namespace from the (non-scalar) array arguments.

Returns the plain numpy module when every array-like argument is a numpy
array (byte-identical numpy path), the appropriate jax/torch namespace when
a tracked array is present, or None when there is nothing array-like to
dispatch on (so the caller can fall through to the context/global default).
"""
arrs = [a for a in arrays if not _is_python_scalar(a)]
if not arrs:
return None
if all(isinstance(a, (numpy.ndarray, numpy.generic)) for a in arrs):
return numpy
if not _ARRAY_API_COMPAT_LOADED: # pragma: no cover - backend extra installs it
raise ImportError(
"galpy's non-numpy backends require array-api-compat "
"(pip install array-api-compat, or galpy[jax]/galpy[torch])"
)
import array_api_compat

# Non-numpy arrays only reach here (numpy is handled by the fast path above),
# so this returns the jax / array-api-compat-torch namespace.
return array_api_compat.array_namespace(*arrs)
11 changes: 11 additions & 0 deletions galpy/backend/_reference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
###############################################################################
# galpy.backend._reference: differentiable reference implementations.
#
# In-backend ODE orbit integration (diffrax / torchdiffeq) of galpy's
# backend-agnostic forces -- the fully-differentiable orbit path for jax/torch
# and the independent correctness reference for the fast C state-transition-
# matrix path.
###############################################################################
from .inbackend_ode import integrate_orbit

__all__ = ["integrate_orbit"]
127 changes: 127 additions & 0 deletions galpy/backend/_reference/inbackend_ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
###############################################################################
# galpy.backend._reference.inbackend_ode
#
# Differentiable orbit integration *inside* the backend: galpy's equations of
# motion are evaluated through the backend-agnostic force layer
# (_evaluateRforces / _evaluatephitorques / _evaluatezforces -- the
# underscored, decorator-free evaluators, as the Python integrators use) and
# integrated by the backend's own ODE solver (diffrax for jax, torchdiffeq for
# torch), so gradients of the orbit w.r.t. initial conditions AND potential
# parameters fall out of autodiff -- no hand-coded variational equations.
#
# This is the reference, higher-order-differentiable orbit path for the
# jax/torch backends, and the independent cross-check for the fast C
# state-transition-matrix path (Track B/D of the backend plan).
#
# Convention: phase-space vectors use galpy's Orbit ordering
# ``[R, vR, vT, z, vz, phi]``. Internally the EOM is integrated in
# *rectangular* variables ``[x, vx, y, vy, z, vz]`` -- matching galpy's C
# integrator (integrateFullOrbit.c) rather than the cylindrical Python _EOM --
# which avoids the 1/R centrifugal term and the coordinate singularity at the
# axis, and is naturally well-behaved for autodiff. The public input/output are
# transformed to/from ``Orbit`` order so they match ``Orbit``.
###############################################################################
from .. import get_namespace


def _eom_rhs(y, pot, t, xp):
"""Backend-agnostic rectangular EOM, state y = [x, vx, y, vy, z, vz].

Mirrors galpy.orbit.integrateFullOrbit._rectForce: recover (R, phi) from the
Cartesian position, evaluate the (decorator-free) force layer -- which
dispatches on the array type of (R, z, phi) so this runs and differentiates
under any backend -- and rotate the planar force back to Cartesian. Returns a
length-6 tuple of time-derivatives.
"""
# Imported lazily so importing this module never forces the potential import
# graph at galpy import time.
from ...potential.Potential import (
_evaluatephitorques,
_evaluateRforces,
_evaluatezforces,
)

x, vx, yy, vy, z, vz = y[0], y[1], y[2], y[3], y[4], y[5]
R = xp.sqrt(x**2 + yy**2)
phi = xp.arctan2(yy, x)
cosphi, sinphi = x / R, yy / R
# v=[vR, vT, vz]; only used by velocity-dependent/dissipative forces.
vR = vx * cosphi + vy * sinphi
vT = -vx * sinphi + vy * cosphi
v = [vR, vT, vz]
Rforce = _evaluateRforces(pot, R, z, phi=phi, t=t, v=v)
phitorque = _evaluatephitorques(pot, R, z, phi=phi, t=t, v=v)
ax = cosphi * Rforce - sinphi / R * phitorque
ay = sinphi * Rforce + cosphi / R * phitorque
az = _evaluatezforces(pot, R, z, phi=phi, t=t, v=v)
return vx, ax, vy, ay, vz, az


def _to_eom(xp, vxvv):
"""[R,vR,vT,z,vz,phi] (Orbit order) -> [x,vx,y,vy,z,vz] (rectangular EOM)."""
R, vR, vT, z, vz, phi = (vxvv[i] for i in range(6))
cosphi, sinphi = xp.cos(phi), xp.sin(phi)
return xp.stack(
[
R * cosphi,
vR * cosphi - vT * sinphi,
R * sinphi,
vR * sinphi + vT * cosphi,
z,
vz,
]
)


def _from_eom(xp, ys):
"""[...,x,vx,y,vy,z,vz] -> [...,R,vR,vT,z,vz,phi] (Orbit order)."""
x, vx, yy, vy, z, vz = (ys[..., i] for i in range(6))
R = xp.sqrt(x**2 + yy**2)
phi = xp.arctan2(yy, x) # in [-pi, pi], matching Orbit's wrapped convention
cosphi, sinphi = x / R, yy / R
vR = vx * cosphi + vy * sinphi
vT = -vx * sinphi + vy * cosphi
return xp.stack([R, vR, vT, z, vz, phi], axis=-1)


def integrate_orbit(pot, vxvv, ts, *, rtol=1e-10, atol=1e-10, max_steps=100000):
"""Differentiably integrate a 3D orbit with the backend's ODE solver.

Parameters
----------
pot : Potential (or list) -- must be backend-migrated for the chosen backend.
vxvv : backend array, shape (6,), ``[R, vR, vT, z, vz, phi]`` (Orbit order).
Its namespace selects the integrator: jax -> diffrax, torch -> torchdiffeq.
ts : backend array of output times (monotonic; ts[0] is the initial time).
rtol, atol : solver tolerances.
max_steps : diffrax step cap (jax only).

Returns
-------
backend array, shape ``(len(ts), 6)``, the orbit in ``[R, vR, vT, z, vz, phi]``.

Notes
-----
Differentiable w.r.t. ``vxvv`` (initial conditions) and, when the parameter is
supplied as a backend array, w.r.t. the potential's parameters. ``phi`` is
recovered from the Cartesian position, so it is wrapped to [-pi, pi] as in
``Orbit``.
"""
xp = get_namespace(vxvv)
name = xp.__name__
y0 = _to_eom(xp, vxvv)
# backend-specific integrators live in galpy.backend._jax / ._torch
if "jax" in name:
from .._jax.orbit_ode import integrate as _integrate_jax

ys = _integrate_jax(pot, y0, ts, rtol=rtol, atol=atol, max_steps=max_steps)
elif "torch" in name:
from .._torch.orbit_ode import integrate as _integrate_torch

ys = _integrate_torch(pot, y0, ts, rtol=rtol, atol=atol)
else: # numpy path uses galpy's C/scipy integrators instead
raise NotImplementedError(
"in-backend ODE integration requires a jax or torch input array; "
"for numpy use Orbit.integrate (C / scipy integrators)"
)
return _from_eom(xp, ys)
Loading
Loading