Skip to content

Commit e3203b3

Browse files
jobovyclaude
andauthored
potential/backend: anchor scalar coordinates in migrated analytic/disk compute methods (#992)
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 0c0103a commit e3203b3

5 files changed

Lines changed: 90 additions & 16 deletions

galpy/potential/EllipticalDiskPotential.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
###############################################################################
55
import numpy
66

7-
from ..backend import get_namespace
7+
from ..backend import coerce_coords, get_namespace
88
from ..util import conversion
99
from .planarPotential import planarPotential
1010

@@ -119,13 +119,15 @@ def _smooth(self, t):
119119

120120
def _evaluate(self, R, phi=0.0, t=0.0):
121121
xp = get_namespace(R, phi, t)
122+
R, phi = coerce_coords(xp, R, phi)
122123
smooth = self._smooth(t)
123124
return (
124125
smooth * self._twophio / 2.0 * R**self._p * xp.cos(2.0 * (phi - self._phib))
125126
)
126127

127128
def _Rforce(self, R, phi=0.0, t=0.0):
128129
xp = get_namespace(R, phi, t)
130+
R, phi = coerce_coords(xp, R, phi)
129131
smooth = self._smooth(t)
130132
return (
131133
-smooth
@@ -138,11 +140,13 @@ def _Rforce(self, R, phi=0.0, t=0.0):
138140

139141
def _phitorque(self, R, phi=0.0, t=0.0):
140142
xp = get_namespace(R, phi, t)
143+
R, phi = coerce_coords(xp, R, phi)
141144
smooth = self._smooth(t)
142145
return smooth * self._twophio * R**self._p * xp.sin(2.0 * (phi - self._phib))
143146

144147
def _R2deriv(self, R, phi=0.0, t=0.0):
145148
xp = get_namespace(R, phi, t)
149+
R, phi = coerce_coords(xp, R, phi)
146150
smooth = self._smooth(t)
147151
return (
148152
smooth
@@ -156,6 +160,7 @@ def _R2deriv(self, R, phi=0.0, t=0.0):
156160

157161
def _phi2deriv(self, R, phi=0.0, t=0.0):
158162
xp = get_namespace(R, phi, t)
163+
R, phi = coerce_coords(xp, R, phi)
159164
smooth = self._smooth(t)
160165
return (
161166
-2.0
@@ -167,6 +172,7 @@ def _phi2deriv(self, R, phi=0.0, t=0.0):
167172

168173
def _Rphideriv(self, R, phi=0.0, t=0.0):
169174
xp = get_namespace(R, phi, t)
175+
R, phi = coerce_coords(xp, R, phi)
170176
smooth = self._smooth(t)
171177
return (
172178
-smooth

galpy/potential/KuijkenDubinskiDiskExpansionPotential.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy
88
from scipy import integrate
99

10-
from ..backend import get_namespace
10+
from ..backend import coerce_coords, get_namespace
1111
from ..backend.special import logsumexp
1212
from .Potential import Potential
1313

@@ -147,24 +147,51 @@ def _parse_Sigma_dict_indiv(self, Sigma):
147147
# (get_namespace), so they run under numpy (byte-identical: the numpy
148148
# namespace IS the numpy module), jax, and torch.
149149
stype = Sigma.get("type", "exp")
150+
# These closures are also called directly with numpy/python R (e.g. by
151+
# the Sigma-derivative tests) while the resolved namespace is a forced
152+
# backend, so coerce_coords R onto that backend before xp.exp(R); the
153+
# numpy pass-through keeps the numpy path byte-identical.
150154
if stype == "exp" and not "Rhole" in Sigma:
151155
rd = Sigma.get("h", 1.0 / 3.0)
152156
ta = Sigma.get("amp", 1.0)
153-
ts = lambda R, trd=rd: get_namespace(R).exp(-R / trd)
154-
tds = lambda R, trd=rd: -get_namespace(R).exp(-R / trd) / trd
155-
td2s = lambda R, trd=rd: get_namespace(R).exp(-R / trd) / trd**2.0
157+
158+
def ts(R, trd=rd):
159+
xp = get_namespace(R)
160+
(R,) = coerce_coords(xp, R)
161+
return xp.exp(-R / trd)
162+
163+
def tds(R, trd=rd):
164+
xp = get_namespace(R)
165+
(R,) = coerce_coords(xp, R)
166+
return -xp.exp(-R / trd) / trd
167+
168+
def td2s(R, trd=rd):
169+
xp = get_namespace(R)
170+
(R,) = coerce_coords(xp, R)
171+
return xp.exp(-R / trd) / trd**2.0
172+
156173
elif stype == "expwhole" or (stype == "exp" and "Rhole" in Sigma):
157174
rd = Sigma.get("h", 1.0 / 3.0)
158175
rm = Sigma.get("Rhole", 0.5)
159176
ta = Sigma.get("amp", 1.0)
160-
ts = lambda R, trd=rd, trm=rm: get_namespace(R).exp(-trm / R - R / trd)
161-
tds = lambda R, trd=rd, trm=rm: (
162-
(trm / R**2.0 - 1.0 / trd) * get_namespace(R).exp(-trm / R - R / trd)
163-
)
164-
td2s = lambda R, trd=rd, trm=rm: (
165-
((trm / R**2.0 - 1.0 / trd) ** 2.0 - 2.0 * trm / R**3.0)
166-
* get_namespace(R).exp(-trm / R - R / trd)
167-
)
177+
178+
def ts(R, trd=rd, trm=rm):
179+
xp = get_namespace(R)
180+
(R,) = coerce_coords(xp, R)
181+
return xp.exp(-trm / R - R / trd)
182+
183+
def tds(R, trd=rd, trm=rm):
184+
xp = get_namespace(R)
185+
(R,) = coerce_coords(xp, R)
186+
return (trm / R**2.0 - 1.0 / trd) * xp.exp(-trm / R - R / trd)
187+
188+
def td2s(R, trd=rd, trm=rm):
189+
xp = get_namespace(R)
190+
(R,) = coerce_coords(xp, R)
191+
return (
192+
(trm / R**2.0 - 1.0 / trd) ** 2.0 - 2.0 * trm / R**3.0
193+
) * xp.exp(-trm / R - R / trd)
194+
168195
return (ta, ts, tds, td2s)
169196

170197
def _parse_hz(self, hz, Hz, dHzdz):
@@ -214,20 +241,27 @@ def _parse_hz_dict_indiv(self, hz):
214241
# bit-for-bit on real floats, xp.stack of same-shape inputs ==
215242
# numpy.array of that list, and galpy.backend.special.logsumexp routes
216243
# numpy to scipy.special.logsumexp -- so the numpy path is unchanged.
244+
# As in _parse_Sigma_dict_indiv, these closures are also called directly
245+
# with numpy/python z while the resolved namespace is a forced backend,
246+
# so coerce_coords z onto that backend before xp.abs/exp/sign(z); the
247+
# numpy pass-through keeps the numpy path byte-identical.
217248
htype = hz.get("type", "exp")
218249
if htype == "exp":
219250
zd = hz.get("h", 0.0375)
220251

221252
def th(z, tzd=zd):
222253
xp = get_namespace(z)
254+
(z,) = coerce_coords(xp, z)
223255
return 1.0 / 2.0 / tzd * xp.exp(-xp.abs(z) / tzd)
224256

225257
def tH(z, tzd=zd):
226258
xp = get_namespace(z)
259+
(z,) = coerce_coords(xp, z)
227260
return (xp.exp(-xp.abs(z) / tzd) - 1.0 + xp.abs(z) / tzd) * tzd / 2.0
228261

229262
def tdH(z, tzd=zd):
230263
xp = get_namespace(z)
264+
(z,) = coerce_coords(xp, z)
231265
return 0.5 * xp.sign(z) * (1.0 - xp.exp(-xp.abs(z) / tzd))
232266

233267
elif htype == "sech2":
@@ -236,6 +270,7 @@ def tdH(z, tzd=zd):
236270
# th/tH written so as to avoid overflow in cosh
237271
def th(z, tzd=zd):
238272
xp = get_namespace(z)
273+
(z,) = coerce_coords(xp, z)
239274
return (
240275
xp.exp(
241276
-logsumexp(
@@ -250,13 +285,15 @@ def th(z, tzd=zd):
250285

251286
def tH(z, tzd=zd):
252287
xp = get_namespace(z)
288+
(z,) = coerce_coords(xp, z)
253289
return tzd * (
254290
logsumexp(xp.stack([z / 2.0 / tzd, -z / 2.0 / tzd]), axis=0)
255291
- numpy.log(2.0)
256292
)
257293

258294
def tdH(z, tzd=zd):
259295
xp = get_namespace(z)
296+
(z,) = coerce_coords(xp, z)
260297
return xp.tanh(z / 2.0 / tzd) / 2.0
261298

262299
return (th, tH, tdH)
@@ -265,6 +302,10 @@ def _evaluate(self, R, z, phi=0.0, t=0.0):
265302
# Here and below: out-of-place accumulation (out = out + ...) instead of
266303
# += so torch autograd never sees an in-place op; identical numpy values.
267304
xp = get_namespace(R, z)
305+
# Coerce R/z onto the active backend so xp.sqrt and the Sigma/hz closures
306+
# (xp.exp/xp.abs(...)) receive backend arrays, not numpy/python; numpy
307+
# pass-through keeps this byte-identical.
308+
R, z = coerce_coords(xp, R, z)
268309
r = xp.sqrt(R**2.0 + z**2.0)
269310
out = self._me(R, z, phi=phi, t=t, use_physical=False)
270311
for a, s, H in zip(self._Sigma_amp, self._Sigma, self._Hz):
@@ -273,6 +314,7 @@ def _evaluate(self, R, z, phi=0.0, t=0.0):
273314

274315
def _Rforce(self, R, z, phi=0, t=0):
275316
xp = get_namespace(R, z)
317+
R, z = coerce_coords(xp, R, z)
276318
r = xp.sqrt(R**2.0 + z**2.0)
277319
out = self._me.Rforce(R, z, phi=phi, t=t, use_physical=False)
278320
for a, ds, H in zip(self._Sigma_amp, self._dSigmadR, self._Hz):
@@ -281,6 +323,7 @@ def _Rforce(self, R, z, phi=0, t=0):
281323

282324
def _zforce(self, R, z, phi=0, t=0):
283325
xp = get_namespace(R, z)
326+
R, z = coerce_coords(xp, R, z)
284327
r = xp.sqrt(R**2.0 + z**2.0)
285328
out = self._me.zforce(R, z, phi=phi, t=t, use_physical=False)
286329
for a, s, ds, H, dH in zip(
@@ -294,6 +337,7 @@ def _phitorque(self, R, z, phi=0.0, t=0.0):
294337

295338
def _R2deriv(self, R, z, phi=0.0, t=0.0):
296339
xp = get_namespace(R, z)
340+
R, z = coerce_coords(xp, R, z)
297341
r = xp.sqrt(R**2.0 + z**2.0)
298342
out = self._me.R2deriv(R, z, phi=phi, t=t, use_physical=False)
299343
for a, ds, d2s, H in zip(
@@ -311,6 +355,7 @@ def _R2deriv(self, R, z, phi=0.0, t=0.0):
311355

312356
def _z2deriv(self, R, z, phi=0.0, t=0.0):
313357
xp = get_namespace(R, z)
358+
R, z = coerce_coords(xp, R, z)
314359
r = xp.sqrt(R**2.0 + z**2.0)
315360
out = self._me.z2deriv(R, z, phi=phi, t=t, use_physical=False)
316361
for a, s, ds, d2s, h, H, dH in zip(
@@ -336,6 +381,7 @@ def _z2deriv(self, R, z, phi=0.0, t=0.0):
336381

337382
def _Rzderiv(self, R, z, phi=0.0, t=0.0):
338383
xp = get_namespace(R, z)
384+
R, z = coerce_coords(xp, R, z)
339385
r = xp.sqrt(R**2.0 + z**2.0)
340386
out = self._me.Rzderiv(R, z, phi=phi, t=t, use_physical=False)
341387
for a, ds, d2s, H, dH in zip(
@@ -354,6 +400,7 @@ def _phi2deriv(self, R, z, phi=0.0, t=0.0):
354400

355401
def _dens(self, R, z, phi=0.0, t=0.0):
356402
xp = get_namespace(R, z)
403+
R, z = coerce_coords(xp, R, z)
357404
r = xp.sqrt(R**2.0 + z**2.0)
358405
out = self._me.dens(R, z, phi=phi, t=t, use_physical=False)
359406
for a, s, ds, d2s, h, H, dH in zip(
@@ -395,6 +442,7 @@ def phiME_dens(R, z, phi, dens, Sigma, dSigmadR, d2SigmadR2, hz, Hz, dHzdz, Sigm
395442
"""The density corresponding to phi_ME (backend-agnostic provided that the
396443
user-supplied ``dens`` callable accepts backend arrays)"""
397444
xp = get_namespace(R, z)
445+
R, z = coerce_coords(xp, R, z)
398446
r = xp.sqrt(R**2.0 + z**2.0)
399447
out = dens(R, z, phi)
400448
for a, s, ds, d2s, h, H, dH in zip(

galpy/potential/PowerSphericalPotential.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy
1212
from scipy import special
1313

14-
from ..backend import get_namespace
14+
from ..backend import coerce_coords, get_namespace
1515
from ..util import conversion
1616
from .Potential import Potential
1717

@@ -89,6 +89,7 @@ def _evaluate(self, R, z, phi=0.0, t=0.0):
8989
- Started: 2010-07-10 by Bovy (NYU)
9090
"""
9191
xp = get_namespace(R, z)
92+
R, z = coerce_coords(xp, R, z)
9293
r2 = R**2.0 + z**2.0
9394
if self.alpha == 2.0:
9495
return xp.log(r2) / 2.0
@@ -281,6 +282,7 @@ def _dens(self, R, z, phi=0.0, t=0.0):
281282
- 2013-01-09 - Written - Bovy (IAS)
282283
"""
283284
xp = get_namespace(R, z)
285+
R, z = coerce_coords(xp, R, z)
284286
r = xp.sqrt(R**2.0 + z**2.0)
285287
return (3.0 - self.alpha) / 4.0 / math.pi / r**self.alpha
286288

galpy/potential/PowerSphericalPotentialwCutoff.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy
99
from scipy import special
1010

11-
from ..backend import get_namespace
11+
from ..backend import coerce_coords, get_namespace
1212
from ..backend.special import gamma as _gamma
1313
from ..backend.special import gammainc as _gammainc
1414
from ..util import conversion
@@ -78,6 +78,7 @@ def __init__(
7878

7979
def _evaluate(self, R, z, phi=0.0, t=0.0):
8080
xp = get_namespace(R, z)
81+
R, z = coerce_coords(xp, R, z)
8182
r = xp.sqrt(R**2.0 + z**2.0)
8283
# guard r=0 in the dead branch (the 1/r term -> 0/0=NaN there) so the
8384
# xp.where stays finite under autodiff/jit; the value at r=0 is 0.
@@ -115,30 +116,35 @@ def _rforce(self, r):
115116

116117
def _Rforce(self, R, z, phi=0.0, t=0.0):
117118
xp = get_namespace(R, z)
119+
R, z = coerce_coords(xp, R, z)
118120
r = xp.sqrt(R * R + z * z)
119121
return self._rforce(r) * R / r
120122

121123
def _zforce(self, R, z, phi=0.0, t=0.0):
122124
xp = get_namespace(R, z)
125+
R, z = coerce_coords(xp, R, z)
123126
r = xp.sqrt(R * R + z * z)
124127
return self._rforce(r) * z / r
125128

126129
def _R2deriv(self, R, z, phi=0.0, t=0.0):
127130
xp = get_namespace(R, z)
131+
R, z = coerce_coords(xp, R, z)
128132
r = xp.sqrt(R * R + z * z)
129133
return 4.0 * numpy.pi * r ** (-2.0 - self.alpha) * xp.exp(
130134
-((r / self.rc) ** 2.0)
131135
) * R**2.0 + self._mass(r) / r**5.0 * (z**2.0 - 2.0 * R**2.0)
132136

133137
def _z2deriv(self, R, z, phi=0.0, t=0.0):
134138
xp = get_namespace(R, z)
139+
R, z = coerce_coords(xp, R, z)
135140
r = xp.sqrt(R * R + z * z)
136141
return 4.0 * numpy.pi * r ** (-2.0 - self.alpha) * xp.exp(
137142
-((r / self.rc) ** 2.0)
138143
) * z**2.0 + self._mass(r) / r**5.0 * (R**2.0 - 2.0 * z**2.0)
139144

140145
def _Rzderiv(self, R, z, phi=0.0, t=0.0):
141146
xp = get_namespace(R, z)
147+
R, z = coerce_coords(xp, R, z)
142148
r = xp.sqrt(R * R + z * z)
143149
return (
144150
R
@@ -224,6 +230,7 @@ def _ddenstwobetadr(self, r, beta=0):
224230

225231
def _dens(self, R, z, phi=0.0, t=0.0):
226232
xp = get_namespace(R, z)
233+
R, z = coerce_coords(xp, R, z)
227234
r = xp.sqrt(R**2.0 + z**2.0)
228235
return 1.0 / r**self.alpha * xp.exp(-((r / self.rc) ** 2.0))
229236

0 commit comments

Comments
 (0)