Skip to content

Commit 84af22a

Browse files
committed
add numba testing now too, probably will revert as it is painfully slow
1 parent 75e5d54 commit 84af22a

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

tests/test_pr_659.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
ak = pytest.importorskip("awkward")
1717
sympy = pytest.importorskip("sympy")
18+
numba = pytest.importorskip("numba")
19+
20+
pytestmark = [pytest.mark.awkward, pytest.mark.sympy, pytest.mark.numba]
1821

19-
pytestmark = [pytest.mark.awkward, pytest.mark.sympy]
2022

2123
ALL_COORDINATES = [
2224
"x",
@@ -41,7 +43,7 @@
4143
]
4244

4345
MOMENTUM_COORDINATES = {"px", "py", "pz", "pt", "E", "e", "energy", "M", "m", "mass"}
44-
46+
AZIMUTHAL_COORDS = {"x", "y", "rho", "phi", "px", "py", "pt"}
4547
TEMPORAL_COORDS = {"t", "tau", "E", "e", "energy", "M", "m", "mass"}
4648
LONGITUDINAL_COORDS = {"z", "theta", "eta", "pz"}
4749

@@ -204,6 +206,25 @@ def _get_sympy_class(coords):
204206
return vector.MomentumSympy4D if has_momentum else vector.VectorSympy4D
205207

206208

