Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4496,3 +4496,165 @@ def _flavor_of(*objects: VectorProtocol) -> type[VectorProtocol]:
return handler.MomentumClass
else:
return handler.GenericClass


def _validate_coordinates(fieldnames: tuple[str, ...]) -> None:
"""
Validate coordinate field names for constructing vectors.

This follows the same logic as _check_names in awkward_constructors to ensure
consistent validation across backends.

Raises TypeError if duplicate or conflicting coordinates are detected.
"""
complaint1 = "duplicate coordinates (through momentum-aliases): " + ", ".join(
repr(x) for x in fieldnames
)
complaint2 = (
"unrecognized combination of coordinates, allowed combinations are:\n\n"
" (2D) x= y=\n"
" (2D) rho= phi=\n"
" (3D) x= y= z=\n"
" (3D) x= y= theta=\n"
" (3D) x= y= eta=\n"
" (3D) rho= phi= z=\n"
" (3D) rho= phi= theta=\n"
" (3D) rho= phi= eta=\n"
" (4D) x= y= z= t=\n"
" (4D) x= y= z= tau=\n"
" (4D) x= y= theta= t=\n"
" (4D) x= y= theta= tau=\n"
" (4D) x= y= eta= t=\n"
" (4D) x= y= eta= tau=\n"
" (4D) rho= phi= z= t=\n"
" (4D) rho= phi= z= tau=\n"
" (4D) rho= phi= theta= t=\n"
" (4D) rho= phi= theta= tau=\n"
" (4D) rho= phi= eta= t=\n"
" (4D) rho= phi= eta= tau="
)

is_momentum = False
dimension = 0
fieldnames_copy = list(fieldnames)

# 2D azimuthal coordinates
if "x" in fieldnames_copy and "y" in fieldnames_copy:
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("x")
fieldnames_copy.remove("y")
if "rho" in fieldnames_copy and "phi" in fieldnames_copy:
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("rho")
fieldnames_copy.remove("phi")
if "x" in fieldnames_copy and "py" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("x")
fieldnames_copy.remove("py")
if "px" in fieldnames_copy and "y" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("px")
fieldnames_copy.remove("y")
if "px" in fieldnames_copy and "py" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("px")
fieldnames_copy.remove("py")
if "pt" in fieldnames_copy and "phi" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("pt")
fieldnames_copy.remove("phi")

# 3D longitudinal coordinates
if "z" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("z")
if "theta" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("theta")
if "eta" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("eta")
if "pz" in fieldnames_copy:
is_momentum = True
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("pz")

# 4D temporal coordinates
if "t" in fieldnames_copy:
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("t")
if "tau" in fieldnames_copy:
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("tau")
if "E" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("E")
if "e" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("e")
if "energy" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("energy")
if "M" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("M")
if "m" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("m")
if "mass" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("mass")

# Check if any remaining fieldnames would conflict with already-processed coordinates
# when mapped to generic names (e.g., pt was processed, rho shouldn't remain)
if fieldnames_copy:
# Map all original fieldnames to generic names to detect conflicts
generic_names = [_repr_momentum_to_generic.get(x, x) for x in fieldnames]
if len(generic_names) != len(set(generic_names)):
raise TypeError(complaint1 if is_momentum else complaint2)
11 changes: 11 additions & 0 deletions src/vector/backends/awkward_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy

from vector._methods import _repr_momentum_to_generic


