Skip to content

potential/backend: anchor construction-time scalar phi to float64 in EllipsoidalPotential#989

Merged
jobovy merged 1 commit into
feat/backendsfrom
backend/ellipsoidal-construct-float64-amp
Jun 20, 2026
Merged

potential/backend: anchor construction-time scalar phi to float64 in EllipsoidalPotential#989
jobovy merged 1 commit into
feat/backendsfrom
backend/ellipsoidal-construct-float64-amp

Conversation

@jobovy

@jobovy jobovy commented Jun 20, 2026

Copy link
Copy Markdown
Owner

What

Anchor a construction-time python-scalar phi to float64 in EllipsoidalPotential._anchor_phi.

During __init__, normalize() calls Rforce(1., 0.) with plain python scalars before _backend_compatible is set, so the @potential_physical_input boundary does not coerce them. Under a forced non-numpy backend _anchor_phi then saw a dtype-less scalar R and fell back to xp.asarray(phi, dtype=None) — which on torch (whose CI jobs keep the float32 default, unlike jax's x64) produced a float32 phi, dropping the computed _amp to float32 and silently losing precision in every later float64-coordinate evaluation.

The fix: when R carries no dtype, anchor phi to xp.float64 (galpy's interior precision).

 if xp is numpy or is_backend_array(phi):
     return phi
-return xp.asarray(phi, dtype=getattr(R, "dtype", None))
+dtype = getattr(R, "dtype", None)
+if dtype is None:
+    dtype = xp.float64
+return xp.asarray(phi, dtype=dtype)

Result

Fixes the 6 oblate/prolate/triaxial Hernquist & Jaffe test_amp_mult_divide torch failures (baseline residuals ~1e-7, i.e. float32-magnitude, vs the 1e-10 tolerance). They flip xfail→XPASS (ledger stays green via strict=False; the stale entries are pruned in the separate ledger-regen PR, not here).

Safety (adversarially reviewed, CPU-forced)

  • numpy byte-identical — the xp is numpy guard short-circuits before any new code; verified max abs diff 0.0 over a full evaluate/forces/2nd-deriv grid.
  • exit-cast policy preserved — a genuine float32 R array still anchors phi to float32 (the dtype is None branch is only taken for a dtype-less scalar).
  • autodiff w.r.t. phi untouched — a real backend phi array returns via the is_backend_array short-circuit (no cast/detach).
  • jax unaffected (x64 already gives float64); no regressions across numpy/jax/torch in test_backend_ellipsoidal.py (687 passed).
  • New test test_construct_normalize_amp_is_float64 reproduces the CI torch float32-default condition and pins the construction-time _amp to float64 (fails on baseline, passes with the fix). It also covers the new branch under the coverage-uploading test_backend* shard.

Notes / divergences

  • Investigated under "stabilize the GL sum order", but the Gauss-Legendre quadrature was provably not the cause (vectorizing _potInt left the tests failing). The fix is a one-spot dtype anchor — the branch was renamed off gl-stabilize to reflect that.
  • Follow-up (separate PR): the Ellipsoidal subclasses (PerfectEllipsoid, TwoPowerTriaxial/TriaxialHernquist/TriaxialJaffe, TriaxialGaussian, PowerTriaxial) still set _backend_compatible=True after normalize(), unlike the ~10 potentials backend: central coordinate coercion — fix systemic torch scalar-input rejection #960 reordered (only TriaxialNFW was reordered). Reordering them is the root-cause cleanup; this guard is the correct, byte-identical, independently-defensible boundary fix and keeps the branch covered.
  • triaxialLogarithmicHaloPotential / mockRotatedAndTiltedTriaxialLogHaloPotential torch entries are out of scope — LogarithmicHaloPotential is not an EllipsoidalPotential subclass (already backend: central coordinate coercion — fix systemic torch scalar-input rejection #960-reordered; any residual float32 there is a different mechanism).

🤖 Generated with Claude Code

@codecov

codecov Bot commented Jun 20, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.93%. Comparing base (69659c6) to head (4effc36).
⚠️ Report is 2 commits behind head on feat/backends.

Additional details and impacted files
@@               Coverage Diff               @@
##           feat/backends     #989    +/-   ##
===============================================
  Coverage          99.93%   99.93%            
===============================================
  Files                254      254            
  Lines              39820    40276   +456     
  Branches             837      837            
===============================================
+ Hits               39795    40251   +456     
  Misses                25       25            

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

…EllipsoidalPotential

During __init__, normalize() calls Rforce(1., 0.) with plain python scalars
before _backend_compatible is set, so the @potential_physical_input boundary
does not coerce them. Under a forced non-numpy backend _anchor_phi then saw a
dtype-less scalar R and fell back to xp.asarray(phi, dtype=None) -- which on
torch (whose CI jobs keep the float32 default, unlike jax's x64) produced a
float32 phi, dropping the computed _amp to float32 and silently losing
precision in every later float64-coordinate evaluation. Anchor the scalar phi
to xp.float64 (galpy's interior precision) when R carries no dtype.

Fixes the 6 oblate/prolate/triaxial Hernquist & Jaffe amp_mult_divide torch
failures (residuals ~1e-7, float32-magnitude). The numpy path stays
byte-identical (the `xp is numpy` guard short-circuits first); genuine float32
R inputs still keep float32 (the exit-cast policy is preserved); autodiff
w.r.t. a backend `phi` array is untouched (is_backend_array short-circuits).
Adds a forced-backend construction test pinning the normalize() _amp to float64
(reproduces the CI torch float32-default condition).

Note: this was investigated under "stabilize the GL sum order", but the
Gauss-Legendre quadrature was provably not the cause (vectorizing _potInt left
the tests failing) -- the fix is a one-spot dtype anchor, hence the branch was
renamed off "gl-stabilize".

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
@jobovy jobovy force-pushed the backend/ellipsoidal-construct-float64-amp branch from f0e31f2 to 4effc36 Compare June 20, 2026 20:33
@jobovy jobovy enabled auto-merge (squash) June 20, 2026 21:45
@github-actions

github-actions Bot commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

All-backend test status (jax / torch)

Commit a771a5bed478b5db2c9da183d394492dec6a7114

Green is achieved via the checked-in xfail-ledger (tests/backend_xfail.txt, applied xfail(strict=False)), so the metric to watch is the shrinking xfail count (burndown), not a raw pass count. A FAIL/ERR is an un-ledgered regression (reds the run). Because the ledger is non-strict, a now-passing ledgered test is a plain pass here (no per-push XPASS); burndown candidates -- in both directions -- are surfaced by the scheduled regen run, which rewrites the ledger from real outcomes. deferred is a separate burndown: tests skipped because they are unrunnable under the backend until the port is vectorized (see tests/backend_slow_skip.txt), e.g. the jax spherical-DF sampling/quadrature tests pending the Track F DF migration.

Overall: jax: 1056 passed · 266 xfail · 725 deferred | torch: 833 passed · 1210 xfail · 1 deferred

Ledger size: 2357 entries (jax=284, torch=2073).

Test shard jax torch
actionAngle ✅ 112 pass · 89 xfail ✅ 26 pass · 175 xfail
sphericaldf ✅ 164 pass · 26 xfail · 28 deferred ✅ 8 pass · 210 xfail
conversion + util + misc ✅ 85 pass · 6 xfail · 1 deferred ✅ 41 pass · 51 xfail
potential + scf + multipole — (no result) — (no result)
quantity + coords ✅ 287 pass · 49 xfail ✅ 189 pass · 147 xfail
orbit (energy/Jacobi + from_name) ✅ 0 pass · 0 xfail · 115 deferred ✅ 63 pass · 52 xfail
orbit + orbits (main) ✅ 0 pass · 0 xfail · 578 deferred ✅ 248 pass · 327 xfail
evolveddiskdf ✅ 35 pass · 0 xfail ✅ 32 pass · 3 xfail
jeans + dynamfric ✅ 17 pass · 2 xfail · 2 deferred ✅ 7 pass · 13 xfail · 1 deferred
qdf + pv2qdf + streamgapdf_impulse + noninertial ✅ 57 pass · 75 xfail · 1 deferred ✅ 14 pass · 119 xfail
streamgapdf ✅ 28 pass · 2 xfail ✅ 27 pass · 3 xfail
diskdf ✅ 129 pass · 0 xfail ✅ 112 pass · 17 xfail
streamdf + streamspraydf + streamTrack ✅ 142 pass · 17 xfail ✅ 66 pass · 93 xfail
Per-shard counts
Test shard backend pass xfail deferred XPASS fail error
actionAngle jax 112 89 0 0 0 0
actionAngle torch 26 175 0 0 0 0
sphericaldf jax 164 26 28 0 0 0
sphericaldf torch 8 210 0 0 0 0
conversion + util + misc jax 85 6 1 0 0 0
conversion + util + misc torch 41 51 0 0 0 0
potential + scf + multipole jax
potential + scf + multipole torch
quantity + coords jax 287 49 0 0 0 0
quantity + coords torch 189 147 0 0 0 0
orbit (energy/Jacobi + from_name) jax 0 0 115 0 0 0
orbit (energy/Jacobi + from_name) torch 63 52 0 0 0 0
orbit + orbits (main) jax 0 0 578 0 0 0
orbit + orbits (main) torch 248 327 0 0 0 0
evolveddiskdf jax 35 0 0 0 0 0
evolveddiskdf torch 32 3 0 0 0 0
jeans + dynamfric jax 17 2 2 0 0 0
jeans + dynamfric torch 7 13 1 0 0 0
qdf + pv2qdf + streamgapdf_impulse + noninertial jax 57 75 1 0 0 0
qdf + pv2qdf + streamgapdf_impulse + noninertial torch 14 119 0 0 0 0
streamgapdf jax 28 2 0 0 0 0
streamgapdf torch 27 3 0 0 0 0
diskdf jax 129 0 0 0 0 0
diskdf torch 112 17 0 0 0 0
streamdf + streamspraydf + streamTrack jax 142 17 0 0 0 0
streamdf + streamspraydf + streamTrack torch 66 93 0 0 0 0

@jobovy jobovy merged commit 509ea17 into feat/backends Jun 20, 2026
146 checks passed
@jobovy jobovy deleted the backend/ellipsoidal-construct-float64-amp branch June 20, 2026 23:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant