Skip to content

Commit 4b4b454

Browse files
jobovyclaude
andcommitted
backend: promote _under_jax_trace -> _namespaces.under_jax_trace (shared trace predicate)
The jit/grad/vmap tracer check that gates the eager-loop-vs-lax.fori_loop choice is general (bracket expansion, bisection, any future rolled loop), not a root-finder detail, so it moves next to is_backend_array/device_of in galpy.backend._namespaces and both call sites in optimize.py plus the duplicate inline check in _jax/optimize._bisect_root now share it. Also add a jit+grad-through-Staeckel test: Staeckel's turning points call the module-level bisect_root DIRECTLY (not via brentq), so its under_jax_trace(a,b) -> lax.fori_loop branch was uncovered (the Adiabatic/Spherical/Vertical grad tests route through brentq/iterate_bracket instead). The new test traces actionsFreqsAngles under jax.jit (asserts the jaxpr is rolled) and jax.grad (finite dJr/dR), closing the codecov/patch gap on optimize.py:219-222. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent ac13c97 commit 4b4b454

4 files changed

Lines changed: 54 additions & 21 deletions

File tree

galpy/backend/_jax/optimize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ def _bisect_root(f, a, b, xp, *, xtol, maxiter):
4949
"""
5050
import jax
5151

52+
from .._namespaces import under_jax_trace
5253
from ..optimize import bisect_root, bisect_step, n_bisect_steps
5354

54-
if not any(isinstance(x, jax.core.Tracer) for x in (a, b)):
55+
if not under_jax_trace(a, b): # entered directly with a concrete bracket
5556
return bisect_root(f, a, b, xp, xtol=xtol, maxiter=maxiter)
5657
lo = xp.asarray(a) * 1.0
5758
hi = xp.asarray(b) * 1.0

galpy/backend/_namespaces.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ def is_backend_array(x):
4747
return False
4848

4949

50+
def under_jax_trace(*xs):
51+
"""True iff jax is imported AND one of ``xs`` is a jax tracer (jit/grad/vmap).
52+
53+
The predicate that gates the eager-loop-vs-``lax.fori_loop`` choice wherever
54+
galpy rolls a fixed-schedule loop (bracket expansion, bisection, ...): the
55+
eager Python loop stays byte-identical and ~9x faster outside a trace, while
56+
under a jax trace the same body is rolled into a ``fori_loop`` so its ``n``
57+
embedded copies of the physics closure do not unroll into the user's jaxpr.
58+
59+
Cheap on numpy/torch and on plain (untraced) jax arrays: if ``jax`` is not
60+
even imported we short-circuit to ``False`` (via ``sys.modules``, so the
61+
numpy/torch eager paths never import jax). This is deliberately gated on
62+
``sys.modules`` rather than the ``_JAX_LOADED`` install flag so a jax-
63+
installed-but-unused run (pure numpy/torch) keeps the eager hot path from
64+
importing jax at all.
65+
"""
66+
import sys
67+
68+
if "jax" not in sys.modules:
69+
return False
70+
import jax
71+
72+
return any(isinstance(x, jax.core.Tracer) for x in xs)
73+
74+
5075
def _is_floating_dtype(dtype):
5176
"""True for real floating-point dtypes of any backend.
5277

galpy/backend/optimize.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
# first-order sensitivity; a Hessian through x* would require a true
4747
# custom_jvp differentiating the implicit relation a second time.
4848
###############################################################################
49-
from ._namespaces import is_backend_array
49+
from ._namespaces import is_backend_array, under_jax_trace
5050
from ._resolver import get_namespace
5151

5252
# Default bracketing tolerance (matches scipy.optimize.brentq's xtol default, so
@@ -156,23 +156,6 @@ def bisect_step(lo, hi, slo, f, xp):
156156
return xp.where(same, mid, lo), xp.where(same, hi, mid)
157157

158158