def _recname(is_momentum: bool, dimension: int) -> str:
name = "Momentum" if is_momentum else "Vector"
Expand Down Expand Up @@ -49,6 +51,7 @@ def _check_names(
dimension = 0
names = []
columns = []
fieldnames_orig = list(fieldnames)

if "x" in fieldnames and "y" in fieldnames:
if dimension != 0:
Expand Down Expand Up @@ -199,6 +202,14 @@ def _check_names(
if dimension == 0:
raise TypeError(complaint1 if is_momentum else complaint2)

# Check if any remaining fieldnames would conflict with already-processed coordinates
# when mapped to generic names (e.g., pt was processed, rho shouldn't remain)
if fieldnames:
# Map all original fieldnames to generic names to detect conflicts
generic_names = [_repr_momentum_to_generic.get(x, x) for x in fieldnames_orig]
if len(generic_names) != len(set(generic_names)):
raise TypeError(complaint1 if is_momentum else complaint2)

for name in fieldnames:
names.append(name)
columns.append(projectable[name])
Expand Down
15 changes: 15 additions & 0 deletions src/vector/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_ltype,
_repr_momentum_to_generic,
_ttype,
_validate_coordinates,
)
from vector._typeutils import BoolCollection, FloatArray, ScalarCollection

Expand Down Expand Up @@ -1190,6 +1191,8 @@ def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

_validate_coordinates(self.dtype.names or ())

if _has(self, ("x", "y")):
self._azimuthal_type = AzimuthalNumpyXY
elif _has(self, ("rho", "phi")):
Expand Down Expand Up @@ -1363,6 +1366,8 @@ def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

_validate_coordinates(self.dtype.names or ())

self.dtype.names = tuple(
_repr_momentum_to_generic.get(x, x) for x in (self.dtype.names or ())
)
Expand Down Expand Up @@ -1431,6 +1436,8 @@ def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

_validate_coordinates(self.dtype.names or ())

if _has(self, ("x", "y")):
self._azimuthal_type = AzimuthalNumpyXY
elif _has(self, ("rho", "phi")):
Expand Down Expand Up @@ -1664,6 +1671,8 @@ def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

_validate_coordinates(self.dtype.names or ())

self.dtype.names = tuple(
_repr_momentum_to_generic.get(x, x) for x in (self.dtype.names or ())
)
Expand Down Expand Up @@ -1745,6 +1754,8 @@ def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

_validate_coordinates(self.dtype.names or ())

if _has(self, ("x", "y")):
self._azimuthal_type = AzimuthalNumpyXY
elif _has(self, ("rho", "phi")):
Expand Down Expand Up @@ -2047,6 +2058,8 @@ def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

_validate_coordinates(self.dtype.names or ())

self.dtype.names = tuple(
_repr_momentum_to_generic.get(x, x) for x in (self.dtype.names or ())
)
Expand Down Expand Up @@ -2159,6 +2172,8 @@ def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy:

is_momentum = any(x in _repr_momentum_to_generic for x in names)

_validate_coordinates(names)

if any(x in ("t", "E", "e", "energy", "tau", "M", "m", "mass") for x in names):
cls = MomentumNumpy4D if is_momentum else VectorNumpy4D
elif any(x in ("z", "pz", "theta", "eta") for x in names):
Expand Down
4 changes: 2 additions & 2 deletions src/vector/backends/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,7 +3230,7 @@ def obj(**coordinates: float) -> VectorObject:
if "E" in coordinates:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also have the and "t" not in generic_coordinates condition?

Copy link
Author

@ikrommyd ikrommyd Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this unnecessary? This is the first if condition where generic_coordinates might be populated with t. So and "t" not in generic_coordinates is a useless check no? Same goes for tau in your other comment below.

Copy link
Author

@ikrommyd ikrommyd Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can add it for "visual OCD reasons" but it's a check that's always going to be false right?

is_momentum = True
generic_coordinates["t"] = coordinates.pop("E")
if "e" in coordinates:
if "e" in coordinates and "t" not in generic_coordinates:
is_momentum = True
generic_coordinates["t"] = coordinates.pop("e")
if "energy" in coordinates and "t" not in generic_coordinates:
Expand All @@ -3239,7 +3239,7 @@ def obj(**coordinates: float) -> VectorObject:
if "M" in coordinates:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, for tau

is_momentum = True
generic_coordinates["tau"] = coordinates.pop("M")
if "m" in coordinates:
if "m" in coordinates and "tau" not in generic_coordinates:
is_momentum = True
generic_coordinates["tau"] = coordinates.pop("m")
if "mass" in coordinates and "tau" not in generic_coordinates:
Expand Down
7 changes: 7 additions & 0 deletions src/vector/backends/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_repr_generic_to_momentum,
_repr_momentum_to_generic,
_ttype,
_validate_coordinates,
)


Expand Down Expand Up @@ -744,6 +745,8 @@ class VectorSympy2D(VectorSympy, Planar, Vector2D):
azimuthal: AzimuthalSympy

def __init__(self, azimuthal: AzimuthalSympy | None = None, **kwargs: sympy.Symbol):
_validate_coordinates(tuple(kwargs))

for k, v in kwargs.copy().items():
kwargs.pop(k)
kwargs[_repr_momentum_to_generic.get(k, k)] = v
Expand Down Expand Up @@ -945,6 +948,8 @@ def __init__(
longitudinal: LongitudinalSympy | None = None,
**kwargs: sympy.Symbol,
):
_validate_coordinates(tuple(kwargs))

for k, v in kwargs.copy().items():
kwargs.pop(k)
kwargs[_repr_momentum_to_generic.get(k, k)] = v
Expand Down Expand Up @@ -1219,6 +1224,8 @@ def __init__(
temporal: TemporalSympy | None = None,
**kwargs: sympy.Symbol,
):
_validate_coordinates(tuple(kwargs))

for k, v in kwargs.copy().items():
kwargs.pop(k)
kwargs[_repr_momentum_to_generic.get(k, k)] = v
Expand Down
Loading