Skip to content

Commit 4e20195

Browse files
authored
ENH(string dtype): Implement cumsum for Python-backed strings (#60938)
* ENH(string dtype): Implement cumsum for Python-backed strings * cleanups * cleanups * type-hint fixup * More type fixes * Use quotes for cast * Refinements * type-ignore
1 parent d4dff29 commit 4e20195

File tree

5 files changed

+92
-20
lines changed

5 files changed

+92
-20
lines changed

doc/source/whatsnew/v2.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Other enhancements
3838
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
3939
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
4040
- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`)
41-
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
41+
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns (:issue:`60633`)
4242
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
4343

4444
.. ---------------------------------------------------------------------------

pandas/core/arrays/string_.py

+83
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050

5151
from pandas.core import (
52+
missing,
5253
nanops,
5354
ops,
5455
)
@@ -870,6 +871,88 @@ def _reduce(
870871

871872
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
872873

874+
def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArray:
875+
"""
876+
Return an ExtensionArray performing an accumulation operation.
877+
878+
The underlying data type might change.
879+
880+
Parameters
881+
----------
882+
name : str
883+
Name of the function, supported values are:
884+
- cummin
885+
- cummax
886+
- cumsum
887+
- cumprod
888+
skipna : bool, default True
889+
If True, skip NA values.
890+
**kwargs
891+
Additional keyword arguments passed to the accumulation function.
892+
Currently, there is no supported kwarg.
893+
894+
Returns
895+
-------
896+
array
897+
898+
Raises
899+
------
900+
NotImplementedError : subclass does not define accumulations
901+
"""
902+
if name == "cumprod":
903+
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
904+
raise TypeError(msg)
905+
906+
# We may need to strip out trailing NA values
907+
tail: np.ndarray | None = None
908+
na_mask: np.ndarray | None = None
909+
ndarray = self._ndarray
910+
np_func = {
911+
"cumsum": np.cumsum,
912+
"cummin": np.minimum.accumulate,
913+
"cummax": np.maximum.accumulate,
914+
}[name]
915+
916+
if self._hasna:
917+
na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray))
918+
if np.all(na_mask):
919+
return type(self)(ndarray)
920+
if skipna:
921+
if name == "cumsum":
922+
ndarray = np.where(na_mask, "", ndarray)
923+
else:
924+
# We can retain the running min/max by forward/backward filling.
925+
ndarray = ndarray.copy()
926+
missing.pad_or_backfill_inplace(
927+
ndarray,
928+
method="pad",
929+
axis=0,
930+
)
931+
missing.pad_or_backfill_inplace(
932+
ndarray,
933+
method="backfill",
934+
axis=0,
935+
)
936+
else:
937+
# When not skipping NA values, the result should be null from
938+
# the first NA value onward.
939+
idx = np.argmax(na_mask)
940+
tail = np.empty(len(ndarray) - idx, dtype="object")
941+
tail[:] = self.dtype.na_value
942+
ndarray = ndarray[:idx]
943+
944+
# mypy: Cannot call function of unknown type
945+
np_result = np_func(ndarray) # type: ignore[operator]
946+
947+
if tail is not None:
948+
np_result = np.hstack((np_result, tail))
949+
elif na_mask is not None:
950+
# Argument 2 to "where" has incompatible type "NAType | float"
951+
np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type]
952+
953+
result = type(self)(np_result)
954+
return result
955+
873956
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
874957
if self.dtype.na_value is np.nan and result is libmissing.NA:
875958
# the masked_reductions use pd.NA -> convert to np.nan

pandas/tests/apply/test_str.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66

77
from pandas.compat import (
8-
HAS_PYARROW,
98
WASM,
109
)
1110

@@ -162,17 +161,10 @@ def test_agg_cython_table_series(series, func, expected):
162161
),
163162
),
164163
)
165-
def test_agg_cython_table_transform_series(request, series, func, expected):
164+
def test_agg_cython_table_transform_series(series, func, expected):
166165
# GH21224
167166
# test transforming functions in
168167
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
169-
if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW:
170-
request.applymarker(
171-
pytest.mark.xfail(
172-
raises=NotImplementedError,
173-
reason="TODO(infer_string) cumsum not yet implemented for string",
174-
)
175-
)
176168
warn = None if isinstance(func, str) else FutureWarning
177169
with tm.assert_produces_warning(warn, match="is currently using Series.*"):
178170
result = series.agg(func)

pandas/tests/extension/test_string.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,7 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
196196

197197
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
198198
assert isinstance(ser.dtype, StorageExtensionDtype)
199-
return ser.dtype.storage == "pyarrow" and op_name in [
200-
"cummin",
201-
"cummax",
202-
"cumsum",
203-
]
199+
return op_name in ["cummin", "cummax", "cumsum"]
204200

205201
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
206202
dtype = cast(StringDtype, tm.get_dtype(obj))

pandas/tests/series/test_cumulative.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,14 @@ def test_cumprod_timedelta(self):
265265
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
266266
],
267267
)
268-
def test_cum_methods_pyarrow_strings(
269-
self, pyarrow_string_dtype, data, op, skipna, expected_data
268+
def test_cum_methods_ea_strings(
269+
self, string_dtype_no_object, data, op, skipna, expected_data
270270
):
271-
# https://github.com/pandas-dev/pandas/pull/60633
272-
ser = pd.Series(data, dtype=pyarrow_string_dtype)
271+
# https://github.com/pandas-dev/pandas/pull/60633 - pyarrow
272+
# https://github.com/pandas-dev/pandas/pull/60938 - Python
273+
ser = pd.Series(data, dtype=string_dtype_no_object)
273274
method = getattr(ser, op)
274-
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
275+
expected = pd.Series(expected_data, dtype=string_dtype_no_object)
275276
result = method(skipna=skipna)
276277
tm.assert_series_equal(result, expected)
277278

0 commit comments

Comments
 (0)