diff --git a/pandas/__init__.py b/pandas/__init__.py index c570fb8d70204..0cc0a2075355b 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -61,6 +61,7 @@ PeriodDtype, IntervalDtype, DatetimeTZDtype, + ListDtype, StringDtype, BooleanDtype, # missing @@ -261,6 +262,7 @@ "Interval", "IntervalDtype", "IntervalIndex", + "ListDtype", "MultiIndex", "NaT", "NamedAgg", diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index de603beff7836..7eaa9b17ee2a1 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -834,6 +834,8 @@ cpdef ndarray[object] ensure_string_array( if isinstance(val, bytes): # GH#49658 discussion of desired behavior here result[i] = val.decode() + elif util.is_array(val): + result[i] = str(val.tolist()) elif not util.is_float_object(val): # f"{val}" is faster than str(val) result[i] = f"{val}" diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index daa5187cdb636..958de0b61e542 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -54,6 +54,7 @@ TimedeltaArray, ) from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin +from pandas.core.arrays.list_ import ListDtype from pandas.core.arrays.string_ import StringDtype from pandas.core.indexes.api import safe_sort_index @@ -824,6 +825,11 @@ def assert_extension_array_equal( [np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined] ), "wrong missing value sentinels" + # TODO: not every array type may be convertible to NumPy; should catch here + if isinstance(left.dtype, ListDtype) and isinstance(right.dtype, ListDtype): + assert left._pa_array == right._pa_array + return + left_valid = left[~left_na].to_numpy(dtype=object) right_valid = right[~right_na].to_numpy(dtype=object) if check_exact: diff --git a/pandas/core/api.py b/pandas/core/api.py index ec12d543d8389..414b07ad802a9 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -40,6 +40,7 @@ UInt32Dtype, UInt64Dtype, ) +from pandas.core.arrays.list_ import ListDtype from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import array # noqa: ICN001 from pandas.core.flags import Flags @@ -103,6 +104,7 @@ "Interval", "IntervalDtype", "IntervalIndex", + "ListDtype", "MultiIndex", "NaT", "NamedAgg", diff --git a/pandas/core/arrays/arrow/accessors.py b/pandas/core/arrays/arrow/accessors.py index b220a94d032b5..e5ee23906ddf4 100644 --- a/pandas/core/arrays/arrow/accessors.py +++ b/pandas/core/arrays/arrow/accessors.py @@ -18,6 +18,8 @@ from pandas.core.dtypes.common import is_list_like +from pandas.core.arrays.list_ import ListDtype + if not pa_version_under10p1: import pyarrow as pa import pyarrow.compute as pc @@ -106,7 +108,7 @@ def len(self) -> Series: ... [1, 2, 3], ... [3], ... ], - ... dtype=pd.ArrowDtype(pa.list_(pa.int64())), + ... dtype=pd.ListDtype(pa.int64()), ... ) >>> s.list.len() 0 3 @@ -189,7 +191,7 @@ def __getitem__(self, key: int | slice) -> Series: sliced = pc.list_slice(self._pa_array, start, stop, step) return Series( sliced, - dtype=ArrowDtype(sliced.type), + dtype=ListDtype(sliced.type.value_type), index=self._data.index, name=self._data.name, ) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index afa219f611992..441e3bce9bda9 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -428,7 +428,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar: """ if isinstance(value, pa.Scalar): pa_scalar = value - elif isna(value): + elif not is_list_like(value) and isna(value): pa_scalar = pa.scalar(None, type=pa_type) else: # Workaround https://github.com/apache/arrow/issues/37291 @@ -1350,7 +1350,16 @@ def take( # TODO(ARROW-9433): Treat negative indices as NULL indices_array = pa.array(indices_array, mask=fill_mask) result = self._pa_array.take(indices_array) - if isna(fill_value): + if is_list_like(fill_value): + # TODO: this should be hit by ListArray. Ideally we do: + # pc.replace_with_mask(result, fill_mask, pa.scalar(fill_value)) + # but pyarrow does not yet implement that for list types + new_values = [ + fill_value if should_fill else x.as_py() + for x, should_fill in zip(result, fill_mask) + ] + return type(self)(new_values) + elif isna(fill_value): return type(self)(result) # TODO: ArrowNotImplementedError: Function fill_null has no # kernel matching input types (array[string], scalar[string]) diff --git a/pandas/core/arrays/list_.py b/pandas/core/arrays/list_.py new file mode 100644 index 0000000000000..bfddbe5ce2c07 --- /dev/null +++ b/pandas/core/arrays/list_.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas.compat import HAS_PYARROW +from pandas.util._decorators import set_module + +from pandas.core.dtypes.base import ( + ExtensionDtype, + register_extension_dtype, +) +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_integer_dtype, + is_string_dtype, +) +from pandas.core.dtypes.dtypes import ArrowDtype + +from pandas.core.arrays.arrow.array import ArrowExtensionArray +from pandas.core.arrays.base import ExtensionArray + +if TYPE_CHECKING: + from collections.abc import Sequence + from pandas._typing import ( + type_t, + ArrayLike, + AstypeArg, + DtypeObj, + Shape, + ) + +import re + +import pyarrow as pa + + +def string_to_pyarrow_type(string: str) -> pa.DataType: + # TODO: combine this with to_pyarrow_type in pandas.core.arrays.arrow ? + pater = r"list\[(.*)\]" + + if mtch := re.search(pater, string): + value_type = mtch.groups()[0] + match value_type: + # TODO: is there a pyarrow function get a type from the string? + case "string" | "large_string": + return pa.large_list(pa.large_string()) + case "int64": + return pa.large_list(pa.int64()) + # TODO: need to implement many more here, including nested + + raise ValueError(f"Cannot map {string} to a pyarrow list type") + + +def transpose_homogeneous_list( + arrays: Sequence[ListArray], +) -> list[ListArray]: + # TODO: this is the same as transpose_homogeneous_pyarrow + # but returns the ListArray instead of an ArrowExtensionArray + # should consolidate these + arrays = list(arrays) + nrows, ncols = len(arrays[0]), len(arrays) + indices = np.arange(nrows * ncols).reshape(ncols, nrows).T.reshape(-1) + arr = pa.chunked_array([chunk for arr in arrays for chunk in arr._pa_array.chunks]) + arr = arr.take(indices) + return [ListArray(arr.slice(i * ncols, ncols)) for i in range(nrows)] + + +@register_extension_dtype +@set_module("pandas") +class ListDtype(ArrowDtype): + """ + An ExtensionDtype suitable for storing homogeneous lists of data. + """ + + _is_immutable = True + + def __init__(self, value_dtype: pa.DataType) -> None: + super().__init__(pa.large_list(value_dtype)) + + @classmethod + def construct_from_string(cls, string: str): + if not isinstance(string, str): + raise TypeError( + f"'construct_from_string' expects a string, got {type(string)}" + ) + + try: + pa_type = string_to_pyarrow_type(string) + except ValueError as e: + raise TypeError( + f"Cannot construct a '{cls.__name__}' from '{string}'" + ) from e + + return cls(pa_type) + + @property + def name(self) -> str: # type: ignore[override] + """ + A string identifying the data type. + """ + return f"list[{self.pyarrow_dtype.value_type!s}]" + + @property + def kind(self) -> str: + # TODO(wayd): our extension interface says this field should be the + # NumPy type character, but no such thing exists for list + # This uses the Arrow C Data exchange code instead + return "+L" + + @classmethod + def construct_array_type(cls) -> type_t[ListArray]: + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return ListArray + + def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: + for dtype in dtypes: + if ( + isinstance(dtype, ListDtype) + and self.pyarrow_dtype.value_type == dtype.pyarrow_dtype.value_type + ): + continue + else: + return None + + return ListDtype(self.pyarrow_dtype.value_type) + + +class ListArray(ArrowExtensionArray): + __array_priority__ = 1000 + + def __init__( + self, values: pa.Array | pa.ChunkedArray | list | ListArray, value_type=None + ) -> None: + if not HAS_PYARROW: + raise NotImplementedError("ListArray requires pyarrow to be installed") + + if isinstance(values, type(self)): + self._pa_array = values._pa_array + else: + if value_type is None: + if isinstance(values, (pa.Array, pa.ChunkedArray)): + parent_type = values.type + if not isinstance(parent_type, (pa.ListType, pa.LargeListType)): + # TODO: maybe implement native casts in pyarrow + new_values = [ + [x.as_py()] if x.is_valid else None for x in values + ] + values = pa.array(new_values, type=pa.large_list(parent_type)) + + value_type = values.type.value_type + else: + value_type = pa.array(values).type.value_type + + if value_type == pa.string(): + value_type = pa.large_string() + + if not isinstance(values, pa.ChunkedArray): + arr = pa.array(values, type=pa.large_list(value_type), from_pandas=True) + self._pa_array = pa.chunked_array(arr, type=pa.large_list(value_type)) + else: + self._pa_array = values + + @property + def _dtype(self): + return ListDtype(self._pa_array.type.value_type) + + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False): + if isinstance(scalars, ListArray): + return cls(scalars) + elif isinstance(scalars, pa.Scalar): + scalars = [scalars] + return cls(scalars) + + try: + values = pa.array(scalars, from_pandas=True) + except TypeError: + # TypeError: object of type 'NoneType' has no len() if you have + # pa.ListScalar(None). Upstream issue in Arrow - see: + # https://github.com/apache/arrow/issues/40319 + values = pa.array(scalars.to_pylist(), from_pandas=True) + + if values.type == "null" and dtype is not None: + pa_type = string_to_pyarrow_type(str(dtype)) + values = pa.array(values, type=pa_type) + + return cls(values) + + @classmethod + def _box_pa( + cls, value, pa_type: pa.DataType | None = None + ) -> pa.Array | pa.ChunkedArray | pa.Scalar: + """ + Box value into a pyarrow Array, ChunkedArray or Scalar. + + Parameters + ---------- + value : any + pa_type : pa.DataType | None + + Returns + ------- + pa.Array or pa.ChunkedArray or pa.Scalar + """ + if ( + isinstance(value, (pa.ListScalar, pa.LargeListScalar)) + or isinstance(value, list) + or value is None + ): + return cls._box_pa_scalar(value, pa_type) + return cls._box_pa_array(value, pa_type) + + def __getitem__(self, item): + if isinstance(item, (np.ndarray, ExtensionArray)): + if is_bool_dtype(item.dtype): + mask_len = len(item) + if mask_len != len(self): + raise IndexError( + f"Boolean index has wrong length: {mask_len} " + f"instead of {len(self)}" + ) + pos = np.array(range(len(item))) + + if isinstance(item, ExtensionArray): + mask = pos[item.fillna(False)] + else: + mask = pos[item] + return type(self)(self._pa_array.take(mask)) + elif is_integer_dtype(item.dtype): + if isinstance(item, ExtensionArray) and item.isna().any(): + msg = "Cannot index with an integer indexer containing NA values" + raise ValueError(msg) + + indexer = pa.array(item) + return type(self)(self._pa_array.take(indexer)) + elif isinstance(item, int): + value = self._pa_array[item] + if value.is_valid: + return value.as_py() + else: + return self.dtype.na_value + elif isinstance(item, list): + # pyarrow does not support taking yet from an empty list + # https://github.com/apache/arrow/issues/39917 + if item: + try: + result = self._pa_array.take(item) + except pa.lib.ArrowInvalid as e: + if "Could not convert " in str(e): + msg = ( + "Cannot index with an integer indexer containing NA values" + ) + raise ValueError(msg) from e + raise e + else: + result = pa.array([], type=self._pa_array.type) + + return type(self)(result) + + try: + result = type(self)(self._pa_array[item]) + except TypeError as e: + msg = ( + "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis " + "(`None`) and integer or boolean arrays are valid indices" + ) + raise IndexError(msg) from e + + return result + + def __setitem__(self, key, value) -> None: + msg = "ListArray does not support item assignment via setitem" + raise TypeError(msg) + + @classmethod + def _empty(cls, shape: Shape, dtype: ExtensionDtype): + """ + Create an ExtensionArray with the given shape and dtype. + + See also + -------- + ExtensionDtype.empty + ExtensionDtype.empty is the 'official' public version of this API. + """ + if isinstance(shape, tuple): + if len(shape) > 1: + raise ValueError("ListArray may only be 1-D") + else: + length = shape[0] + else: + length = shape + + return cls._from_sequence([None] * length, dtype=dtype) + + def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike: + if is_string_dtype(dtype) and not isinstance(dtype, ExtensionDtype): + return np.array([str(x) for x in self], dtype=dtype) + + return super().astype(dtype, copy) + + def __eq__(self, other): + if isinstance(other, list): + from pandas.arrays import BooleanArray + + mask = np.array([False] * len(self)) + values = np.array([x.as_py() == other for x in self._pa_array]) + return BooleanArray(values, mask) + elif isinstance(other, (pa.ListScalar, pa.LargeListScalar)): + from pandas.arrays import BooleanArray + + # TODO: pyarrow.compute does not implement equal for lists + # https://github.com/apache/arrow/issues/45167 + # TODO: pyarrow doesn't compare missing values in Python as missing??? + # arr = pa.array([1, 2, None]) + # pc.equal(arr, arr[2]) returns all nulls but + # arr[2] == arr[2] returns True + mask = np.array([False] * len(self)) + values = np.array([x == other for x in self._pa_array]) + return BooleanArray(values, mask) + + return super().__eq__(other) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 02878b36a379e..e32355c8fe5f7 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -135,6 +135,7 @@ PeriodArray, TimedeltaArray, ) +from pandas.core.arrays.list_ import ListDtype from pandas.core.arrays.sparse import SparseFrameAccessor from pandas.core.construction import ( ensure_wrapped_if_datetimelike, @@ -821,7 +822,7 @@ def __init__( if len(data) > 0: if is_dataclass(data[0]): data = dataclasses_to_dicts(data) - if not isinstance(data, np.ndarray) and treat_as_nested(data): + if not isinstance(data, np.ndarray) and treat_as_nested(data, dtype): # exclude ndarray as we may have cast it a few lines above if columns is not None: columns = ensure_index(columns) @@ -3800,6 +3801,15 @@ def transpose( new_values = transpose_homogeneous_masked_arrays( cast(Sequence[BaseMaskedArray], self._iter_column_arrays()) ) + elif isinstance(first_dtype, ListDtype): + from pandas.core.arrays.list_ import ( + ListArray, + transpose_homogeneous_list, + ) + + new_values = transpose_homogeneous_list( + cast(Sequence[ListArray], self._iter_column_arrays()) + ) elif isinstance(first_dtype, ArrowDtype): # We have arrow EAs with the same dtype. We can transpose faster. from pandas.core.arrays.arrow.array import ( diff --git a/pandas/core/generic.py b/pandas/core/generic.py index de7fb3682fb4f..438f349c152b1 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -149,6 +149,7 @@ ) from pandas.core.array_algos.replace import should_use_regex from pandas.core.arrays import ExtensionArray +from pandas.core.arrays.list_ import ListDtype from pandas.core.base import PandasObject from pandas.core.construction import extract_array from pandas.core.flags import Flags @@ -7012,11 +7013,20 @@ def fillna( stacklevel=2, ) + holds_list_array = False + if isinstance(self, ABCSeries) and isinstance(self.dtype, ListDtype): + holds_list_array = True + elif isinstance(self, ABCDataFrame) and any( + isinstance(x, ListDtype) for x in self.dtypes + ): + holds_list_array = True + if isinstance(value, (list, tuple)): - raise TypeError( - '"value" parameter must be a scalar or dict, but ' - f'you passed a "{type(value).__name__}"' - ) + if not holds_list_array: + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + f'you passed a "{type(value).__name__}"' + ) # set the default here, so functions examining the signature # can detect if something was set (e.g. in groupby) (GH9221) @@ -7036,7 +7046,9 @@ def fillna( value = Series(value) value = value.reindex(self.index) value = value._values - elif not is_list_like(value): + elif ( + isinstance(value, list) and isinstance(self.dtype, ListDtype) + ) or not is_list_like(value): pass else: raise TypeError( @@ -7100,7 +7112,7 @@ def fillna( else: return result - elif not is_list_like(value): + elif holds_list_array or not is_list_like(value): if axis == 1: result = self.T.fillna(value=value, limit=limit).T new_data = result._mgr diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index dfff34656f82b..af038c2d6751f 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -46,6 +46,7 @@ common as com, ) from pandas.core.arrays import ExtensionArray +from pandas.core.arrays.list_ import ListDtype from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import ( array as pd_array, @@ -452,7 +453,7 @@ def nested_data_to_arrays( return arrays, columns, index -def treat_as_nested(data) -> bool: +def treat_as_nested(data, dtype) -> bool: """ Check if we should use nested_data_to_arrays. """ @@ -460,6 +461,7 @@ def treat_as_nested(data) -> bool: len(data) > 0 and is_list_like(data[0]) and getattr(data[0], "ndim", 1) == 1 + and not isinstance(dtype, ListDtype) and not (isinstance(data, ExtensionArray) and data.ndim == 2) ) diff --git a/pandas/core/series.py b/pandas/core/series.py index 4fa8b86fa4c16..612539217168b 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -111,6 +111,7 @@ StructAccessor, ) from pandas.core.arrays.categorical import CategoricalAccessor +from pandas.core.arrays.list_ import ListDtype from pandas.core.arrays.sparse import SparseAccessor from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import ( @@ -494,7 +495,7 @@ def __init__( if not is_list_like(data): data = [data] index = default_index(len(data)) - elif is_list_like(data): + elif is_list_like(data) and not isinstance(dtype, ListDtype): com.require_length_match(data, index) # create/copy the manager diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index c1d9f5ea4d25c..233b963633057 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -72,6 +72,7 @@ class TestPDApi(Base): "RangeIndex", "Series", "SparseDtype", + "ListDtype", "StringDtype", "Timedelta", "TimedeltaIndex", diff --git a/pandas/tests/extension/list/__init__.py b/pandas/tests/extension/list/__init__.py deleted file mode 100644 index 0f3f2f3537788..0000000000000 --- a/pandas/tests/extension/list/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from pandas.tests.extension.list.array import ( - ListArray, - ListDtype, - make_data, -) - -__all__ = ["ListArray", "ListDtype", "make_data"] diff --git a/pandas/tests/extension/list/array.py b/pandas/tests/extension/list/array.py deleted file mode 100644 index da53bdcb4e37e..0000000000000 --- a/pandas/tests/extension/list/array.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Test extension array for storing nested data in a pandas container. - -The ListArray stores an ndarray of lists. -""" - -from __future__ import annotations - -import numbers -import string -from typing import TYPE_CHECKING - -import numpy as np - -from pandas.core.dtypes.base import ExtensionDtype - -import pandas as pd -from pandas.api.types import ( - is_object_dtype, - is_string_dtype, -) -from pandas.core.arrays import ExtensionArray - -if TYPE_CHECKING: - from pandas._typing import type_t - - -class ListDtype(ExtensionDtype): - type = list - name = "list" - na_value = np.nan - - @classmethod - def construct_array_type(cls) -> type_t[ListArray]: - """ - Return the array type associated with this dtype. - - Returns - ------- - type - """ - return ListArray - - -class ListArray(ExtensionArray): - dtype = ListDtype() - __array_priority__ = 1000 - - def __init__(self, values, dtype=None, copy=False) -> None: - if not isinstance(values, np.ndarray): - raise TypeError("Need to pass a numpy array as values") - for val in values: - if not isinstance(val, self.dtype.type) and not pd.isna(val): - raise TypeError("All values must be of type " + str(self.dtype.type)) - self.data = values - - @classmethod - def _from_sequence(cls, scalars, *, dtype=None, copy=False): - data = np.empty(len(scalars), dtype=object) - data[:] = scalars - return cls(data) - - def __getitem__(self, item): - if isinstance(item, numbers.Integral): - return self.data[item] - else: - # slice, list-like, mask - return type(self)(self.data[item]) - - def __len__(self) -> int: - return len(self.data) - - def isna(self): - return np.array( - [not isinstance(x, list) and np.isnan(x) for x in self.data], dtype=bool - ) - - def take(self, indexer, allow_fill=False, fill_value=None): - # re-implement here, since NumPy has trouble setting - # sized objects like UserDicts into scalar slots of - # an ndarary. - indexer = np.asarray(indexer) - msg = ( - "Index is out of bounds or cannot do a " - "non-empty take from an empty array." - ) - - if allow_fill: - if fill_value is None: - fill_value = self.dtype.na_value - # bounds check - if (indexer < -1).any(): - raise ValueError - try: - output = [ - self.data[loc] if loc != -1 else fill_value for loc in indexer - ] - except IndexError as err: - raise IndexError(msg) from err - else: - try: - output = [self.data[loc] for loc in indexer] - except IndexError as err: - raise IndexError(msg) from err - - return self._from_sequence(output) - - def copy(self): - return type(self)(self.data[:]) - - def astype(self, dtype, copy=True): - if isinstance(dtype, type(self.dtype)) and dtype == self.dtype: - if copy: - return self.copy() - return self - elif is_string_dtype(dtype) and not is_object_dtype(dtype): - # numpy has problems with astype(str) for nested elements - return np.array([str(x) for x in self.data], dtype=dtype) - elif not copy: - return np.asarray(self.data, dtype=dtype) - else: - return np.array(self.data, dtype=dtype, copy=copy) - - @classmethod - def _concat_same_type(cls, to_concat): - data = np.concatenate([x.data for x in to_concat]) - return cls(data) - - -def make_data(): - # TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer - rng = np.random.default_rng(2) - data = np.empty(100, dtype=object) - data[:] = [ - [rng.choice(list(string.ascii_letters)) for _ in range(rng.integers(0, 10))] - for _ in range(100) - ] - return data diff --git a/pandas/tests/extension/list/test_list.py b/pandas/tests/extension/list/test_list.py index ac396cd3c60d4..33d6303796e04 100644 --- a/pandas/tests/extension/list/test_list.py +++ b/pandas/tests/extension/list/test_list.py @@ -1,27 +1,279 @@ +import itertools +import operator + +import pyarrow as pa import pytest import pandas as pd -from pandas.tests.extension.list.array import ( +import pandas._testing as tm +from pandas.core.arrays.list_ import ( ListArray, ListDtype, - make_data, ) +from pandas.tests.extension import base @pytest.fixture def dtype(): - return ListDtype() + return ListDtype(pa.large_string()) @pytest.fixture def data(): """Length-100 ListArray for semantics test.""" - data = make_data() + # TODO: make better random data + data = [list("a"), list("ab"), list("abc")] * 33 + [list("a")] + return ListArray(data) - while len(data[0]) == len(data[1]): - data = make_data() - return ListArray(data) +@pytest.fixture +def data_missing(dtype): + """Length 2 array with [NA, Valid]""" + arr = dtype.construct_array_type()._from_sequence([pd.NA, [1, 2, 3]], dtype=dtype) + return arr + + +@pytest.fixture +def data_for_sorting(data_for_grouping): + """ + Length-3 array with a known sort order. + + This should be three items [B, C, A] with + A < B < C + """ + pytest.skip("ListArray does not support sorting") + + +@pytest.fixture +def data_missing_for_sorting(data_for_grouping): + """ + Length-3 array with a known sort order. + + This should be three items [B, NA, A] with + A < B and NA missing. + """ + pytest.skip("ListArray does not support sorting") + + +@pytest.fixture +def data_for_grouping(dtype): + A = ["a"] + B = ["a", "b"] + NA = None + C = ["a", "b", "c"] + return ListArray([B, B, NA, NA, A, A, B, C]) + + +class TestListArray(base.ExtensionTests): + def test_fillna_no_op_returns_copy(self, data): + # TODO(wayd): This test is copied from test_arrow.py + # It appears the TestArrowArray class has different expectations around + # when copies should be made then the base.ExtensionTests + # Assuming intentional, maybe in the long term this should just + # inherit from TestArrowArray + data = data[~data.isna()] + + valid = data[0] + result = data.fillna(valid) + assert result is not data + tm.assert_extension_array_equal(result, data) + + def test_kind(self, dtype): + assert dtype.kind == "+L" + + @pytest.mark.parametrize("as_index", [True, False]) + def test_groupby_extension_agg(self, as_index, data_for_grouping): + pytest.skip(reason="ListArray does not implement mean") + + def test_groupby_extension_no_sort(self, data_for_grouping): + pytest.skip(reason="ListArray does not implement mean") + + def test_groupby_extension_transform(self, data_for_grouping): + pytest.skip(reason="ListArray does not implement dictionary_encode") + + def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op): + pytest.skip(reason="ListArray does not implement dictionary_encode") + + def test_array_interface(self, data): + pytest.skip(reason="ListArrayScalar does not compare to numpy object-dtype") + + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data, request): + pytest.skip(reason="ListArray has not implemented parsing from string") + + def test_arith_series_with_scalar(self, data, all_arithmetic_operators): + if all_arithmetic_operators in ("__mod__", "__rmod__"): + pytest.skip("ListArray does not implement __mod__ or __rmod__") + + super().test_arith_series_with_scalar(data, all_arithmetic_operators) + + def test_arith_series_with_array(self, data, all_arithmetic_operators, request): + if all_arithmetic_operators in ("__mod__", "__rmod__"): + pytest.skip("ListArray does not implement __mod__ or __rmod__") + + super().test_arith_series_with_array(data, all_arithmetic_operators) + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): + if all_arithmetic_operators in ("__mod__", "__rmod__"): + pytest.skip("ListArray does not implement __mod__ or __rmod__") + + super().test_arith_frame_with_scalar(data, all_arithmetic_operators) + + def test_divmod(self, data): + pytest.skip("ListArray does not implement divmod") + + def test_compare_scalar(self, data, comparison_op): + if comparison_op in (operator.eq, operator.ne): + pytest.skip("Series.combine does not properly handle missing values") + + super().test_compare_scalar(data, comparison_op) + + def test_compare_array(self, data, comparison_op): + pytest.skip("ListArray comparison ops are not implemented") + + def test_invert(self, data): + pytest.skip("ListArray does not implement invert") + + def test_merge_on_extension_array(self, data): + pytest.skip("ListArray cannot be factorized") + + def test_merge_on_extension_array_duplicates(self, data): + pytest.skip("ListArray cannot be factorized") + + @pytest.mark.parametrize( + "index", + [ + # Two levels, uniform. + pd.MultiIndex.from_product(([["A", "B"], ["a", "b"]]), names=["a", "b"]), + # non-uniform + pd.MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("B", "b")]), + # three levels, non-uniform + pd.MultiIndex.from_product([("A", "B"), ("a", "b", "c"), (0, 1, 2)]), + pd.MultiIndex.from_tuples( + [ + ("A", "a", 1), + ("A", "b", 0), + ("A", "a", 0), + ("B", "a", 0), + ("B", "c", 1), + ] + ), + ], + ) + @pytest.mark.parametrize("obj", ["series", "frame"]) + def test_unstack(self, data, index, obj): + # TODO: the base class test casts everything to object + # If you remove the object casts, these tests pass... + # Check if still needed in base class + data = data[: len(index)] + if obj == "series": + ser = pd.Series(data, index=index) + else: + ser = pd.DataFrame({"A": data, "B": data}, index=index) + + n = index.nlevels + levels = list(range(n)) + # [0, 1, 2] + # [(0,), (1,), (2,), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + combinations = itertools.chain.from_iterable( + itertools.permutations(levels, i) for i in range(1, n) + ) + + for level in combinations: + result = ser.unstack(level=level) + assert all( + isinstance(result[col].array, type(data)) for col in result.columns + ) + + if obj == "series": + # We should get the same result with to_frame+unstack+droplevel + df = ser.to_frame() + + alt = df.unstack(level=level).droplevel(0, axis=1) + tm.assert_frame_equal(result, alt) + + # obj_ser = ser.astype(object) + + expected = ser.unstack(level=level, fill_value=data.dtype.na_value) + # if obj == "series": + # assert (expected.dtypes == object).all() + + # result = result.astype(object) + tm.assert_frame_equal(result, expected) + + def test_getitem_ellipsis_and_slice(self, data): + pytest.skip("ListArray does not support NumPy style ellipsis slicing nor 2-D") + + def test_hash_pandas_object(self, data): + pytest.skip("ListArray does not support this") + + @pytest.mark.parametrize("dropna", [True, False]) + def test_value_counts(self, all_data, dropna): + pytest.skip("ListArray does not support this") + + def test_value_counts_with_normalize(self, data): + pytest.skip("ListArray does not support this") + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data_missing, na_action): + pytest.skip("ListArray does not support this") + + @pytest.mark.parametrize("keep", ["first", "last", False]) + def test_duplicated(self, data, keep): + pytest.skip("ListArray does not support this") + + @pytest.mark.parametrize("box", [pd.Series, lambda x: x]) + @pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique]) + def test_unique(self, data, box, method): + pytest.skip("ListArray does not support this") + + def test_factorize(self, data_for_grouping): + pytest.skip("ListArray does not support this") + + def test_factorize_equivalence(self, data_for_grouping): + pytest.skip("ListArray does not support this") + + def test_factorize_empty(self, data): + pytest.skip("ListArray does not support this") + + def test_fillna_limit_frame(self, data_missing): + pytest.skip("Needs review - can assignment be avoided?") + + def test_fillna_limit_series(self, data_missing): + pytest.skip("Needs review - can assignment be avoided?") + + def test_fillna_copy_frame(self, data_missing): + pytest.skip("Needs review - can assignment be avoided?") + + def test_fillna_copy_series(self, data_missing): + pytest.skip("Needs review - can assignment be avoided?") + + def test_combine_le(self, data_repeated): + pytest.skip("Needs review - can assignment be avoided?") + + def test_combine_first(self, data): + pytest.skip("Needs review - can assignment be avoided?") + + def test_shift_0_periods(self, data): + pytest.skip("Needs review - can assignment be avoided?") + + def test_hash_pandas_object_works(self, data, as_frame): + pytest.skip("ListArray does not support this") + + def test_where_series(self, data, na_value, as_frame): + pytest.skip("Needs review - can assignment be avoided?") + + def test_argsort(self, data_for_sorting): + pytest.skip("ListArray does not support this") + + def test_argsort_missing_array(self, data_missing_for_sorting): + pytest.skip("ListArray does not support this") + + def test_argsort_missing(self, data_missing_for_sorting): + pytest.skip("ListArray does not support this") + + def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value): + pytest.skip("ListArray does not support this") def test_to_csv(data): diff --git a/pandas/tests/series/accessors/test_list_accessor.py b/pandas/tests/series/accessors/test_list_accessor.py index bec8ca13a2f5f..909af8ee7c1d9 100644 --- a/pandas/tests/series/accessors/test_list_accessor.py +++ b/pandas/tests/series/accessors/test_list_accessor.py @@ -4,6 +4,7 @@ from pandas import ( ArrowDtype, + ListDtype, Series, ) import pandas._testing as tm @@ -16,15 +17,16 @@ @pytest.mark.parametrize( "list_dtype", ( - pa.list_(pa.int64()), - pa.list_(pa.int64(), list_size=3), - pa.large_list(pa.int64()), + ArrowDtype(pa.list_(pa.int64())), + ArrowDtype(pa.list_(pa.int64(), list_size=3)), + ArrowDtype(pa.large_list(pa.int64())), + ListDtype(pa.int64()), ), ) def test_list_getitem(list_dtype): ser = Series( [[1, 2, 3], [4, None, 5], None], - dtype=ArrowDtype(list_dtype), + dtype=list_dtype, name="a", ) actual = ser.list[1] @@ -36,7 +38,7 @@ def test_list_getitem_index(): # GH 58425 ser = Series( [[1, 2, 3], [4, None, 5], None], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), index=[1, 3, 7], name="a", ) @@ -53,7 +55,7 @@ def test_list_getitem_index(): def test_list_getitem_slice(): ser = Series( [[1, 2, 3], [4, None, 5], None], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), index=[1, 3, 7], name="a", ) @@ -66,7 +68,7 @@ def test_list_getitem_slice(): actual = ser.list[1:None:None] expected = Series( [[2, 3], [None, 5], None], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), index=[1, 3, 7], name="a", ) @@ -76,18 +78,18 @@ def test_list_getitem_slice(): def test_list_len(): ser = Series( [[1, 2, 3], [4, None], None], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), name="a", ) actual = ser.list.len() - expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()), name="a") + expected = Series([3, 2, None], dtype=ArrowDtype(pa.int64()), name="a") tm.assert_series_equal(actual, expected) def test_list_flatten(): ser = Series( [[1, 2, 3], None, [4, None], [], [7, 8]], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), name="a", ) actual = ser.list.flatten() @@ -103,7 +105,7 @@ def test_list_flatten(): def test_list_getitem_slice_invalid(): ser = Series( [[1, 2, 3], [4, None, 5], None], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), ) if pa_version_under11p0: with pytest.raises( @@ -133,15 +135,16 @@ def test_list_accessor_non_list_dtype(): @pytest.mark.parametrize( "list_dtype", ( - pa.list_(pa.int64()), - pa.list_(pa.int64(), list_size=3), - pa.large_list(pa.int64()), + ArrowDtype(pa.list_(pa.int64())), + ArrowDtype(pa.list_(pa.int64(), list_size=3)), + ArrowDtype(pa.large_list(pa.int64())), + ListDtype(pa.int64()), ), ) def test_list_getitem_invalid_index(list_dtype): ser = Series( [[1, 2, 3], [4, None, 5], None], - dtype=ArrowDtype(list_dtype), + dtype=list_dtype, ) with pytest.raises(pa.lib.ArrowInvalid, match="Index -1 is out of bounds"): ser.list[-1] @@ -154,7 +157,7 @@ def test_list_getitem_invalid_index(list_dtype): def test_list_accessor_not_iterable(): ser = Series( [[1, 2, 3], [4, None], None], - dtype=ArrowDtype(pa.list_(pa.int64())), + dtype=ListDtype(pa.int64()), ) with pytest.raises(TypeError, match="'ListAccessor' object is not iterable"): iter(ser.list)