Skip to content

Commit 63249f2

Browse files
authored
ENH: Add dtype argument to str.decode (#60940)
* ENH: Add dtype argument to str.decode * Refinements * cleanup * cleanup * type-hint fixup * Simplify condition * lint
1 parent f77398c commit 63249f2

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Other enhancements
3737
updated to work correctly with NumPy >= 2 (:issue:`57739`)
3838
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
3939
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
40+
- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`)
4041
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
4142
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
4243

pandas/core/strings/accessor.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
is_numeric_dtype,
3535
is_object_dtype,
3636
is_re,
37+
is_string_dtype,
3738
)
3839
from pandas.core.dtypes.dtypes import (
3940
ArrowDtype,
@@ -2102,7 +2103,9 @@ def slice_replace(self, start=None, stop=None, repl=None):
21022103
result = self._data.array._str_slice_replace(start, stop, repl)
21032104
return self._wrap_result(result)
21042105

2105-
def decode(self, encoding, errors: str = "strict"):
2106+
def decode(
2107+
self, encoding, errors: str = "strict", dtype: str | DtypeObj | None = None
2108+
):
21062109
"""
21072110
Decode character string in the Series/Index using indicated encoding.
21082111
@@ -2116,6 +2119,12 @@ def decode(self, encoding, errors: str = "strict"):
21162119
errors : str, optional
21172120
Specifies the error handling scheme.
21182121
Possible values are those supported by :meth:`bytes.decode`.
2122+
dtype : str or dtype, optional
2123+
The dtype of the result. When not ``None``, must be either a string or
2124+
object dtype. When ``None``, the dtype of the result is determined by
2125+
``pd.options.future.infer_string``.
2126+
2127+
.. versionadded:: 2.3.0
21192128
21202129
Returns
21212130
-------
@@ -2137,6 +2146,10 @@ def decode(self, encoding, errors: str = "strict"):
21372146
2 ()
21382147
dtype: object
21392148
"""
2149+
if dtype is not None and not is_string_dtype(dtype):
2150+
raise ValueError(f"dtype must be string or object, got {dtype=}")
2151+
if dtype is None and get_option("future.infer_string"):
2152+
dtype = "str"
21402153
# TODO: Add a similar _bytes interface.
21412154
if encoding in _cpython_optimized_decoders:
21422155
# CPython optimized implementation
@@ -2146,7 +2159,6 @@ def decode(self, encoding, errors: str = "strict"):
21462159
f = lambda x: decoder(x, errors)[0]
21472160
arr = self._data.array
21482161
result = arr._str_map(f)
2149-
dtype = "str" if get_option("future.infer_string") else None
21502162
return self._wrap_result(result, dtype=dtype)
21512163

21522164
@forbid_nonstring_types(["bytes"])

pandas/tests/strings/test_strings.py

+24
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,30 @@ def test_decode_errors_kwarg():
601601
tm.assert_series_equal(result, expected)
602602

603603

604+
def test_decode_string_dtype(string_dtype):
605+
# https://github.com/pandas-dev/pandas/pull/60940
606+
ser = Series([b"a", b"b"])
607+
result = ser.str.decode("utf-8", dtype=string_dtype)
608+
expected = Series(["a", "b"], dtype=string_dtype)
609+
tm.assert_series_equal(result, expected)
610+
611+
612+
def test_decode_object_dtype(object_dtype):
613+
# https://github.com/pandas-dev/pandas/pull/60940
614+
ser = Series([b"a", rb"\ud800"])
615+
result = ser.str.decode("utf-8", dtype=object_dtype)
616+
expected = Series(["a", r"\ud800"], dtype=object_dtype)
617+
tm.assert_series_equal(result, expected)
618+
619+
620+
def test_decode_bad_dtype():
621+
# https://github.com/pandas-dev/pandas/pull/60940
622+
ser = Series([b"a", b"b"])
623+
msg = "dtype must be string or object, got dtype='int64'"
624+
with pytest.raises(ValueError, match=msg):
625+
ser.str.decode("utf-8", dtype="int64")
626+
627+
604628
@pytest.mark.parametrize(
605629
"form, expected",
606630
[

0 commit comments

Comments
 (0)