Skip to content

Commit 8943c97

Browse files
ENH: Allow JIT compilation with an internal API (#61032)
1 parent ddd0aa8 commit 8943c97

File tree

7 files changed

+341
-51
lines changed

7 files changed

+341
-51
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Other enhancements
6767
- :class:`Rolling` and :class:`Expanding` now support aggregations ``first`` and ``last`` (:issue:`33155`)
6868
- :func:`read_parquet` accepts ``to_pandas_kwargs`` which are forwarded to :meth:`pyarrow.Table.to_pandas` which enables passing additional keywords to customize the conversion to pandas, such as ``maps_as_pydicts`` to read the Parquet map data type as python dictionaries (:issue:`56842`)
6969
- :meth:`.DataFrameGroupBy.transform`, :meth:`.SeriesGroupBy.transform`, :meth:`.DataFrameGroupBy.agg`, :meth:`.SeriesGroupBy.agg`, :meth:`.SeriesGroupBy.apply`, :meth:`.DataFrameGroupBy.apply` now support ``kurt`` (:issue:`40139`)
70+
- :meth:`DataFrame.apply` supports using third-party execution engines like the Bodo.ai JIT compiler (:issue:`60668`)
7071
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
7172
- :meth:`Rolling.agg`, :meth:`Expanding.agg` and :meth:`ExponentialMovingWindow.agg` now accept :class:`NamedAgg` aggregations through ``**kwargs`` (:issue:`28333`)
7273
- :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`)

pandas/api/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""public toolkit API"""
22

33
from pandas.api import (
4+
executors,
45
extensions,
56
indexers,
67
interchange,
@@ -9,6 +10,7 @@
910
)
1011

1112
__all__ = [
13+
"executors",
1214
"extensions",
1315
"indexers",
1416
"interchange",

pandas/api/executors/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
Public API for function executor engines to be used with ``map`` and ``apply``.
3+
"""
4+
5+
from pandas.core.apply import BaseExecutionEngine
6+
7+
__all__ = ["BaseExecutionEngine"]

pandas/core/apply.py

+104
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,110 @@
7474
ResType = dict[int, Any]
7575

7676

77+
class BaseExecutionEngine(abc.ABC):
78+
"""
79+
Base class for execution engines for map and apply methods.
80+
81+
An execution engine receives all the parameters of a call to
82+
``apply`` or ``map``, such as the data container, the function,
83+
etc. and takes care of running the execution.
84+
85+
Supporting different engines allows functions to be JIT compiled,
86+
run in parallel, and others. Besides the default executor which
87+
simply runs the code with the Python interpreter and pandas.
88+
"""
89+
90+
@staticmethod
91+
@abc.abstractmethod
92+
def map(
93+
data: Series | DataFrame | np.ndarray,
94+
func: AggFuncType,
95+
args: tuple,
96+
kwargs: dict[str, Any],
97+
decorator: Callable | None,
98+
skip_na: bool,
99+
):
100+
"""
101+
Executor method to run functions elementwise.
102+
103+
In general, pandas uses ``map`` for running functions elementwise,
104+
but ``Series.apply`` with the default ``by_row='compat'`` will also
105+
call this executor function.
106+
107+
Parameters
108+
----------
109+
data : Series, DataFrame or NumPy ndarray
110+
The object to use for the data. Some methods implement a ``raw``
111+
parameter which will convert the original pandas object to a
112+
NumPy array, which will then be passed here to the executor.
113+
func : function or NumPy ufunc
114+
The function to execute.
115+
args : tuple
116+
Positional arguments to be passed to ``func``.
117+
kwargs : dict
118+
Keyword arguments to be passed to ``func``.
119+
decorator : function, optional
120+
For JIT compilers and other engines that need to decorate the
121+
function ``func``, this is the decorator to use. While the
122+
executor may already know which is the decorator to use, this
123+
is useful as for a single executor the user can specify for
124+
example ``numba.jit`` or ``numba.njit(nogil=True)``, and this
125+
decorator parameter will contain the exact decorator from the
126+
executor the user wants to use.
127+
skip_na : bool
128+
Whether the function should be called for missing values or not.
129+
This is specified by the pandas user as ``map(na_action=None)``
130+
or ``map(na_action='ignore')``.
131+
"""
132+
133+
@staticmethod
134+
@abc.abstractmethod
135+
def apply(
136+
data: Series | DataFrame | np.ndarray,
137+
func: AggFuncType,
138+
args: tuple,
139+
kwargs: dict[str, Any],
140+
decorator: Callable,
141+
axis: Axis,
142+
):
143+
"""
144+
Executor method to run functions by an axis.
145+
146+
While we can see ``map`` as executing the function for each cell
147+
in a ``DataFrame`` (or ``Series``), ``apply`` will execute the
148+
function for each column (or row).
149+
150+
Parameters
151+
----------
152+
data : Series, DataFrame or NumPy ndarray
153+
The object to use for the data. Some methods implement a ``raw``
154+
parameter which will convert the original pandas object to a
155+
NumPy array, which will then be passed here to the executor.
156+
func : function or NumPy ufunc
157+
The function to execute.
158+
args : tuple
159+
Positional arguments to be passed to ``func``.
160+
kwargs : dict
161+
Keyword arguments to be passed to ``func``.
162+
decorator : function, optional
163+
For JIT compilers and other engines that need to decorate the
164+
function ``func``, this is the decorator to use. While the
165+
executor may already know which is the decorator to use, this
166+
is useful as for a single executor the user can specify for
167+
example ``numba.jit`` or ``numba.njit(nogil=True)``, and this
168+
decorator parameter will contain the exact decorator from the
169+
executor the user wants to use.
170+
axis : {0 or 'index', 1 or 'columns'}
171+
0 or 'index' should execute the function passing each column as
172+
parameter. 1 or 'columns' should execute the function passing
173+
each row as parameter. The default executor engine passes rows
174+
as pandas ``Series``. Other executor engines should probably
175+
expect functions to be implemented this way for compatibility.
176+
But passing rows as other data structures is technically possible
177+
as far as the function ``func`` is implemented accordingly.
178+
"""
179+
180+
77181
def frame_apply(
78182
obj: DataFrame,
79183
func: AggFuncType,

pandas/core/frame.py

+119-36
Original file line numberDiff line numberDiff line change
@@ -10275,7 +10275,7 @@ def apply(
1027510275
result_type: Literal["expand", "reduce", "broadcast"] | None = None,
1027610276
args=(),
1027710277
by_row: Literal[False, "compat"] = "compat",
10278-
engine: Literal["python", "numba"] = "python",
10278+
engine: Callable | None | Literal["python", "numba"] = None,
1027910279
engine_kwargs: dict[str, bool] | None = None,
1028010280
**kwargs,
1028110281
):
@@ -10339,35 +10339,32 @@ def apply(
1033910339
1034010340
.. versionadded:: 2.1.0
1034110341
10342-
engine : {'python', 'numba'}, default 'python'
10343-
Choose between the python (default) engine or the numba engine in apply.
10342+
engine : decorator or {'python', 'numba'}, optional
10343+
Choose the execution engine to use. If not provided the function
10344+
will be executed by the regular Python interpreter.
1034410345
10345-
The numba engine will attempt to JIT compile the passed function,
10346-
which may result in speedups for large DataFrames.
10347-
It also supports the following engine_kwargs :
10346+
Other options include JIT compilers such Numba and Bodo, which in some
10347+
cases can speed up the execution. To use an executor you can provide
10348+
the decorators ``numba.jit``, ``numba.njit`` or ``bodo.jit``. You can
10349+
also provide the decorator with parameters, like ``numba.jit(nogit=True)``.
1034810350
10349-
- nopython (compile the function in nopython mode)
10350-
- nogil (release the GIL inside the JIT compiled function)
10351-
- parallel (try to apply the function in parallel over the DataFrame)
10351+
Not all functions can be executed with all execution engines. In general,
10352+
JIT compilers will require type stability in the function (no variable
10353+
should change data type during the execution). And not all pandas and
10354+
NumPy APIs are supported. Check the engine documentation [1]_ and [2]_
10355+
for limitations.
1035210356
10353-
Note: Due to limitations within numba/how pandas interfaces with numba,
10354-
you should only use this if raw=True
10355-
10356-
Note: The numba compiler only supports a subset of
10357-
valid Python/numpy operations.
10357+
.. warning::
1035810358
10359-
Please read more about the `supported python features
10360-
<https://numba.pydata.org/numba-doc/dev/reference/pysupported.html>`_
10361-
and `supported numpy features
10362-
<https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html>`_
10363-
in numba to learn what you can or cannot use in the passed function.
10359+
String parameters will stop being supported in a future pandas version.
1036410360
1036510361
.. versionadded:: 2.2.0
1036610362
1036710363
engine_kwargs : dict
1036810364
Pass keyword arguments to the engine.
1036910365
This is currently only used by the numba engine,
1037010366
see the documentation for the engine argument for more information.
10367+
1037110368
**kwargs
1037210369
Additional keyword arguments to pass as keywords arguments to
1037310370
`func`.
@@ -10390,6 +10387,13 @@ def apply(
1039010387
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
1039110388
for more details.
1039210389
10390+
References
10391+
----------
10392+
.. [1] `Numba documentation
10393+
<https://numba.readthedocs.io/en/stable/index.html>`_
10394+
.. [2] `Bodo documentation
10395+
<https://docs.bodo.ai/latest/>`/
10396+
1039310397
Examples
1039410398
--------
1039510399
>>> df = pd.DataFrame([[4, 9]] * 3, columns=["A", "B"])
@@ -10458,22 +10462,99 @@ def apply(
1045810462
0 1 2
1045910463
1 1 2
1046010464
2 1 2
10465+
10466+
Advanced users can speed up their code by using a Just-in-time (JIT) compiler
10467+
with ``apply``. The main JIT compilers available for pandas are Numba and Bodo.
10468+
In general, JIT compilation is only possible when the function passed to
10469+
``apply`` has type stability (variables in the function do not change their
10470+
type during the execution).
10471+
10472+
>>> import bodo
10473+
>>> df.apply(lambda x: x.A + x.B, axis=1, engine=bodo.jit)
10474+
10475+
Note that JIT compilation is only recommended for functions that take a
10476+
significant amount of time to run. Fast functions are unlikely to run faster
10477+
with JIT compilation.
1046110478
"""
10462-
from pandas.core.apply import frame_apply
10479+
if engine is None or isinstance(engine, str):
10480+
from pandas.core.apply import frame_apply
1046310481

10464-
op = frame_apply(
10465-
self,
10466-
func=func,
10467-
axis=axis,
10468-
raw=raw,
10469-
result_type=result_type,
10470-
by_row=by_row,
10471-
engine=engine,
10472-
engine_kwargs=engine_kwargs,
10473-
args=args,
10474-
kwargs=kwargs,
10475-
)
10476-
return op.apply().__finalize__(self, method="apply")
10482+
if engine is None:
10483+
engine = "python"
10484+
10485+
if engine not in ["python", "numba"]:
10486+
raise ValueError(f"Unknown engine '{engine}'")
10487+
10488+
op = frame_apply(
10489+
self,
10490+
func=func,
10491+
axis=axis,
10492+
raw=raw,
10493+
result_type=result_type,
10494+
by_row=by_row,
10495+
engine=engine,
10496+
engine_kwargs=engine_kwargs,
10497+
args=args,
10498+
kwargs=kwargs,
10499+
)
10500+
return op.apply().__finalize__(self, method="apply")
10501+
elif hasattr(engine, "__pandas_udf__"):
10502+
if result_type is not None:
10503+
raise NotImplementedError(
10504+
f"{result_type=} only implemented for the default engine"
10505+
)
10506+
10507+
agg_axis = self._get_agg_axis(self._get_axis_number(axis))
10508+
10509+
# one axis is empty
10510+
if not all(self.shape):
10511+
func = cast(Callable, func)
10512+
try:
10513+
if axis == 0:
10514+
r = func(Series([], dtype=np.float64), *args, **kwargs)
10515+
else:
10516+
r = func(
10517+
Series(index=self.columns, dtype=np.float64),
10518+
*args,
10519+
**kwargs,
10520+
)
10521+
except Exception:
10522+
pass
10523+
else:
10524+
if not isinstance(r, Series):
10525+
if len(agg_axis):
10526+
r = func(Series([], dtype=np.float64), *args, **kwargs)
10527+
else:
10528+
r = np.nan
10529+
10530+
return self._constructor_sliced(r, index=agg_axis)
10531+
return self.copy()
10532+
10533+
data: DataFrame | np.ndarray = self
10534+
if raw:
10535+
# This will upcast the whole DataFrame to the same type,
10536+
# and likely result in an object 2D array.
10537+
# We should probably pass a list of 1D arrays instead, at
10538+
# lest for ``axis=0``
10539+
data = self.values
10540+
result = engine.__pandas_udf__.apply(
10541+
data=data,
10542+
func=func,
10543+
args=args,
10544+
kwargs=kwargs,
10545+
decorator=engine,
10546+
axis=axis,
10547+
)
10548+
if raw:
10549+
if result.ndim == 2:
10550+
return self._constructor(
10551+
result, index=self.index, columns=self.columns
10552+
)
10553+
else:
10554+
return self._constructor_sliced(result, index=agg_axis)
10555+
return result
10556+
else:
10557+
raise ValueError(f"Unknown engine {engine}")
1047710558

1047810559
def map(
1047910560
self, func: PythonFuncType, na_action: Literal["ignore"] | None = None, **kwargs
@@ -10590,9 +10671,11 @@ def _append(
1059010671

1059110672
index = Index(
1059210673
[other.name],
10593-
name=self.index.names
10594-
if isinstance(self.index, MultiIndex)
10595-
else self.index.name,
10674+
name=(
10675+
self.index.names
10676+
if isinstance(self.index, MultiIndex)
10677+
else self.index.name
10678+
),
1059610679
)
1059710680
row_df = other.to_frame().T
1059810681
# infer_objects is needed for

pandas/tests/api/test_api.py

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pandas import api
77
import pandas._testing as tm
88
from pandas.api import (
9+
executors as api_executors,
910
extensions as api_extensions,
1011
indexers as api_indexers,
1112
interchange as api_interchange,
@@ -243,6 +244,7 @@ def test_depr(self):
243244

244245
class TestApi(Base):
245246
allowed_api_dirs = [
247+
"executors",
246248
"types",
247249
"extensions",
248250
"indexers",
@@ -338,6 +340,7 @@ class TestApi(Base):
338340
"ExtensionArray",
339341
"ExtensionScalarOpsMixin",
340342
]
343+
allowed_api_executors = ["BaseExecutionEngine"]
341344

342345
def test_api(self):
343346
self.check(api, self.allowed_api_dirs)
@@ -357,6 +360,9 @@ def test_api_indexers(self):
357360
def test_api_extensions(self):
358361
self.check(api_extensions, self.allowed_api_extensions)
359362

363+
def test_api_executors(self):
364+
self.check(api_executors, self.allowed_api_executors)
365+
360366

361367
class TestErrors(Base):
362368
def test_errors(self):

0 commit comments

Comments
 (0)