backend: coerce all-numpy operands in promote_scalars under a forced backend#991
backend: coerce all-numpy operands in promote_scalars under a forced backend#991jobovy wants to merge 2 commits into
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## feat/backends #991 +/- ##
==================================================
- Coverage 99.93% 46.87% -53.06%
==================================================
Files 254 254
Lines 39836 40668 +832
Branches 838 843 +5
==================================================
- Hits 39810 19064 -20746
- Misses 26 21604 +21578 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
…backend promote_scalars passed its inputs through unchanged when none was a backend array to anchor on (`ref is None`), assuming "the namespace's functions handle scalars". That holds for jax but NOT for torch: torch.cos/sqrt/... reject numpy.float64 / python floats, so under a forced torch backend every promote_scalars caller (the coords transforms cyl_to_rect / rect_to_cyl / ..., OblateStaeckelWrapperPotential) crashed on all-numpy inputs. Coerce the operands via coerce_coords in that branch instead. Routing that branch through coerce_coords surfaced that #987's promote_scalars refactor had silently dropped the device-reject fallback (the device-less asarray retry when a namespace rejects the ref's .device value -- array-api jax exposes .device as the string 'cpu' and jnp.asarray(device='cpu') raises ValueError), leaving its test a no-op (the mock ref is no longer detected as a backend array after #987's is_backend_array switch). Restore the fallback in asarray_on_device (catch TypeError / ValueError -> device-less asarray; a genuine dtype error re-raises from the fallback so it is not masked) and rewrite the test to exercise asarray_on_device directly and deterministically. The numpy path is byte-identical (the `xp is numpy` guard short-circuits, and asarray_on_device's device branch is only taken when a backend array supplies a device). jax value-identical under x64. Fixes the migrated RotateAndTilt / Offset / OblateStaeckel / Kuzmin wrapper torch entries that route coordinates through these transforms, plus 11 test_coords and 11 test_quantity torch cases. New tests/test_backend_coerce.py covers the coercion branch. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
9713b4a to
0d3c7de
Compare
All-backend test status (jax / torch)Commit Green is achieved via the checked-in xfail-ledger ( Overall: jax: 1058 passed · 264 xfail · 725 deferred | torch: 998 passed · 1039 xfail · 1 deferred · 6 FAIL/ERR Ledger size: 2357 entries (jax=284, torch=2073).
Per-shard counts
|
…oses The promote_scalars all-numpy coercion (this PR) makes the coords transforms return backend arrays under a forced torch backend; the UNMIGRATED streamdf (Track F) feeds that output into its numpy track-building, so these 6 tests -- which were accidentally passing because the old coords pass-through kept streamdf's numpy path alive under forced torch -- now correctly fail (wrong track / missing _interpolatedObsTrackAA). They join the 29 existing streamdf-torch xfails (streamdf is unmigrated; the numpy / default path is byte-identical and unaffected). They get un-ledgered when streamdf is migrated. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
What
promote_scalarspassed its inputs through unchanged when none was a backend array to anchor on (ref is None), assuming "the namespace's functions handle scalars." That holds for jax but not torch:torch.cos/sqrt/… rejectnumpy.float64/ python floats. So under a forced torch backend, everypromote_scalarscaller — thecoordstransforms (cyl_to_rect/rect_to_cyl/…) andOblateStaeckelWrapperPotential— crashed on all-numpy inputs:This is the root cause behind the migrated RotateAndTilt / Offset / OblateStaeckel / Kuzmin wrapper torch failures (they route coordinates through these transforms). The fix coerces the operands via
coerce_coordsin that branch:One spot in the central coercion helper, rather than patching each transform.
Safety (adversarially reviewed, two independent passes, CPU-forced)
xp is numpyguard short-circuits first; verified object-identical pass-through and identical SHA-256 output hashes forcyl_to_rect/rect_to_cyl/cyl_to_spher/Rz_to_uv/uv_to_Rzover a 2000-point grid.0.0and autodiff still flows (jax.gradthrough a coords transform is finite). jax failure sets are byte-identical (functional no-op for jax).test_coordstorch 20→9 failures (fixes 11),test_quantitytorch 110→99 (fixes 11), all remaining are strict subsets; OblateStaeckelWrapper unaffected.integrateFullOrbit.py) pinsxp=numpy, hitting the guard; no caller depends on the old pass-through returning a raw scalar.numpy-array + scalartorch bug (the array was anchored-on but itself left un-coerced).Tests
New
tests/test_backend_coerce.py: numpy object-identity pass-through; forced-jax/torch coercion to backend float64 (the fixed branch); the anchored path unchanged; and a coords-transform integration check (fails atcoords.py:1164without the fix). Lives in thetest_backend*coverage shard.File-disjoint from #990 (Potential.py) and the merged #989. Part of the torch potential burndown; the per-potential
_anchorfixes (PowerSpherical/SpiralArms/EllipticalDisk/…) are a separate themed PR.🤖 Generated with Claude Code