Skip to content

Commit 9ff2271

Browse files
jobovyclaude
andauthored
backend: centralize data-coercion helpers into galpy.backend._coerce (#983)c
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent fcad3f7 commit 9ff2271

10 files changed

Lines changed: 202 additions & 124 deletions

galpy/backend/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66
# numpy, JAX, and PyTorch. Backend selection follows the data first (the type
77
# of the array arguments), with an explicit ``xp=`` override and a
88
# context-manager/global default as fallbacks. See ``_resolver`` for details.
9+
#
10+
# See ``_coerce`` for the data-coercion helpers (bringing numpy/Python data
11+
# onto the active backend, anchoring stored constants) and ``_namespaces``
12+
# for the namespace-resolution and dtype/device primitives they build on.
913
###############################################################################
14+
from ._coerce import (
15+
as_backend_constant,
16+
coerce_coords,
17+
promote_scalars,
18+
zeros_like_backend,
19+
)
1020
from ._namespaces import (
1121
asarray_on_device,
1222
device_of,
@@ -33,4 +43,8 @@
3343
"match_input_dtype",
3444
"device_of",
3545
"asarray_on_device",
46+
"as_backend_constant",
47+
"coerce_coords",
48+
"promote_scalars",
49+
"zeros_like_backend",
3650
]

galpy/backend/_coerce.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
###############################################################################
2+
# galpy.backend._coerce: backend DATA-coercion helpers.
3+
###############################################################################
4+
"""Backend data-coercion helpers: the single home for bringing numpy/Python
5+
data onto the active jax/torch backend.
6+
7+
PURPOSE
8+
-------
9+
This module does two related jobs:
10+
11+
* it brings numpy/Python *coordinate* data onto the active backend's array
12+
type (so e.g. ``torch.sqrt`` -- which rejects ``numpy.float64`` -- and
13+
``Tensor`` arithmetic see real backend arrays), and
14+
* it anchors *stored numpy constants* (rotation matrices, lookup tables, a
15+
zero reference coordinate) onto the dtype/device of an input array, so the
16+
constant joins the computation as a same-dtype/same-device backend array.
17+
18+
It is the single home for data-coercion. Namespace *resolution* (which backend
19+
a call dispatches to) and the dtype/device *primitives* it builds on live in
20+
``_namespaces.py`` (``is_backend_array``, ``device_of``, ``asarray_on_device``,
21+
``match_input_dtype``); this module only consumes those primitives.
22+
23+
THE CORE INVARIANT
24+
------------------
25+
Every function here is a STRICT PASS-THROUGH when ``xp is numpy``: it returns
26+
its inputs OBJECT-IDENTICALLY (no asarray, no copy, no dtype touch). This is
27+
what keeps the numpy code path BYTE-IDENTICAL to galpy's historical behaviour.
28+
Any new coercion helper added to this module MUST preserve this invariant --
29+
guard the work behind ``if xp is numpy: return <inputs unchanged>`` first.
30+
31+
WHEN TO USE EACH
32+
----------------
33+
* ``coerce_coords(xp, *coords)`` -- at the PUBLIC INPUT BOUNDARY (the
34+
``@potential_physical_input`` decorator) to bring coordinate arguments onto
35+
the backend: plain Python/int scalars become float64 (galpy's interior
36+
precision), float arrays keep their dtype (so the float32 exit-cast policy
37+
still applies), and ``None`` passes through.
38+
* ``promote_scalars(xp, *vals)`` -- INSIDE coordinate transforms to promote
39+
plain Python scalars sitting alongside array arguments, anchored on the
40+
dtype/device of the first array, so mixed scalar/array inputs work on a
41+
backend whose functions require arrays.
42+
* ``as_backend_constant(xp, value, ref)`` -- to anchor a single STORED numpy
43+
constant (a rotation matrix, an offset, a table) on a backend ``ref`` array
44+
derived from the coordinate inputs.
45+
* ``zeros_like_backend(xp, R)`` -- for a backend ZERO reference coordinate
46+
(e.g. the ``z = 0`` plane a spherical-in-disguise wrapper feeds its wrapped
47+
potential).
48+
49+
WHY float64-INTERIOR / DEVICE-ANCHORING
50+
---------------------------------------
51+
galpy computes in float64 internally: a bare ``asarray`` of a Python float
52+
yields torch float32 and silently misses galpy's tolerances, so plain scalars
53+
are lifted to ``xp.float64`` while genuine float arrays keep their own dtype.
54+
Anchoring constants and promoted scalars on an input array's dtype/device keeps
55+
the whole computation on one device and at one precision, which is required for
56+
torch (cross-device / mixed-dtype ops raise) and correct for jax.
57+
"""
58+
59+
import numpy
60+
61+
from ._namespaces import (
62+
_is_floating_dtype,
63+
asarray_on_device,
64+
device_of,
65+
)
66+
67+
68+
def coerce_coords(xp, *coords):
69+
"""Bring coordinate inputs onto the active backend's array type.
70+
71+
The dominant non-numpy failure mode is "the namespace resolved to a backend
72+
(forced harness, or a user mixing a backend tensor with a numpy/python arg)
73+
but a coordinate is still numpy/python", which torch rejects strictly
74+
(``torch.sqrt(numpy.float64)`` raises; ``numpy.ndarray * Tensor`` raises).
75+
Coercing the coordinates to backend arrays once, at the public input
76+
boundary, fixes it for every potential at once.
77+
78+
Rules (applied only when the backend is NOT numpy):
79+
* ``None`` is passed through (axisymmetric ``phi=None`` etc.).
80+
* a coordinate that already carries a *floating* dtype (a numpy/backend
81+
float32/float64 array or scalar) is moved onto the backend with its
82+
dtype PRESERVED, so the float32/exit-cast policy (``match_input_dtype``)
83+
still applies.
84+
* a plain Python scalar (``1.0``/``1``) or an integer array is brought to
85+
the backend's float64 -- galpy's interior precision; a bare ``asarray``
86+
of a Python float would give torch float32 and miss the tolerances.
87+
88+
The numpy backend is a strict pass-through (``coords`` returned object-
89+
identical) -> the numpy path stays byte-identical.
90+
"""
91+
if xp is numpy:
92+
return coords
93+
dev = device_of(*coords)
94+
out = []
95+
for c in coords:
96+
if c is None:
97+
out.append(c)
98+
continue
99+
dt = getattr(c, "dtype", None)
100+
if dt is not None and _is_floating_dtype(dt):
101+
out.append(asarray_on_device(xp, c, dev)) # preserve float dtype
102+
else:
103+
out.append(asarray_on_device(xp, c, dev, dtype=xp.float64))
104+
return tuple(out)
105+
106+
107+
def promote_scalars(xp, *vals):
108+
"""Promote plain Python scalars among ``vals`` to the active non-numpy
109+
namespace, anchored on the dtype/device of the first array argument, so
110+
that e.g. torch functions -- which require Tensors -- accept the mixed
111+
scalar/array inputs that the numpy path has always supported. The numpy
112+
path passes everything through untouched (byte-identical)."""
113+
if xp is numpy:
114+
return vals
115+
ref = next((v for v in vals if hasattr(v, "ndim")), None)
116+
if ref is None:
117+
# Nothing to anchor on (e.g. all-scalar inputs under a forced backend
118+
# default): pass through, the namespace's functions handle scalars
119+
return vals
120+
dtype = getattr(ref, "dtype", None)
121+
device = getattr(ref, "device", None)
122+
123+
def _promote(v):
124+
if hasattr(v, "ndim"):
125+
return v
126+
try:
127+
return xp.asarray(v, dtype=dtype, device=device)
128+
except (TypeError, ValueError):
129+
# namespace without a device kwarg, or one that rejects the device
130+
# value (array-api jax exposes .device as the string 'cpu', which
131+
# jnp.asarray(device=...) refuses with ValueError). jax tracks device
132+
# automatically, so a plain asarray is correct; torch's .device is a
133+
# real object and never hits this fallback.
134+
return xp.asarray(v, dtype=dtype)
135+
136+
return tuple(_promote(v) for v in vals)
137+
138+
139+
def as_backend_constant(xp, value, ref):
140+
"""Bring a stored numpy constant (rotation matrix / offset) into the active
141+
namespace, anchored on the dtype/device of ``ref`` (a backend array derived
142+
from the coordinate inputs). The numpy path passes the stored array through
143+
untouched (byte-identical)."""
144+
if xp is numpy:
145+
return value
146+
dtype = getattr(ref, "dtype", None)
147+
device = getattr(ref, "device", None)
148+
try:
149+
return xp.asarray(value, dtype=dtype, device=device)
150+
except TypeError: # pragma: no cover - namespace without device= kwarg
151+
return xp.asarray(value, dtype=dtype)
152+
153+
154+
def zeros_like_backend(xp, R):
155+
"""The numpy path passes the plain scalar through untouched
156+
(byte-identical); on a non-numpy backend the z = 0 reference
157+
coordinate is anchored on the inputs so the wrapped potential sees a
158+
backend array (torch functions require Tensors) on the right
159+
device/dtype."""
160+
return 0.0 if xp is numpy else xp.zeros_like(R)

galpy/backend/_namespaces.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -149,45 +149,6 @@ def asarray_on_device(xp, a, device, dtype=None):
149149
return xp.asarray(a, dtype=dtype, device=device)
150150

151151

152-
def coerce_coords(xp, *coords):
153-
"""Bring coordinate inputs onto the active backend's array type.
154-
155-
The dominant non-numpy failure mode is "the namespace resolved to a backend
156-
(forced harness, or a user mixing a backend tensor with a numpy/python arg)
157-
but a coordinate is still numpy/python", which torch rejects strictly
158-
(``torch.sqrt(numpy.float64)`` raises; ``numpy.ndarray * Tensor`` raises).
159-
Coercing the coordinates to backend arrays once, at the public input
160-
boundary, fixes it for every potential at once.
161-
162-
Rules (applied only when the backend is NOT numpy):
163-
* ``None`` is passed through (axisymmetric ``phi=None`` etc.).
164-
* a coordinate that already carries a *floating* dtype (a numpy/backend
165-
float32/float64 array or scalar) is moved onto the backend with its
166-
dtype PRESERVED, so the float32/exit-cast policy (``match_input_dtype``)
167-
still applies.
168-
* a plain Python scalar (``1.0``/``1``) or an integer array is brought to
169-
the backend's float64 -- galpy's interior precision; a bare ``asarray``
170-
of a Python float would give torch float32 and miss the tolerances.
171-
172-
The numpy backend is a strict pass-through (``coords`` returned object-
173-
identical) -> the numpy path stays byte-identical.
174-
"""
175-
if xp is numpy:
176-
return coords
177-
dev = device_of(*coords)
178-
out = []
179-
for c in coords:
180-
if c is None:
181-
out.append(c)
182-
continue
183-
dt = getattr(c, "dtype", None)
184-
if dt is not None and _is_floating_dtype(dt):
185-
out.append(asarray_on_device(xp, c, dev)) # preserve float dtype
186-
else:
187-
out.append(asarray_on_device(xp, c, dev, dtype=xp.float64))
188-
return tuple(out)
189-
190-
191152
def namespace_for_name(name):
192153
"""Map a backend name ('numpy'|'jax'|'torch') to its array namespace module.
193154

galpy/potential/KuzminLikeWrapperPotential.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
###############################################################################
55
import numpy
66

7-
from ..backend import get_namespace
7+
from ..backend import get_namespace, zeros_like_backend
88
from ..util import conversion
99
from .Potential import (
1010
_evaluatePotentials,
@@ -124,38 +124,30 @@ def _d2xidRdz(self, R, z):
124124
* ((self._a + xp.sqrt(self._b2 + z**2.0)) ** 2.0 + R**2.0) ** 1.5
125125
)
126126

127-
def _zero_like(self, xp, R):
128-
# The numpy path passes the plain scalar through untouched
129-
# (byte-identical); on a non-numpy backend the z = 0 reference
130-
# coordinate is anchored on the inputs so the wrapped potential sees a
131-
# backend array (torch functions require Tensors) on the right
132-
# device/dtype.
133-
return 0.0 if xp is numpy else xp.zeros_like(R)
134-
135127
def _evaluate(self, R, z, phi=0.0, t=0.0):
136128
xp = get_namespace(R, z, phi, t)
137129
return _evaluatePotentials(
138-
self._pot, self._xi(R, z), self._zero_like(xp, R), phi=phi, t=t
130+
self._pot, self._xi(R, z), zeros_like_backend(xp, R), phi=phi, t=t
139131
)
140132

141133
def _Rforce(self, R, z, phi=0.0, t=0.0):
142134
xp = get_namespace(R, z, phi, t)
143135
return _evaluateRforces(
144-
self._pot, self._xi(R, z), self._zero_like(xp, R), phi=phi, t=t
136+
self._pot, self._xi(R, z), zeros_like_backend(xp, R), phi=phi, t=t
145137
) * self._dxidR(R, z)
146138

147139
def _zforce(self, R, z, phi=0.0, t=0.0):
148140
xp = get_namespace(R, z, phi, t)
149141
return _evaluateRforces(
150-
self._pot, self._xi(R, z), self._zero_like(xp, R), phi=phi, t=t
142+
self._pot, self._xi(R, z), zeros_like_backend(xp, R), phi=phi, t=t
151143
) * self._dxidz(R, z)
152144

153145
def _phitorque(self, R, z, phi=0.0, t=0.0):
154146
return 0.0
155147

156148
def _R2deriv(self, R, z, phi=0.0, t=0.0):
157149
xp = get_namespace(R, z, phi, t)
158-
zero = self._zero_like(xp, R)
150+
zero = zeros_like_backend(xp, R)
159151
return evaluateR2derivs(
160152
self._pot, self._xi(R, z), zero, phi=phi, t=t
161153
) * self._dxidR(R, z) ** 2.0 - _evaluateRforces(
@@ -164,7 +156,7 @@ def _R2deriv(self, R, z, phi=0.0, t=0.0):
164156

165157
def _z2deriv(self, R, z, phi=0.0, t=0.0):
166158
xp = get_namespace(R, z, phi, t)
167-
zero = self._zero_like(xp, R)
159+
zero = zeros_like_backend(xp, R)
168160
return evaluateR2derivs(
169161
self._pot, self._xi(R, z), zero, phi=phi, t=t
170162
) * self._dxidz(R, z) ** 2.0 - _evaluateRforces(
@@ -173,7 +165,7 @@ def _z2deriv(self, R, z, phi=0.0, t=0.0):
173165

174166
def _Rzderiv(self, R, z, phi=0.0, t=0.0):
175167
xp = get_namespace(R, z, phi, t)
176-
zero = self._zero_like(xp, R)
168+
zero = zeros_like_backend(xp, R)
177169
return evaluateR2derivs(
178170
self._pot, self._xi(R, z), zero, phi=phi, t=t
179171
) * self._dxidR(R, z) * self._dxidz(R, z) - _evaluateRforces(

galpy/potential/OblateStaeckelWrapperPotential.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import numpy
1111

1212
from galpy.util import conversion, coords
13-
from galpy.util.coords import _promote_scalars_for
1413

15-
from ..backend import get_namespace
14+
from ..backend import get_namespace, promote_scalars
1615
from .Potential import (
1716
_APY_LOADED,
1817
_evaluatePotentials,
@@ -474,17 +473,17 @@ def _d2Vdv2(self, v):
474473

475474
def _staeckel_prefactor(u, v):
476475
xp = get_namespace(u, v)
477-
u, v = _promote_scalars_for(xp, u, v)
476+
u, v = promote_scalars(xp, u, v)
478477
return xp.sinh(u) ** 2.0 + xp.sin(v) ** 2.0
479478

480479

481480
def _dstaeckel_prefactordudv(u, v):
482481
xp = get_namespace(u, v)
483-
u, v = _promote_scalars_for(xp, u, v)
482+
u, v = promote_scalars(xp, u, v)
484483
return (2.0 * xp.sinh(u) * xp.cosh(u), 2.0 * xp.sin(v) * xp.cos(v))
485484

486485

487486
def _dstaeckel_prefactord2ud2v(u, v):
488487
xp = get_namespace(u, v)
489-
u, v = _promote_scalars_for(xp, u, v)
488+
u, v = promote_scalars(xp, u, v)
490489
return (2.0 * xp.cosh(2.0 * u), 2.0 * xp.cos(2.0 * v))

galpy/potential/RotateAndTiltWrapperPotential.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
###############################################################################
55
import numpy
66

7-
from ..backend import get_namespace
7+
from ..backend import as_backend_constant, get_namespace
88
from ..util import _rotate_to_arbitrary_vector, conversion, coords
99
from .Potential import (
1010
_evaluatephitorques,
@@ -23,21 +23,6 @@
2323
from .WrapperPotential import WrapperPotential
2424

2525

26-
def _as_xp_constant(xp, value, ref):
27-
"""Bring a stored numpy constant (rotation matrix / offset) into the active
28-
namespace, anchored on the dtype/device of ``ref`` (a backend array derived
29-
from the coordinate inputs). The numpy path passes the stored array through
30-
untouched (byte-identical)."""
31-
if xp is numpy:
32-
return value
33-
dtype = getattr(ref, "dtype", None)
34-
device = getattr(ref, "device", None)
35-
try:
36-
return xp.asarray(value, dtype=dtype, device=device)
37-
except TypeError: # pragma: no cover - namespace without device= kwarg
38-
return xp.asarray(value, dtype=dtype)
39-
40-
4126
# Only implement 3D wrapper
4227
class RotateAndTiltWrapperPotential(WrapperPotential):
4328
"""Potential wrapper that allows a potential to be rotated in 3D
@@ -194,9 +179,9 @@ def _rect_transformed(self, xp, R, z, phi, guard_inf=False):
194179
x, y, z = coords.cyl_to_rect(R, phi, z)
195180
xyzp = xp.stack([x, y, z])
196181
if not self._norot:
197-
xyzp = _as_xp_constant(xp, self._rot, xyzp) @ xyzp
182+
xyzp = as_backend_constant(xp, self._rot, xyzp) @ xyzp
198183
if self._offset is not None:
199-
xyzp = xyzp + _as_xp_constant(xp, self._offset, xyzp)
184+
xyzp = xyzp + as_backend_constant(xp, self._offset, xyzp)
200185
return xyzp
201186

202187
@check_potential_inputs_not_arrays
@@ -233,7 +218,7 @@ def _force_xyz(self, R, z, phi=0.0, t=0.0):
233218
xforcep = xp.cos(phip) * Rforcep - xp.sin(phip) * phitorquep / Rp
234219
yforcep = xp.sin(phip) * Rforcep + xp.cos(phip) * phitorquep / Rp
235220
Fxyzp = xp.stack([xforcep, yforcep, zforcep])
236-
return _as_xp_constant(xp, self._inv_rot, Fxyzp) @ Fxyzp
221+
return as_backend_constant(xp, self._inv_rot, Fxyzp) @ Fxyzp
237222

238223
@check_potential_inputs_not_arrays
239224
def _R2deriv(self, R, z, phi=0.0, t=0.0):
@@ -342,8 +327,8 @@ def _2ndderiv_xyz(self, R, z, phi=0.0, t=0.0):
342327
xp.stack([xzderivp, yzderivp, z2derivp]),
343328
]
344329
)
345-
inv_rot = _as_xp_constant(xp, self._inv_rot, deriv2p)
346-
inv_rot_T = _as_xp_constant(xp, self._inv_rot.T, deriv2p)
330+
inv_rot = as_backend_constant(xp, self._inv_rot, deriv2p)
331+
inv_rot_T = as_backend_constant(xp, self._inv_rot.T, deriv2p)
347332
return inv_rot @ (deriv2p @ inv_rot_T)
348333

349334
@check_potential_inputs_not_arrays

0 commit comments

Comments
 (0)