Skip to content

Commit 16cf978

Browse files
committed
✨ Add support for RollingGroupby, ExpandingGroupby
1 parent 28074c1 commit 16cf978

4 files changed

Lines changed: 273 additions & 0 deletions

File tree

src/mapply/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,7 @@ def init(
9595
)
9696

9797
setattr(PandasObject, apply_name, apply)
98+
99+
from pandas.core.window.rolling import BaseWindowGroupby
100+
101+
setattr(BaseWindowGroupby, apply_name, apply)

src/mapply/_window_groupby.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# BSD 3-Clause License
2+
#
3+
# Copyright (c) 2024, ddelange, <ddelange@delange.dev>
4+
#
5+
# All rights reserved.
6+
#
7+
# Redistribution and use in source and binary forms, with or without
8+
# modification, are permitted provided that the following conditions are met:
9+
#
10+
# 1. Redistributions of source code must retain the above copyright notice, this
11+
# list of conditions and the following disclaimer.
12+
#
13+
# 2. Redistributions in binary form must reproduce the above copyright notice,
14+
# this list of conditions and the following disclaimer in the documentation
15+
# and/or other materials provided with the distribution.
16+
#
17+
# 3. Neither the name of the copyright holder nor the names of its
18+
# contributors may be used to endorse or promote products derived from
19+
# this software without specific prior written permission.
20+
#
21+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
#
32+
# SPDX-License-Identifier: BSD-3-Clause
33+
import logging
34+
from collections.abc import Callable
35+
from typing import Any
36+
37+
from mapply.parallel import multiprocessing_imap, tqdm
38+
39+
logger = logging.getLogger(__name__)
40+
41+
42+
def run_window_groupby_apply(
43+
window_groupby: Any,
44+
func: Callable,
45+
*,
46+
n_workers: int,
47+
progressbar: bool,
48+
args: tuple[Any, ...] = (),
49+
**kwargs: Any,
50+
):
51+
"""Apply func to each group's window in parallel using multiprocessing_imap."""
52+
from pandas import concat
53+
from pandas.core.window.expanding import ExpandingGroupby
54+
from pandas.core.window.rolling import RollingGroupby
55+
56+
if isinstance(window_groupby, ExpandingGroupby):
57+
window_kwargs = {
58+
"min_periods": window_groupby.min_periods,
59+
}
60+
window_method = "expanding"
61+
elif isinstance(window_groupby, RollingGroupby):
62+
window_kwargs = {
63+
"window": window_groupby.window,
64+
"min_periods": window_groupby.min_periods,
65+
"center": window_groupby.center,
66+
"on": window_groupby.on,
67+
"closed": window_groupby.closed,
68+
}
69+
window_method = "rolling"
70+
else:
71+
msg = f"Unsupported window groupby type: {type(window_groupby).__name__}"
72+
raise TypeError(msg)
73+
74+
grouper = window_groupby._grouper # noqa: SLF001
75+
indices = grouper.indices
76+
result_index = grouper.result_index
77+
obj = window_groupby.obj
78+
as_index = window_groupby._as_index # noqa: SLF001
79+
groupby_names = grouper.names
80+
81+
# lazy generator: yield (key, group_slice) without materializing all groups
82+
def _group_iter():
83+
for key in result_index:
84+
yield key, obj.iloc[indices[key]]
85+
86+
def _process_group(key_and_data):
87+
key, group_data = key_and_data
88+
window_obj = getattr(group_data, window_method)(**window_kwargs)
89+
result = window_obj.apply(func, args=args, **kwargs)
90+
return key, result
91+
92+
# generator with length defined (for progressbar)
93+
groups = tqdm(_group_iter(), disable=True, total=len(result_index))
94+
processed = multiprocessing_imap(
95+
_process_group,
96+
groups,
97+
n_workers=n_workers,
98+
progressbar=progressbar,
99+
)
100+
101+
# consume lazily from the multiprocessing_imap generator
102+
keys = []
103+
parts = []
104+
for key, part in processed:
105+
keys.append(key)
106+
parts.append(part)
107+
108+
if not parts:
109+
# delegate to native pandas for the empty case to preserve index dtypes
110+
return window_groupby.apply(func, args=args, **kwargs)
111+
112+
result = concat(parts, keys=keys, names=groupby_names + list(obj.index.names))
113+
114+
if not as_index:
115+
result = result.reset_index(level=list(range(len(groupby_names))))
116+
117+
return result

