Skip to content

Commit c36da3f

Browse files
authored
ENH(string dtype): Make str.decode return str dtype (#60709)
* TST(string dtype): Make str.decode return str dtype * Test fixups * pytables fixup * Simplify * whatsnew * fix implementation
1 parent c430c61 commit c36da3f

File tree

6 files changed

+28
-18
lines changed

6 files changed

+28
-18
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Other enhancements
3535
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
3636
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
3737
updated to work correctly with NumPy >= 2 (:issue:`57739`)
38+
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
3839
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
3940
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
4041
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)

pandas/core/strings/accessor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import numpy as np
1414

15+
from pandas._config import get_option
16+
1517
from pandas._libs import lib
1618
from pandas._typing import (
1719
AlignJoin,
@@ -400,7 +402,9 @@ def cons_row(x):
400402
# This is a mess.
401403
_dtype: DtypeObj | str | None = dtype
402404
vdtype = getattr(result, "dtype", None)
403-
if self._is_string:
405+
if _dtype is not None:
406+
pass
407+
elif self._is_string:
404408
if is_bool_dtype(vdtype):
405409
_dtype = result.dtype
406410
elif returns_string:
@@ -2141,9 +2145,9 @@ def decode(self, encoding, errors: str = "strict"):
21412145
decoder = codecs.getdecoder(encoding)
21422146
f = lambda x: decoder(x, errors)[0]
21432147
arr = self._data.array
2144-
# assert isinstance(arr, (StringArray,))
21452148
result = arr._str_map(f)
2146-
return self._wrap_result(result)
2149+
dtype = "str" if get_option("future.infer_string") else None
2150+
return self._wrap_result(result, dtype=dtype)
21472151

21482152
@forbid_nonstring_types(["bytes"])
21492153
def encode(self, encoding, errors: str = "strict"):

pandas/io/pytables.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5233,7 +5233,9 @@ def _unconvert_string_array(
52335233
dtype = f"U{itemsize}"
52345234

52355235
if isinstance(data[0], bytes):
5236-
data = Series(data, copy=False).str.decode(encoding, errors=errors)._values
5236+
ser = Series(data, copy=False).str.decode(encoding, errors=errors)
5237+
data = ser.to_numpy()
5238+
data.flags.writeable = True
52375239
else:
52385240
data = data.astype(dtype, copy=False).astype(object, copy=False)
52395241

pandas/io/sas/sas7bdat.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
import numpy as np
2424

25+
from pandas._config import get_option
26+
2527
from pandas._libs.byteswap import (
2628
read_double_with_byteswap,
2729
read_float_with_byteswap,
@@ -699,6 +701,7 @@ def _chunk_to_dataframe(self) -> DataFrame:
699701
rslt = {}
700702

701703
js, jb = 0, 0
704+
infer_string = get_option("future.infer_string")
702705
for j in range(self.column_count):
703706
name = self.column_names[j]
704707

@@ -715,6 +718,9 @@ def _chunk_to_dataframe(self) -> DataFrame:
715718
rslt[name] = pd.Series(self._string_chunk[js, :], index=ix, copy=False)
716719
if self.convert_text and (self.encoding is not None):
717720
rslt[name] = self._decode_string(rslt[name].str)
721+
if infer_string:
722+
rslt[name] = rslt[name].astype("str")
723+
718724
js += 1
719725
else:
720726
self.close()

pandas/tests/io/sas/test_sas7bdat.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import numpy as np
88
import pytest
99

10-
from pandas._config import using_string_dtype
11-
1210
from pandas.compat._constants import (
1311
IS64,
1412
WASM,
@@ -20,10 +18,6 @@
2018

2119
from pandas.io.sas.sas7bdat import SAS7BDATReader
2220

23-
pytestmark = pytest.mark.xfail(
24-
using_string_dtype(), reason="TODO(infer_string)", strict=False
25-
)
26-
2721

2822
@pytest.fixture
2923
def dirpath(datapath):
@@ -246,11 +240,13 @@ def test_zero_variables(datapath):
246240
pd.read_sas(fname)
247241

248242

249-
def test_zero_rows(datapath):
243+
@pytest.mark.parametrize("encoding", [None, "utf8"])
244+
def test_zero_rows(datapath, encoding):
250245
# GH 18198
251246
fname = datapath("io", "sas", "data", "zero_rows.sas7bdat")
252-
result = pd.read_sas(fname)
253-
expected = pd.DataFrame([{"char_field": "a", "num_field": 1.0}]).iloc[:0]
247+
result = pd.read_sas(fname, encoding=encoding)
248+
str_value = b"a" if encoding is None else "a"
249+
expected = pd.DataFrame([{"char_field": str_value, "num_field": 1.0}]).iloc[:0]
254250
tm.assert_frame_equal(result, expected)
255251

256252

@@ -409,7 +405,7 @@ def test_0x40_control_byte(datapath):
409405
fname = datapath("io", "sas", "data", "0x40controlbyte.sas7bdat")
410406
df = pd.read_sas(fname, encoding="ascii")
411407
fname = datapath("io", "sas", "data", "0x40controlbyte.csv")
412-
df0 = pd.read_csv(fname, dtype="object")
408+
df0 = pd.read_csv(fname, dtype="str")
413409
tm.assert_frame_equal(df, df0)
414410

415411

pandas/tests/strings/test_strings.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def test_repeat_with_null(any_string_dtype, arg, repeat):
9595

9696
def test_empty_str_methods(any_string_dtype):
9797
empty_str = empty = Series(dtype=any_string_dtype)
98+
empty_inferred_str = Series(dtype="str")
9899
if is_object_or_nan_string_dtype(any_string_dtype):
99100
empty_int = Series(dtype="int64")
100101
empty_bool = Series(dtype=bool)
@@ -154,7 +155,7 @@ def test_empty_str_methods(any_string_dtype):
154155
tm.assert_series_equal(empty_str, empty.str.rstrip())
155156
tm.assert_series_equal(empty_str, empty.str.wrap(42))
156157
tm.assert_series_equal(empty_str, empty.str.get(0))
157-
tm.assert_series_equal(empty_object, empty_bytes.str.decode("ascii"))
158+
tm.assert_series_equal(empty_inferred_str, empty_bytes.str.decode("ascii"))
158159
tm.assert_series_equal(empty_bytes, empty.str.encode("ascii"))
159160
# ismethods should always return boolean (GH 29624)
160161
tm.assert_series_equal(empty_bool, empty.str.isalnum())
@@ -566,7 +567,7 @@ def test_string_slice_out_of_bounds(any_string_dtype):
566567
def test_encode_decode(any_string_dtype):
567568
ser = Series(["a", "b", "a\xe4"], dtype=any_string_dtype).str.encode("utf-8")
568569
result = ser.str.decode("utf-8")
569-
expected = ser.map(lambda x: x.decode("utf-8")).astype(object)
570+
expected = Series(["a", "b", "a\xe4"], dtype="str")
570571
tm.assert_series_equal(result, expected)
571572

572573

@@ -596,7 +597,7 @@ def test_decode_errors_kwarg():
596597
ser.str.decode("cp1252")
597598

598599
result = ser.str.decode("cp1252", "ignore")
599-
expected = ser.map(lambda x: x.decode("cp1252", "ignore")).astype(object)
600+
expected = ser.map(lambda x: x.decode("cp1252", "ignore")).astype("str")
600601
tm.assert_series_equal(result, expected)
601602

602603

@@ -751,5 +752,5 @@ def test_get_with_dict_label():
751752
def test_series_str_decode():
752753
# GH 22613
753754
result = Series([b"x", b"y"]).str.decode(encoding="UTF-8", errors="strict")
754-
expected = Series(["x", "y"], dtype="object")
755+
expected = Series(["x", "y"], dtype="str")
755756
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)