Skip to content

Commit 3eb4d57

Browse files
authored
fix: get rid of cupy<14 quick fix (#685)
1 parent 658bb87 commit 3eb4d57

File tree

2 files changed

+12
-26
lines changed

2 files changed

+12
-26
lines changed

src/vector/_compute/spatial/eta.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import annotations
1414

1515
import typing
16-
from math import inf
16+
from math import inf, nan
1717

1818
import numpy
1919

@@ -30,16 +30,12 @@
3030
)
3131

3232

33-
# TODO: https://github.com/scikit-hep/vector/issues/615
34-
# revert back to `nan_to_num` implementation once
35-
# https://github.com/cupy/cupy/issues/9143 is fixed.
36-
# `lib.where` works but there is no SymPy equivalent for the function.
3733
def xy_z(lib, x, y, z):
38-
return (
39-
lib.where(
40-
z != 0, lib.arcsinh(lib.where(z != 0, z / lib.sqrt(x**2 + y**2), z)), z
41-
)
42-
* 1
34+
return lib.nan_to_num(
35+
lib.arcsinh(z / lib.sqrt(x**2 + y**2)),
36+
nan=lib.nan_to_num((z != 0) * inf, posinf=nan),
37+
posinf=inf,
38+
neginf=-inf,
4339
)
4440

4541

@@ -56,12 +52,13 @@ def xy_eta(lib, x, y, eta):
5652
xy_eta.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
5753

5854

59-
# TODO: https://github.com/scikit-hep/vector/issues/615
60-
# revert back to `nan_to_num` implementation once
61-
# https://github.com/cupy/cupy/issues/9143 is fixed.
62-
# `lib.where` works but there is no SymPy equivalent for the function.
6355
def rhophi_z(lib, rho, phi, z):
64-
return lib.where(z != 0, lib.arcsinh(lib.where(z != 0, z / rho, z)), z) * 1
56+
return lib.nan_to_num(
57+
lib.arcsinh(z / rho),
58+
nan=lib.nan_to_num((z != 0) * inf, posinf=nan),
59+
posinf=inf,
60+
neginf=-inf,
61+
)
6562

6663

6764
def rhophi_theta(lib, rho, phi, theta):

src/vector/backends/sympy.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,6 @@ class _lib:
4343
"""a wrapper that maps numpy functions to sympy functions (or custom implementations)"""
4444

4545
# functions modified specifically for sympy
46-
47-
# TODO: https://github.com/scikit-hep/vector/issues/615
48-
# should NOT be used as a replacement for np.where as it works specifically for the
49-
# case of vector/_compute/spatial/eta.py. `where` returns the second argument and
50-
# ignores the third (first argument is a boolean condition) because the purpose of
51-
# this function is to handle exceptional values — we know that the "normal" values
52-
# are in the second argument and the "exceptional" ones are in the third argument.
53-
# remove once https://github.com/cupy/cupy/issues/9143 is fixed.
54-
def where(self, val1: sympy.Expr, val2: sympy.Expr, val3: sympy.Expr) -> sympy.Expr:
55-
return val2
56-
5746
def nan_to_num(self, val: sympy.Expr, **kwargs: typing.Any) -> sympy.Expr:
5847
return val
5948

0 commit comments

Comments
 (0)