209+
def _numba_obj(combo):
210+
"""Create vector.obj inside a jitted function with the given coordinates."""
211+
kwargs = ", ".join(f"{c}=1.0" for c in combo)
212+
local_ns = {"vector": vector, "numba": numba}
213+
exec(f"@numba.njit\ndef f():\n return vector.obj({kwargs})", local_ns)
214+
return local_ns["f"]
215+
216+
217+
def _will_numba_error(combo):
218+
"""Check if numba will error. Numba errors on duplicates, no valid azimuthal, or extra azimuthal coords."""
219+
if _has_duplicate(combo):
220+
return True
221+
if not _has_valid_2_subset(combo):
222+
return True
223+
# Numba errors if there are more than 2 azimuthal coordinates (canonical form)
224+
canonical_azimuthal = {_to_canonical(c) for c in combo if c in AZIMUTHAL_COORDS}
225+
return len(canonical_azimuthal) > 2
226+
227+
207228
@pytest.mark.parametrize(
208229
"combo",
209230
ALL_2_COMBINATIONS,
@@ -214,14 +235,18 @@ def test_2_combinations(combo):
214235
is_momentum = _is_momentum(combo)
215236
error_pattern = "duplicate coordinates|unrecognized combination|must have a structured dtype|specify"
216237

238+
numba_error_pattern = "duplicate coordinates|unrecognized combination"
239+
217240
if is_valid:
218241
v_obj = vector.obj(**dict.fromkeys(combo, 1.0))
242+
v_numba = _numba_obj(combo)()
219243
v_numpy = vector.array({c: np.array([1.0, 2.0]) for c in combo})
220244
v_awkward = vector.Array(ak.Array({c: [1.0, 2.0] for c in combo}))
221245
v_zip = vector.zip({c: np.array([1.0, 2.0]) for c in combo})
222246
v_sympy = _get_sympy_class(combo)(**{c: sympy.Symbol(c) for c in combo})
223247

224248
assert isinstance(v_obj, Momentum) == is_momentum
249+
assert isinstance(v_numba, Momentum) == is_momentum
225250
assert isinstance(v_numpy, Momentum) == is_momentum
226251
assert isinstance(v_awkward, Momentum) == is_momentum
227252
assert isinstance(v_zip, Momentum) == is_momentum
@@ -230,6 +255,9 @@ def test_2_combinations(combo):
230255
with pytest.raises(TypeError, match=error_pattern):
231256
vector.obj(**dict.fromkeys(combo, 1.0))
232257

258+
with pytest.raises(numba.TypingError, match=numba_error_pattern):
259+
_numba_obj(combo)()
260+
233261
with pytest.raises(TypeError, match=error_pattern):
234262
vector.array({c: np.array([1.0, 2.0]) for c in combo})
235263

@@ -253,26 +281,37 @@ def test_3_combinations(combo):
253281
has_valid_2 = _has_valid_2_subset(combo)
254282
is_momentum = _is_momentum(combo)
255283
error_pattern = "duplicate coordinates|unrecognized combination|must have a structured dtype|specify"
284+
numba_error_pattern = "duplicate coordinates|unrecognized combination"
256285

257286
if is_valid:
258287
v_obj = vector.obj(**dict.fromkeys(combo, 1.0))
288+
v_numba = _numba_obj(combo)()
259289
v_numpy = vector.array({c: np.array([1.0, 2.0]) for c in combo})
260290
v_awkward = vector.Array(ak.Array({c: [1.0, 2.0] for c in combo}))
261291
v_zip = vector.zip({c: np.array([1.0, 2.0]) for c in combo})
262292
v_sympy = _get_sympy_class(combo)(**{c: sympy.Symbol(c) for c in combo})
263293

264294
assert isinstance(v_obj, Momentum) == is_momentum
295+
assert isinstance(v_numba, Momentum) == is_momentum
265296
assert isinstance(v_numpy, Momentum) == is_momentum
266297
assert isinstance(v_awkward, Momentum) == is_momentum
267298
assert isinstance(v_zip, Momentum) == is_momentum
268299
assert isinstance(v_sympy, Momentum) == is_momentum
269300
else:
301+
# obj and sympy are strict - always error for invalid combos
270302
with pytest.raises(TypeError, match=error_pattern):
271303
vector.obj(**dict.fromkeys(combo, 1.0))
272304

273305
with pytest.raises(TypeError, match=error_pattern):
274306
_get_sympy_class(combo)(**{c: sympy.Symbol(c) for c in combo})
275307

308+
# numba is permissive like numpy/awkward/zip
309+
if _will_numba_error(combo):
310+
with pytest.raises(numba.TypingError, match=numba_error_pattern):
311+
_numba_obj(combo)()
312+
else:
313+
_numba_obj(combo)()
314+
276315
if has_valid_2 and not _will_error_for_non_obj(combo):
277316
# numpy/awkward/zip create a 2D vector with extra fields
278317
v_numpy = vector.array({c: np.array([1.0, 2.0]) for c in combo})
@@ -304,15 +343,18 @@ def test_4_combinations(combo):
304343
has_valid_2 = _has_valid_2_subset(combo)
305344
is_momentum = _is_momentum(combo)
306345
error_pattern = "duplicate coordinates|unrecognized combination|must have a structured dtype|specify"
346+
numba_error_pattern = "duplicate coordinates|unrecognized combination"
307347

308348
if is_valid:
309349
v_obj = vector.obj(**dict.fromkeys(combo, 1.0))
350+
v_numba = _numba_obj(combo)()
310351
v_numpy = vector.array({c: np.array([1.0, 2.0]) for c in combo})
311352
v_awkward = vector.Array(ak.Array({c: [1.0, 2.0] for c in combo}))
312353
v_zip = vector.zip({c: np.array([1.0, 2.0]) for c in combo})
313354
v_sympy = _get_sympy_class(combo)(**{c: sympy.Symbol(c) for c in combo})
314355

315356
assert isinstance(v_obj, Momentum) == is_momentum
357+
assert isinstance(v_numba, Momentum) == is_momentum
316358
assert isinstance(v_numpy, Momentum) == is_momentum
317359
assert isinstance(v_awkward, Momentum) == is_momentum
318360
assert isinstance(v_zip, Momentum) == is_momentum
@@ -325,6 +367,13 @@ def test_4_combinations(combo):
325367
with pytest.raises(TypeError, match=error_pattern):
326368
_get_sympy_class(combo)(**{c: sympy.Symbol(c) for c in combo})
327369

370+
# numba is permissive like numpy/awkward/zip
371+
if _will_numba_error(combo):
372+
with pytest.raises(numba.TypingError, match=numba_error_pattern):
373+
_numba_obj(combo)()
374+
else:
375+
_numba_obj(combo)()
376+
328377
if has_valid_3 and not _will_error_for_non_obj(combo):
329378
# numpy/awkward/zip create a 3D vector with extra fields
330379
v_numpy = vector.array({c: np.array([1.0, 2.0]) for c in combo})

0 commit comments

Comments
 (0)