Skip to content

Commit 9501650

Browse files
snitishmroeschke
andauthored
ENH: Support NamedAggs in kwargs in Rolling/Expanding/EWM agg method (#60549)
* ENH: Support NamedAggs in kwargs in Rolling/Expanding/EWM agg method * Pre-commit fix * Fix typing * Fix typing retry * Fix typing retry 2 * Update pandas/core/window/rolling.py Co-authored-by: Matthew Roeschke <[email protected]> * Add type ignore --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 069253d commit 9501650

File tree

5 files changed

+111
-7
lines changed

5 files changed

+111
-7
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Other enhancements
5656
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
5757
- :func:`read_parquet` accepts ``to_pandas_kwargs`` which are forwarded to :meth:`pyarrow.Table.to_pandas` which enables passing additional keywords to customize the conversion to pandas, such as ``maps_as_pydicts`` to read the Parquet map data type as python dictionaries (:issue:`56842`)
5858
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
59+
- :meth:`Rolling.agg`, :meth:`Expanding.agg` and :meth:`ExponentialMovingWindow.agg` now accept :class:`NamedAgg` aggregations through ``**kwargs`` (:issue:`28333`)
5960
- :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`)
6061
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
6162
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)

pandas/core/window/ewm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def online(
490490
klass="Series/Dataframe",
491491
axis="",
492492
)
493-
def aggregate(self, func, *args, **kwargs):
493+
def aggregate(self, func=None, *args, **kwargs):
494494
return super().aggregate(func, *args, **kwargs)
495495

496496
agg = aggregate
@@ -981,7 +981,7 @@ def reset(self) -> None:
981981
"""
982982
self._mean.reset()
983983

984-
def aggregate(self, func, *args, **kwargs):
984+
def aggregate(self, func=None, *args, **kwargs):
985985
raise NotImplementedError("aggregate is not implemented.")
986986

987987
def std(self, bias: bool = False, *args, **kwargs):

pandas/core/window/expanding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _get_window_indexer(self) -> BaseIndexer:
167167
klass="Series/Dataframe",
168168
axis="",
169169
)
170-
def aggregate(self, func, *args, **kwargs):
170+
def aggregate(self, func=None, *args, **kwargs):
171171
return super().aggregate(func, *args, **kwargs)
172172

173173
agg = aggregate

pandas/core/window/rolling.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@
4444

4545
from pandas.core._numba import executor
4646
from pandas.core.algorithms import factorize
47-
from pandas.core.apply import ResamplerWindowApply
47+
from pandas.core.apply import (
48+
ResamplerWindowApply,
49+
reconstruct_func,
50+
)
4851
from pandas.core.arrays import ExtensionArray
4952
from pandas.core.base import SelectionMixin
5053
import pandas.core.common as com
@@ -646,8 +649,12 @@ def _numba_apply(
646649
out = obj._constructor(result, index=index, columns=columns)
647650
return self._resolve_output(out, obj)
648651

649-
def aggregate(self, func, *args, **kwargs):
652+
def aggregate(self, func=None, *args, **kwargs):
653+
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
650654
result = ResamplerWindowApply(self, func, args=args, kwargs=kwargs).agg()
655+
if isinstance(result, ABCDataFrame) and relabeling:
656+
result = result.iloc[:, order]
657+
result.columns = columns # type: ignore[union-attr]
651658
if result is None:
652659
return self.apply(func, raw=False, args=args, kwargs=kwargs)
653660
return result
@@ -1239,7 +1246,7 @@ def calc(x):
12391246
klass="Series/DataFrame",
12401247
axis="",
12411248
)
1242-
def aggregate(self, func, *args, **kwargs):
1249+
def aggregate(self, func=None, *args, **kwargs):
12431250
result = ResamplerWindowApply(self, func, args=args, kwargs=kwargs).agg()
12441251
if result is None:
12451252
# these must apply directly
@@ -1951,7 +1958,7 @@ def _raise_monotonic_error(self, msg: str):
19511958
klass="Series/Dataframe",
19521959
axis="",
19531960
)
1954-
def aggregate(self, func, *args, **kwargs):
1961+
def aggregate(self, func=None, *args, **kwargs):
19551962
return super().aggregate(func, *args, **kwargs)
19561963

19571964
agg = aggregate

pandas/tests/window/test_groupby.py

+96
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DatetimeIndex,
77
Index,
88
MultiIndex,
9+
NamedAgg,
910
Series,
1011
Timestamp,
1112
date_range,
@@ -489,6 +490,36 @@ def test_groupby_rolling_subset_with_closed(self):
489490
)
490491
tm.assert_series_equal(result, expected)
491492

493+
def test_groupby_rolling_agg_namedagg(self):
494+
# GH#28333
495+
df = DataFrame(
496+
{
497+
"kind": ["cat", "dog", "cat", "dog", "cat", "dog"],
498+
"height": [9.1, 6.0, 9.5, 34.0, 12.0, 8.0],
499+
"weight": [7.9, 7.5, 9.9, 198.0, 10.0, 42.0],
500+
}
501+
)
502+
result = (
503+
df.groupby("kind")
504+
.rolling(2)
505+
.agg(
506+
total_weight=NamedAgg(column="weight", aggfunc=sum),
507+
min_height=NamedAgg(column="height", aggfunc=min),
508+
)
509+
)
510+
expected = DataFrame(
511+
{
512+
"total_weight": [np.nan, 17.8, 19.9, np.nan, 205.5, 240.0],
513+
"min_height": [np.nan, 9.1, 9.5, np.nan, 6.0, 8.0],
514+
},
515+
index=MultiIndex(
516+
[["cat", "dog"], [0, 1, 2, 3, 4, 5]],
517+
[[0, 0, 0, 1, 1, 1], [0, 2, 4, 1, 3, 5]],
518+
names=["kind", None],
519+
),
520+
)
521+
tm.assert_frame_equal(result, expected)
522+
492523
def test_groupby_subset_rolling_subset_with_closed(self):
493524
# GH 35549
494525
df = DataFrame(
@@ -1134,6 +1165,36 @@ def test_expanding_apply(self, raw, frame):
11341165
expected.index = expected_index
11351166
tm.assert_frame_equal(result, expected)
11361167

1168+
def test_groupby_expanding_agg_namedagg(self):
1169+
# GH#28333
1170+
df = DataFrame(
1171+
{
1172+
"kind": ["cat", "dog", "cat", "dog", "cat", "dog"],
1173+
"height": [9.1, 6.0, 9.5, 34.0, 12.0, 8.0],
1174+
"weight": [7.9, 7.5, 9.9, 198.0, 10.0, 42.0],
1175+
}
1176+
)
1177+
result = (
1178+
df.groupby("kind")
1179+
.expanding(1)
1180+
.agg(
1181+
total_weight=NamedAgg(column="weight", aggfunc=sum),
1182+
min_height=NamedAgg(column="height", aggfunc=min),
1183+
)
1184+
)
1185+
expected = DataFrame(
1186+
{
1187+
"total_weight": [7.9, 17.8, 27.8, 7.5, 205.5, 247.5],
1188+
"min_height": [9.1, 9.1, 9.1, 6.0, 6.0, 6.0],
1189+
},
1190+
index=MultiIndex(
1191+
[["cat", "dog"], [0, 1, 2, 3, 4, 5]],
1192+
[[0, 0, 0, 1, 1, 1], [0, 2, 4, 1, 3, 5]],
1193+
names=["kind", None],
1194+
),
1195+
)
1196+
tm.assert_frame_equal(result, expected)
1197+
11371198

11381199
class TestEWM:
11391200
@pytest.mark.parametrize(
@@ -1162,6 +1223,41 @@ def test_methods(self, method, expected_data):
11621223
)
11631224
tm.assert_frame_equal(result, expected)
11641225

1226+
def test_groupby_ewm_agg_namedagg(self):
1227+
# GH#28333
1228+
df = DataFrame({"A": ["a"] * 4, "B": range(4)})
1229+
result = (
1230+
df.groupby("A")
1231+
.ewm(com=1.0)
1232+
.agg(
1233+
B_mean=NamedAgg(column="B", aggfunc="mean"),
1234+
B_std=NamedAgg(column="B", aggfunc="std"),
1235+
B_var=NamedAgg(column="B", aggfunc="var"),
1236+
)
1237+
)
1238+
expected = DataFrame(
1239+
{
1240+
"B_mean": [
1241+
0.0,
1242+
0.6666666666666666,
1243+
1.4285714285714286,
1244+
2.2666666666666666,
1245+
],
1246+
"B_std": [np.nan, 0.707107, 0.963624, 1.177164],
1247+
"B_var": [np.nan, 0.5, 0.9285714285714286, 1.3857142857142857],
1248+
},
1249+
index=MultiIndex.from_tuples(
1250+
[
1251+
("a", 0),
1252+
("a", 1),
1253+
("a", 2),
1254+
("a", 3),
1255+
],
1256+
names=["A", None],
1257+
),
1258+
)
1259+
tm.assert_frame_equal(result, expected)
1260+
11651261
@pytest.mark.parametrize(
11661262
"method, expected_data",
11671263
[["corr", [np.nan, 1.0, 1.0, 1]], ["cov", [np.nan, 0.5, 0.928571, 1.385714]]],

0 commit comments

Comments
 (0)