1515
1616ak = pytest .importorskip ("awkward" )
1717sympy = pytest .importorskip ("sympy" )
18+ numba = pytest .importorskip ("numba" )
19+
20+ pytestmark = [pytest .mark .awkward , pytest .mark .sympy , pytest .mark .numba ]
1821
19- pytestmark = [pytest .mark .awkward , pytest .mark .sympy ]
2022
2123ALL_COORDINATES = [
2224 "x" ,
4143]
4244
4345MOMENTUM_COORDINATES = {"px" , "py" , "pz" , "pt" , "E" , "e" , "energy" , "M" , "m" , "mass" }
44-
46+ AZIMUTHAL_COORDS = { "x" , "y" , "rho" , "phi" , "px" , "py" , "pt" }
4547TEMPORAL_COORDS = {"t" , "tau" , "E" , "e" , "energy" , "M" , "m" , "mass" }
4648LONGITUDINAL_COORDS = {"z" , "theta" , "eta" , "pz" }
4749
@@ -204,6 +206,25 @@ def _get_sympy_class(coords):
204206 return vector .MomentumSympy4D if has_momentum else vector .VectorSympy4D
205207
206208
209+ def _numba_obj (combo ):
210+ """Create vector.obj inside a jitted function with the given coordinates."""
211+ kwargs = ", " .join (f"{ c } =1.0" for c in combo )
212+ local_ns = {"vector" : vector , "numba" : numba }
213+ exec (f"@numba.njit\n def f():\n return vector.obj({ kwargs } )" , local_ns )
214+ return local_ns ["f" ]
215+
216+
217+ def _will_numba_error (combo ):
218+ """Check if numba will error. Numba errors on duplicates, no valid azimuthal, or extra azimuthal coords."""
219+ if _has_duplicate (combo ):
220+ return True
221+ if not _has_valid_2_subset (combo ):
222+ return True
223+ # Numba errors if there are more than 2 azimuthal coordinates (canonical form)
224+ canonical_azimuthal = {_to_canonical (c ) for c in combo if c in AZIMUTHAL_COORDS }
225+ return len (canonical_azimuthal ) > 2
226+
227+
207228@pytest .mark .parametrize (
208229 "combo" ,
209230 ALL_2_COMBINATIONS ,
@@ -214,14 +235,18 @@ def test_2_combinations(combo):
214235 is_momentum = _is_momentum (combo )
215236 error_pattern = "duplicate coordinates|unrecognized combination|must have a structured dtype|specify"
216237
238+ numba_error_pattern = "duplicate coordinates|unrecognized combination"
239+
217240 if is_valid :
218241 v_obj = vector .obj (** dict .fromkeys (combo , 1.0 ))
242+ v_numba = _numba_obj (combo )()
219243 v_numpy = vector .array ({c : np .array ([1.0 , 2.0 ]) for c in combo })
220244 v_awkward = vector .Array (ak .Array ({c : [1.0 , 2.0 ] for c in combo }))
221245 v_zip = vector .zip ({c : np .array ([1.0 , 2.0 ]) for c in combo })
222246 v_sympy = _get_sympy_class (combo )(** {c : sympy .Symbol (c ) for c in combo })
223247
224248 assert isinstance (v_obj , Momentum ) == is_momentum
249+ assert isinstance (v_numba , Momentum ) == is_momentum
225250 assert isinstance (v_numpy , Momentum ) == is_momentum
226251 assert isinstance (v_awkward , Momentum ) == is_momentum
227252 assert isinstance (v_zip , Momentum ) == is_momentum
@@ -230,6 +255,9 @@ def test_2_combinations(combo):
230255 with pytest .raises (TypeError , match = error_pattern ):
231256 vector .obj (** dict .fromkeys (combo , 1.0 ))
232257
258+ with pytest .raises (numba .TypingError , match = numba_error_pattern ):
259+ _numba_obj (combo )()
260+
233261 with pytest .raises (TypeError , match = error_pattern ):
234262 vector .array ({c : np .array ([1.0 , 2.0 ]) for c in combo })
235263
@@ -253,26 +281,37 @@ def test_3_combinations(combo):
253281 has_valid_2 = _has_valid_2_subset (combo )
254282 is_momentum = _is_momentum (combo )
255283 error_pattern = "duplicate coordinates|unrecognized combination|must have a structured dtype|specify"
284+ numba_error_pattern = "duplicate coordinates|unrecognized combination"
256285
257286 if is_valid :
258287 v_obj = vector .obj (** dict .fromkeys (combo , 1.0 ))
288+ v_numba = _numba_obj (combo )()
259289 v_numpy = vector .array ({c : np .array ([1.0 , 2.0 ]) for c in combo })
260290 v_awkward = vector .Array (ak .Array ({c : [1.0 , 2.0 ] for c in combo }))
261291 v_zip = vector .zip ({c : np .array ([1.0 , 2.0 ]) for c in combo })
262292 v_sympy = _get_sympy_class (combo )(** {c : sympy .Symbol (c ) for c in combo })
263293
264294 assert isinstance (v_obj , Momentum ) == is_momentum
295+ assert isinstance (v_numba , Momentum ) == is_momentum
265296 assert isinstance (v_numpy , Momentum ) == is_momentum
266297 assert isinstance (v_awkward , Momentum ) == is_momentum
267298 assert isinstance (v_zip , Momentum ) == is_momentum
268299 assert isinstance (v_sympy , Momentum ) == is_momentum
269300 else :
301+ # obj and sympy are strict - always error for invalid combos
270302 with pytest .raises (TypeError , match = error_pattern ):
271303 vector .obj (** dict .fromkeys (combo , 1.0 ))
272304
273305 with pytest .raises (TypeError , match = error_pattern ):
274306 _get_sympy_class (combo )(** {c : sympy .Symbol (c ) for c in combo })
275307
308+ # numba is permissive like numpy/awkward/zip
309+ if _will_numba_error (combo ):
310+ with pytest .raises (numba .TypingError , match = numba_error_pattern ):
311+ _numba_obj (combo )()
312+ else :
313+ _numba_obj (combo )()
314+
276315 if has_valid_2 and not _will_error_for_non_obj (combo ):
277316 # numpy/awkward/zip create a 2D vector with extra fields
278317 v_numpy = vector .array ({c : np .array ([1.0 , 2.0 ]) for c in combo })
@@ -304,15 +343,18 @@ def test_4_combinations(combo):
304343 has_valid_2 = _has_valid_2_subset (combo )
305344 is_momentum = _is_momentum (combo )
306345 error_pattern = "duplicate coordinates|unrecognized combination|must have a structured dtype|specify"
346+ numba_error_pattern = "duplicate coordinates|unrecognized combination"
307347
308348 if is_valid :
309349 v_obj = vector .obj (** dict .fromkeys (combo , 1.0 ))
350+ v_numba = _numba_obj (combo )()
310351 v_numpy = vector .array ({c : np .array ([1.0 , 2.0 ]) for c in combo })
311352 v_awkward = vector .Array (ak .Array ({c : [1.0 , 2.0 ] for c in combo }))
312353 v_zip = vector .zip ({c : np .array ([1.0 , 2.0 ]) for c in combo })
313354 v_sympy = _get_sympy_class (combo )(** {c : sympy .Symbol (c ) for c in combo })
314355
315356 assert isinstance (v_obj , Momentum ) == is_momentum
357+ assert isinstance (v_numba , Momentum ) == is_momentum
316358 assert isinstance (v_numpy , Momentum ) == is_momentum
317359 assert isinstance (v_awkward , Momentum ) == is_momentum
318360 assert isinstance (v_zip , Momentum ) == is_momentum
@@ -325,6 +367,13 @@ def test_4_combinations(combo):
325367 with pytest .raises (TypeError , match = error_pattern ):
326368 _get_sympy_class (combo )(** {c : sympy .Symbol (c ) for c in combo })
327369
370+ # numba is permissive like numpy/awkward/zip
371+ if _will_numba_error (combo ):
372+ with pytest .raises (numba .TypingError , match = numba_error_pattern ):
373+ _numba_obj (combo )()
374+ else :
375+ _numba_obj (combo )()
376+
328377 if has_valid_3 and not _will_error_for_non_obj (combo ):
329378 # numpy/awkward/zip create a 3D vector with extra fields
330379 v_numpy = vector .array ({c : np .array ([1.0 , 2.0 ]) for c in combo })
0 commit comments