159-
def _under_jax_trace(*xs):
160-
"""True iff jax is imported AND one of ``xs`` is a jax tracer (jit/grad/vmap).
161-
162-
Cheap on numpy/torch and on plain jax arrays: if ``jax`` is not even imported
163-
we short-circuit to ``False`` (so the eager Python loops stay byte-identical
164-
and we never import jax on the numpy/torch path). Under a jax trace the n
165-
copies of the physics closure would otherwise unroll into the user's jaxpr.
166-
"""
167-
import sys
168-
169-
if "jax" not in sys.modules:
170-
return False
171-
import jax
172-
173-
return any(isinstance(x, jax.core.Tracer) for x in xs)
174-
175-
176159
def iterate_bracket(step, x0, n):
177160
"""Run ``x = step(x)`` ``n`` times -- a fixed-schedule, branch-free bracket
178161
expansion/contraction where ``step`` is a single ``xp.where`` update over the
@@ -183,7 +166,7 @@ def iterate_bracket(step, x0, n):
183166
``step`` so the ``n`` embedded copies of ``f`` do NOT unroll into the jaxpr
184167
(mirrors the bisection rolling; eager stays ~9x faster than ``fori_loop``).
185168
"""
186-
if _under_jax_trace(x0):
169+
if under_jax_trace(x0):
187170
import jax
188171

189172
return jax.lax.fori_loop(0, n, lambda _, x: step(x), x0)
@@ -216,7 +199,7 @@ def bisect_root(f, a, b, xp, *, xtol, maxiter):
216199
# DIRECT callers jit-fast too -- notably the Staeckel umin/umax/vmin turning
217200
# points, which call bisect_root straight (not via brentq). Eager keeps the
218201
# Python loop (bit-identical, ~9x faster than fori_loop).
219-
if _under_jax_trace(a, b):
202+
if under_jax_trace(a, b):
220203
from ._jax.optimize import _bisect_root
221204

222205
return _bisect_root(f, a, b, xp, xtol=xtol, maxiter=maxiter)

tests/test_backend_actionAngle.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,30 @@ def test_staeckel_actions_vs_c(backend):
841841
numpy.testing.assert_allclose(_np(jz_b), numpy.asarray(jz_c), rtol=1e-8, atol=1e-9)
842842

843843

844+
@pytest.mark.skipif("jax" not in BACKENDS, reason="jax not installed")
845+
def test_staeckel_jit_grad_rolls_direct_bisection():
846+
# Staeckel's turning points call the module-level ``bisect_root`` DIRECTLY
847+
# (not via ``brentq``). Under a jax trace (jit/grad) the bracket endpoints are
848+
# tracers, so ``bisect_root`` dispatches to the rolled ``lax.fori_loop`` kernel
849+
# -- the ~100-step bisection does NOT unroll ~100 copies of the Staeckel
850+
# integrand into the user's jaxpr. Covers the ``under_jax_trace(a, b)`` branch
851+
# of ``galpy.backend.optimize.bisect_root`` reached only by these direct
852+
# callers (the brentq-based AA methods route through a different kernel).
853+
aA = actionAngleStaeckel(pot=MWPotential2014, delta=0.45, c=False)
854+
R = jnp.asarray(_STK[0])
855+
rest = tuple(jnp.asarray(v) for v in _STK[1:])
856+
jr_e, _, jz_e = aA(R, *rest) # eager (Python-loop) reference
857+
jr_j, _, jz_j = jax.jit(lambda r: aA(r, *rest))(R) # traced (fori_loop) value
858+
numpy.testing.assert_allclose(_np(jr_j), _np(jr_e), rtol=1e-8, atol=1e-10)
859+
numpy.testing.assert_allclose(_np(jz_j), _np(jz_e), rtol=1e-8, atol=1e-10)
860+
# the jaxpr is ROLLED: a loop primitive, not ~100 unrolled bisection steps.
861+
txt = str(jax.make_jaxpr(lambda r: aA(r, *rest)[0])(R))
862+
assert ("while" in txt) or ("scan" in txt)
863+
# grad flows through the direct-bisection turning points: finite dJr/dR.
864+
g = jax.grad(lambda r: jnp.sum(aA(r, *rest)[0]))(R)
865+
assert numpy.all(numpy.isfinite(_np(g)))
866+
867+
844868
@pytest.mark.parametrize("backend", BACKENDS)
845869
def test_staeckel_unbound_backend_no_raise(backend):
846870
# An unbound orbit raises UnboundError on the numpy path (eager), but must NOT

0 commit comments

Comments
 (0)