Skip to content

Commit

Permalink
perf: reduce python overhead for awkward backend (#554)
Browse files Browse the repository at this point in the history
* awkward backend: reduce python overhead by applying multiple operations in one single broadcasting traversal

* add a fast path for noops, otherwise ak.transform does a non-negligible overhead...
  • Loading branch information
pfackeldey authored Jan 15, 2025
1 parent a61d2ef commit 78cd02f
Show file tree
Hide file tree
Showing 88 changed files with 402 additions and 118 deletions.
3 changes: 2 additions & 1 deletion src/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _import_awkward() -> None:
if not typing.TYPE_CHECKING:
VectorAwkward = None
else:
from vector.backends.awkward import VectorAwkward
from vector.backends.awkward import VectorAwkward, awkward_transform

try:
import sympy # type: ignore[import-untyped]
Expand Down Expand Up @@ -143,6 +143,7 @@ def _import_awkward() -> None:
"arr",
"array",
"awk",
"awkward_transform",
"dim",
"obj",
"register_awkward",
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/Et.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/Et2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/Mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/Mt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/boostX_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def dispatch(beta: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
beta,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/boostX_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def dispatch(gamma: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
gamma,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/boostY_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def dispatch(beta: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
beta,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/boostY_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def dispatch(gamma: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
gamma,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/boostZ_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def dispatch(beta: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
beta,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/boostZ_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def dispatch(gamma: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
gamma,
*v.azimuthal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/boost_beta3.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/boost_p4.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,9 +781,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/deltaRapidityPhi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ def dispatch(
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/deltaRapidityPhi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def dispatch(
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/is_lightlike.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def dispatch(tolerance: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
tolerance,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/is_spacelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dispatch(tolerance: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
tolerance,
*v.azimuthal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/is_timelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dispatch(tolerance: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
tolerance,
*v.azimuthal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/isclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ def dispatch(
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
rtol,
atol,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/not_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/rapidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def dispatch(factor: typing.Any, v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
factor,
*v.azimuthal.elements,
Expand Down
5 changes: 3 additions & 2 deletions src/vector/_compute/lorentz/subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
),
)
with numpy.errstate(all="ignore"):
return _handler_of(v1, v2)._wrap_result(
handler = _handler_of(v1, v2)
return handler._wrap_result(
_flavor_of(v1, v2),
function(
handler._wrap_dispatched_function(function)(
_lib_of(v1, v2),
*v1.azimuthal.elements,
*v1.longitudinal.elements,
Expand Down
20 changes: 19 additions & 1 deletion src/vector/_compute/lorentz/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def xy_z_t(lib, x, y, z, t):
return t


xy_z_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_z_tau(lib, x, y, z, tau):
return lib.sqrt(t2.xy_z_tau(lib, x, y, z, tau))

Expand All @@ -45,6 +48,9 @@ def xy_theta_t(lib, x, y, theta, t):
return t


xy_theta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_theta_tau(lib, x, y, theta, tau):
return lib.sqrt(t2.xy_theta_tau(lib, x, y, theta, tau))

Expand All @@ -53,6 +59,9 @@ def xy_eta_t(lib, x, y, eta, t):
return t


xy_eta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_eta_tau(lib, x, y, eta, tau):
return lib.sqrt(t2.xy_eta_tau(lib, x, y, eta, tau))

Expand All @@ -61,6 +70,9 @@ def rhophi_z_t(lib, rho, phi, z, t):
return t


rhophi_z_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_z_tau(lib, rho, phi, z, tau):
return lib.sqrt(t2.rhophi_z_tau(lib, rho, phi, z, tau))

Expand All @@ -69,6 +81,9 @@ def rhophi_theta_t(lib, rho, phi, theta, t):
return t


rhophi_theta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_theta_tau(lib, rho, phi, theta, tau):
return lib.sqrt(t2.rhophi_theta_tau(lib, rho, phi, theta, tau))

Expand All @@ -77,6 +92,9 @@ def rhophi_eta_t(lib, rho, phi, eta, t):
return t


rhophi_eta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_eta_tau(lib, rho, phi, eta, tau):
return lib.sqrt(t2.rhophi_eta_tau(lib, rho, phi, eta, tau))

Expand Down Expand Up @@ -110,7 +128,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/lorentz/t2.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def dispatch(v: typing.Any) -> typing.Any:
with numpy.errstate(all="ignore"):
return v._wrap_result(
_flavor_of(v),
function(
v._wrap_dispatched_function(function)(
v.lib,
*v.azimuthal.elements,
*v.longitudinal.elements,
Expand Down
Loading

0 comments on commit 78cd02f

Please sign in to comment.