77import numpy
88from scipy import integrate
99
10- from ..backend import get_namespace
10+ from ..backend import coerce_coords , get_namespace
1111from ..backend .special import logsumexp
1212from .Potential import Potential
1313
@@ -147,24 +147,51 @@ def _parse_Sigma_dict_indiv(self, Sigma):
147147 # (get_namespace), so they run under numpy (byte-identical: the numpy
148148 # namespace IS the numpy module), jax, and torch.
149149 stype = Sigma .get ("type" , "exp" )
150+ # These closures are also called directly with numpy/python R (e.g. by
151+ # the Sigma-derivative tests) while the resolved namespace is a forced
152+ # backend, so coerce_coords R onto that backend before xp.exp(R); the
153+ # numpy pass-through keeps the numpy path byte-identical.
150154 if stype == "exp" and not "Rhole" in Sigma :
151155 rd = Sigma .get ("h" , 1.0 / 3.0 )
152156 ta = Sigma .get ("amp" , 1.0 )
153- ts = lambda R , trd = rd : get_namespace (R ).exp (- R / trd )
154- tds = lambda R , trd = rd : - get_namespace (R ).exp (- R / trd ) / trd
155- td2s = lambda R , trd = rd : get_namespace (R ).exp (- R / trd ) / trd ** 2.0
157+
158+ def ts (R , trd = rd ):
159+ xp = get_namespace (R )
160+ (R ,) = coerce_coords (xp , R )
161+ return xp .exp (- R / trd )
162+
163+ def tds (R , trd = rd ):
164+ xp = get_namespace (R )
165+ (R ,) = coerce_coords (xp , R )
166+ return - xp .exp (- R / trd ) / trd
167+
168+ def td2s (R , trd = rd ):
169+ xp = get_namespace (R )
170+ (R ,) = coerce_coords (xp , R )
171+ return xp .exp (- R / trd ) / trd ** 2.0
172+
156173 elif stype == "expwhole" or (stype == "exp" and "Rhole" in Sigma ):
157174 rd = Sigma .get ("h" , 1.0 / 3.0 )
158175 rm = Sigma .get ("Rhole" , 0.5 )
159176 ta = Sigma .get ("amp" , 1.0 )
160- ts = lambda R , trd = rd , trm = rm : get_namespace (R ).exp (- trm / R - R / trd )
161- tds = lambda R , trd = rd , trm = rm : (
162- (trm / R ** 2.0 - 1.0 / trd ) * get_namespace (R ).exp (- trm / R - R / trd )
163- )
164- td2s = lambda R , trd = rd , trm = rm : (
165- ((trm / R ** 2.0 - 1.0 / trd ) ** 2.0 - 2.0 * trm / R ** 3.0 )
166- * get_namespace (R ).exp (- trm / R - R / trd )
167- )
177+
178+ def ts (R , trd = rd , trm = rm ):
179+ xp = get_namespace (R )
180+ (R ,) = coerce_coords (xp , R )
181+ return xp .exp (- trm / R - R / trd )
182+
183+ def tds (R , trd = rd , trm = rm ):
184+ xp = get_namespace (R )
185+ (R ,) = coerce_coords (xp , R )
186+ return (trm / R ** 2.0 - 1.0 / trd ) * xp .exp (- trm / R - R / trd )
187+
188+ def td2s (R , trd = rd , trm = rm ):
189+ xp = get_namespace (R )
190+ (R ,) = coerce_coords (xp , R )
191+ return (
192+ (trm / R ** 2.0 - 1.0 / trd ) ** 2.0 - 2.0 * trm / R ** 3.0
193+ ) * xp .exp (- trm / R - R / trd )
194+
168195 return (ta , ts , tds , td2s )
169196
170197 def _parse_hz (self , hz , Hz , dHzdz ):
@@ -214,20 +241,27 @@ def _parse_hz_dict_indiv(self, hz):
214241 # bit-for-bit on real floats, xp.stack of same-shape inputs ==
215242 # numpy.array of that list, and galpy.backend.special.logsumexp routes
216243 # numpy to scipy.special.logsumexp -- so the numpy path is unchanged.
244+ # As in _parse_Sigma_dict_indiv, these closures are also called directly
245+ # with numpy/python z while the resolved namespace is a forced backend,
246+ # so coerce_coords z onto that backend before xp.abs/exp/sign(z); the
247+ # numpy pass-through keeps the numpy path byte-identical.
217248 htype = hz .get ("type" , "exp" )
218249 if htype == "exp" :
219250 zd = hz .get ("h" , 0.0375 )
220251
221252 def th (z , tzd = zd ):
222253 xp = get_namespace (z )
254+ (z ,) = coerce_coords (xp , z )
223255 return 1.0 / 2.0 / tzd * xp .exp (- xp .abs (z ) / tzd )
224256
225257 def tH (z , tzd = zd ):
226258 xp = get_namespace (z )
259+ (z ,) = coerce_coords (xp , z )
227260 return (xp .exp (- xp .abs (z ) / tzd ) - 1.0 + xp .abs (z ) / tzd ) * tzd / 2.0
228261
229262 def tdH (z , tzd = zd ):
230263 xp = get_namespace (z )
264+ (z ,) = coerce_coords (xp , z )
231265 return 0.5 * xp .sign (z ) * (1.0 - xp .exp (- xp .abs (z ) / tzd ))
232266
233267 elif htype == "sech2" :
@@ -236,6 +270,7 @@ def tdH(z, tzd=zd):
236270 # th/tH written so as to avoid overflow in cosh
237271 def th (z , tzd = zd ):
238272 xp = get_namespace (z )
273+ (z ,) = coerce_coords (xp , z )
239274 return (
240275 xp .exp (
241276 - logsumexp (
@@ -250,13 +285,15 @@ def th(z, tzd=zd):
250285
251286 def tH (z , tzd = zd ):
252287 xp = get_namespace (z )
288+ (z ,) = coerce_coords (xp , z )
253289 return tzd * (
254290 logsumexp (xp .stack ([z / 2.0 / tzd , - z / 2.0 / tzd ]), axis = 0 )
255291 - numpy .log (2.0 )
256292 )
257293
258294 def tdH (z , tzd = zd ):
259295 xp = get_namespace (z )
296+ (z ,) = coerce_coords (xp , z )
260297 return xp .tanh (z / 2.0 / tzd ) / 2.0
261298
262299 return (th , tH , tdH )
@@ -265,6 +302,10 @@ def _evaluate(self, R, z, phi=0.0, t=0.0):
265302 # Here and below: out-of-place accumulation (out = out + ...) instead of
266303 # += so torch autograd never sees an in-place op; identical numpy values.
267304 xp = get_namespace (R , z )
305+ # Coerce R/z onto the active backend so xp.sqrt and the Sigma/hz closures
306+ # (xp.exp/xp.abs(...)) receive backend arrays, not numpy/python; numpy
307+ # pass-through keeps this byte-identical.
308+ R , z = coerce_coords (xp , R , z )
268309 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
269310 out = self ._me (R , z , phi = phi , t = t , use_physical = False )
270311 for a , s , H in zip (self ._Sigma_amp , self ._Sigma , self ._Hz ):
@@ -273,6 +314,7 @@ def _evaluate(self, R, z, phi=0.0, t=0.0):
273314
274315 def _Rforce (self , R , z , phi = 0 , t = 0 ):
275316 xp = get_namespace (R , z )
317+ R , z = coerce_coords (xp , R , z )
276318 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
277319 out = self ._me .Rforce (R , z , phi = phi , t = t , use_physical = False )
278320 for a , ds , H in zip (self ._Sigma_amp , self ._dSigmadR , self ._Hz ):
@@ -281,6 +323,7 @@ def _Rforce(self, R, z, phi=0, t=0):
281323
282324 def _zforce (self , R , z , phi = 0 , t = 0 ):
283325 xp = get_namespace (R , z )
326+ R , z = coerce_coords (xp , R , z )
284327 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
285328 out = self ._me .zforce (R , z , phi = phi , t = t , use_physical = False )
286329 for a , s , ds , H , dH in zip (
@@ -294,6 +337,7 @@ def _phitorque(self, R, z, phi=0.0, t=0.0):
294337
295338 def _R2deriv (self , R , z , phi = 0.0 , t = 0.0 ):
296339 xp = get_namespace (R , z )
340+ R , z = coerce_coords (xp , R , z )
297341 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
298342 out = self ._me .R2deriv (R , z , phi = phi , t = t , use_physical = False )
299343 for a , ds , d2s , H in zip (
@@ -311,6 +355,7 @@ def _R2deriv(self, R, z, phi=0.0, t=0.0):
311355
312356 def _z2deriv (self , R , z , phi = 0.0 , t = 0.0 ):
313357 xp = get_namespace (R , z )
358+ R , z = coerce_coords (xp , R , z )
314359 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
315360 out = self ._me .z2deriv (R , z , phi = phi , t = t , use_physical = False )
316361 for a , s , ds , d2s , h , H , dH in zip (
@@ -336,6 +381,7 @@ def _z2deriv(self, R, z, phi=0.0, t=0.0):
336381
337382 def _Rzderiv (self , R , z , phi = 0.0 , t = 0.0 ):
338383 xp = get_namespace (R , z )
384+ R , z = coerce_coords (xp , R , z )
339385 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
340386 out = self ._me .Rzderiv (R , z , phi = phi , t = t , use_physical = False )
341387 for a , ds , d2s , H , dH in zip (
@@ -354,6 +400,7 @@ def _phi2deriv(self, R, z, phi=0.0, t=0.0):
354400
355401 def _dens (self , R , z , phi = 0.0 , t = 0.0 ):
356402 xp = get_namespace (R , z )
403+ R , z = coerce_coords (xp , R , z )
357404 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
358405 out = self ._me .dens (R , z , phi = phi , t = t , use_physical = False )
359406 for a , s , ds , d2s , h , H , dH in zip (
@@ -395,6 +442,7 @@ def phiME_dens(R, z, phi, dens, Sigma, dSigmadR, d2SigmadR2, hz, Hz, dHzdz, Sigm
395442 """The density corresponding to phi_ME (backend-agnostic provided that the
396443 user-supplied ``dens`` callable accepts backend arrays)"""
397444 xp = get_namespace (R , z )
445+ R , z = coerce_coords (xp , R , z )
398446 r = xp .sqrt (R ** 2.0 + z ** 2.0 )
399447 out = dens (R , z , phi )
400448 for a , s , ds , d2s , h , H , dH in zip (
0 commit comments