src/mapply/mapply.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from typing import Any
5151

5252
from mapply._groupby import run_groupwise_apply
53+
from mapply._window_groupby import run_window_groupby_apply
5354
from mapply.parallel import N_CORES, multiprocessing_imap
5455

5556
DEFAULT_CHUNK_SIZE = 100
@@ -120,6 +121,17 @@ def mapply( # noqa: PLR0913
120121
from numpy import arange, array_split
121122
from pandas import Series, concat
122123
from pandas.core.groupby import GroupBy
124+
from pandas.core.window.rolling import BaseWindowGroupby
125+
126+
if isinstance(df_or_series, BaseWindowGroupby):
127+
return run_window_groupby_apply(
128+
df_or_series,
129+
func,
130+
n_workers=n_workers,
131+
progressbar=progressbar,
132+
args=args,
133+
**kwargs,
134+
)
123135

124136
if isinstance(df_or_series, GroupBy):
125137
return run_groupwise_apply(

tests/test_mapply.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,143 @@ def fn(x):
154154
series = pd.Series({"a": list(range(100))})
155155

156156
assert isinstance(series.mapply(sum).iloc[0], np.int64)
157+
158+
159+
def test_rolling_groupby_mapply():
160+
"""Assert RollingGroupby behaviour is equivalent."""
161+
mapply.init(progressbar=False, chunk_size=1)
162+
163+
np.random.seed(42) # noqa: NPY002
164+
df = pd.DataFrame(
165+
{
166+
"A": np.random.randint(0, 100, 200), # noqa: NPY002
167+
"B": np.random.randint(0, 100, 200), # noqa: NPY002
168+
"group": [0] * 100 + [1] * 100,
169+
},
170+
)
171+
172+
# basic RollingGroupby with custom func
173+
pd.testing.assert_frame_equal(
174+
df.groupby("group").rolling(3).apply(lambda x: x.sum()),
175+
df.groupby("group").rolling(3).mapply(lambda x: x.sum()),
176+
)
177+
178+
# min_periods
179+
pd.testing.assert_frame_equal(
180+
df.groupby("group").rolling(5, min_periods=2).apply(lambda x: x.mean()),
181+
df.groupby("group").rolling(5, min_periods=2).mapply(lambda x: x.mean()),
182+
)
183+
184+
# center=True # noqa: ERA001
185+
pd.testing.assert_frame_equal(
186+
df.groupby("group").rolling(3, center=True).apply(lambda x: x.max()),
187+
df.groupby("group").rolling(3, center=True).mapply(lambda x: x.max()),
188+
)
189+
190+
# column selection (Series result)
191+
pd.testing.assert_series_equal(
192+
df.groupby("group")["A"].rolling(3).apply(lambda x: x.sum()),
193+
df.groupby("group")["A"].rolling(3).mapply(lambda x: x.sum()),
194+
)
195+
196+
# as_index=False # noqa: ERA001
197+
pd.testing.assert_frame_equal(
198+
df.groupby("group", as_index=False).rolling(3).apply(lambda x: x.sum()),
199+
df.groupby("group", as_index=False).rolling(3).mapply(lambda x: x.sum()),
200+
)
201+
202+
# multi-level groupby
203+
df["group2"] = list(range(2)) * 100
204+
pd.testing.assert_frame_equal(
205+
df.groupby(["group", "group2"]).rolling(3).apply(lambda x: x.sum()),
206+
df.groupby(["group", "group2"]).rolling(3).mapply(lambda x: x.sum()),
207+
)
208+
209+
# time-based rolling with 'on' parameter
210+
df_ts = pd.DataFrame(
211+
{
212+
"A": np.random.randint(0, 100, 200), # noqa: NPY002
213+
"dt": pd.date_range("2020-01-01", periods=200, freq="D"),
214+
"group": [0] * 100 + [1] * 100,
215+
},
216+
)
217+
pd.testing.assert_frame_equal(
218+
df_ts.groupby("group").rolling("3D", on="dt").apply(lambda x: x.sum()),
219+
df_ts.groupby("group").rolling("3D", on="dt").mapply(lambda x: x.sum()),
220+
)
221+
222+
# empty groupby
223+
pd.testing.assert_frame_equal(
224+
df.iloc[:0].groupby("group").rolling(3).apply(lambda x: x.sum()),
225+
df.iloc[:0].groupby("group").rolling(3).mapply(lambda x: x.sum()),
226+
)
227+
228+
# n_workers=1 (single-process fallback, no pool spawned)
229+
mapply.init(progressbar=False, chunk_size=1, n_workers=1)
230+
pd.testing.assert_frame_equal(
231+
df.groupby("group").rolling(3).apply(lambda x: x.sum()),
232+
df.groupby("group").rolling(3).mapply(lambda x: x.sum()),
233+
)
234+
235+
236+
def test_expanding_groupby_mapply():
237+
"""Assert ExpandingGroupby behaviour is equivalent."""
238+
mapply.init(progressbar=False, chunk_size=1)
239+
240+
np.random.seed(42) # noqa: NPY002
241+
df = pd.DataFrame(
242+
{
243+
"A": np.random.randint(0, 100, 200), # noqa: NPY002
244+
"B": np.random.randint(0, 100, 200), # noqa: NPY002
245+
"group": [0] * 100 + [1] * 100,
246+
},
247+
)
248+
249+
# basic ExpandingGroupby with custom func
250+
pd.testing.assert_frame_equal(
251+
df.groupby("group").expanding().apply(lambda x: x.sum()),
252+
df.groupby("group").expanding().mapply(lambda x: x.sum()),
253+
)
254+
255+
# min_periods
256+
pd.testing.assert_frame_equal(
257+
df.groupby("group").expanding(min_periods=3).apply(lambda x: x.mean()),
258+
df.groupby("group").expanding(min_periods=3).mapply(lambda x: x.mean()),
259+
)
260+
261+
# column selection (Series result)
262+
pd.testing.assert_series_equal(
263+
df.groupby("group")["A"].expanding().apply(lambda x: x.sum()),
264+
df.groupby("group")["A"].expanding().mapply(lambda x: x.sum()),
265+
)
266+
267+
# as_index=False # noqa: ERA001
268+
pd.testing.assert_frame_equal(
269+
df.groupby("group", as_index=False).expanding().apply(lambda x: x.sum()),
270+
df.groupby("group", as_index=False).expanding().mapply(lambda x: x.sum()),
271+
)
272+
273+
# multi-level groupby
274+
df["group2"] = list(range(2)) * 100
275+
pd.testing.assert_frame_equal(
276+
df.groupby(["group", "group2"]).expanding().apply(lambda x: x.sum()),
277+
df.groupby(["group", "group2"]).expanding().mapply(lambda x: x.sum()),
278+
)
279+
280+
# empty groupby
281+
pd.testing.assert_frame_equal(
282+
df.iloc[:0].groupby("group").expanding().apply(lambda x: x.sum()),
283+
df.iloc[:0].groupby("group").expanding().mapply(lambda x: x.sum()),
284+
)
285+
286+
# n_workers=1 (single-process fallback)
287+
mapply.init(progressbar=False, chunk_size=1, n_workers=1)
288+
pd.testing.assert_frame_equal(
289+
df.groupby("group").expanding().apply(lambda x: x.sum()),
290+
df.groupby("group").expanding().mapply(lambda x: x.sum()),
291+
)
292+
293+
# unsupported window groupby type (e.g. EWM)
294+
mapply.init(progressbar=False, chunk_size=1)
295+
with pytest.raises(TypeError, match="Unsupported window groupby type"):
296+
df.groupby("group").ewm(span=3).mapply(lambda x: x.sum())

0 commit comments

Comments
 (0)