Skip to content

Commit 3d5c84b

Browse files
committed
Backport PR pandas-dev#60940: ENH: Add dtype argument to str.decode
1 parent e0f47b7 commit 3d5c84b

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Other enhancements
3737
updated to raise FutureWarning with NumPy >= 2 (:issue:`60340`)
3838
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
3939
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
40+
- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`)
4041
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
4142
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
4243

pandas/core/strings/accessor.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
is_list_like,
3434
is_object_dtype,
3535
is_re,
36+
is_string_dtype,
3637
)
3738
from pandas.core.dtypes.dtypes import (
3839
ArrowDtype,
@@ -1981,7 +1982,9 @@ def slice_replace(self, start=None, stop=None, repl=None):
19811982
result = self._data.array._str_slice_replace(start, stop, repl)
19821983
return self._wrap_result(result)
19831984

1984-
def decode(self, encoding, errors: str = "strict"):
1985+
def decode(
1986+
self, encoding, errors: str = "strict", dtype: str | DtypeObj | None = None
1987+
):
19851988
"""
19861989
Decode character string in the Series/Index using indicated encoding.
19871990
@@ -1992,6 +1995,14 @@ def decode(self, encoding, errors: str = "strict"):
19921995
----------
19931996
encoding : str
19941997
errors : str, optional
1998+
Specifies the error handling scheme.
1999+
Possible values are those supported by :meth:`bytes.decode`.
2000+
dtype : str or dtype, optional
2001+
The dtype of the result. When not ``None``, must be either a string or
2002+
object dtype. When ``None``, the dtype of the result is determined by
2003+
``pd.options.future.infer_string``.
2004+
2005+
.. versionadded:: 2.3.0
19952006
19962007
Returns
19972008
-------
@@ -2008,6 +2019,10 @@ def decode(self, encoding, errors: str = "strict"):
20082019
2 ()
20092020
dtype: object
20102021
"""
2022+
if dtype is not None and not is_string_dtype(dtype):
2023+
raise ValueError(f"dtype must be string or object, got {dtype=}")
2024+
if dtype is None and get_option("future.infer_string"):
2025+
dtype = "str"
20112026
# TODO: Add a similar _bytes interface.
20122027
if encoding in _cpython_optimized_decoders:
20132028
# CPython optimized implementation
@@ -2017,7 +2032,6 @@ def decode(self, encoding, errors: str = "strict"):
20172032
f = lambda x: decoder(x, errors)[0]
20182033
arr = self._data.array
20192034
result = arr._str_map(f)
2020-
dtype = "str" if get_option("future.infer_string") else None
20212035
return self._wrap_result(result, dtype=dtype)
20222036

20232037
@forbid_nonstring_types(["bytes"])

pandas/tests/strings/test_strings.py

+24
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,30 @@ def test_decode_errors_kwarg():
599599
tm.assert_series_equal(result, expected)
600600

601601

602+
def test_decode_string_dtype(string_dtype):
603+
# https://github.com/pandas-dev/pandas/pull/60940
604+
ser = Series([b"a", b"b"])
605+
result = ser.str.decode("utf-8", dtype=string_dtype)
606+
expected = Series(["a", "b"], dtype=string_dtype)
607+
tm.assert_series_equal(result, expected)
608+
609+
610+
def test_decode_object_dtype(object_dtype):
611+
# https://github.com/pandas-dev/pandas/pull/60940
612+
ser = Series([b"a", rb"\ud800"])
613+
result = ser.str.decode("utf-8", dtype=object_dtype)
614+
expected = Series(["a", r"\ud800"], dtype=object_dtype)
615+
tm.assert_series_equal(result, expected)
616+
617+
618+
def test_decode_bad_dtype():
619+
# https://github.com/pandas-dev/pandas/pull/60940
620+
ser = Series([b"a", b"b"])
621+
msg = "dtype must be string or object, got dtype='int64'"
622+
with pytest.raises(ValueError, match=msg):
623+
ser.str.decode("utf-8", dtype="int64")
624+
625+
602626
@pytest.mark.parametrize(
603627
"form, expected",
604628
[

0 commit comments

Comments
 (0)