Skip to content

Commit 2b0279e

Browse files
MarcoGorelliDr-Irv
andauthored
Introduce UnknownSeries and UnknownIndex, type core.strings.pyi using them (#1146)
* make typing in pandas_stubs.core.strings.pyi strict, add UnknownSeries and UnknownIndex * undo pyproject.toml changes * use class, use pyright: strict * update pyright * reduce diff * fixup * fixup * include UnknownSeries in str.cat * move UnknownSeries and UnknownIndex location * use typealias * use Series[str] as .cat return type * use -> T so it matches other .str methods like .str.uppercase * use _TS2 for findall * add test to cover passing UnknownSeries to cat * preserve type in series.str * simplify * use Mapping instead of dict as it is invariant * fixup * split out into separate file * split out into separate file * type check boolean return values * integer return type * integer return type * strings and bytes * list * expanding * fixup * keep fixing * keep fixing * overloads cat * fixup str.extract * rename for clarity * lint * annotate idx2 as per mypys request * return _T_STR, except for `slice` because that one preserves the input types * mypy fixup * disallow .str on certain series types * Revert "disallow .str on certain series types" This reverts commit b2d4657. * use Index of list[str] --------- Co-authored-by: Irv Lustig <[email protected]>
1 parent a0313da commit 2b0279e

File tree

5 files changed

+540
-222
lines changed

5 files changed

+540
-222
lines changed

pandas-stubs/core/indexes/base.pyi

+13-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ from typing import (
1313
Any,
1414
ClassVar,
1515
Literal,
16+
TypeAlias,
1617
final,
1718
overload,
1819
)
@@ -263,7 +264,16 @@ class Index(IndexOpsMixin[S1]):
263264
@property
264265
def str(
265266
self,
266-
) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ...
267+
) -> StringMethods[
268+
Self,
269+
MultiIndex,
270+
np_ndarray_bool,
271+
Index[list[str]],
272+
Index[int],
273+
Index[bytes],
274+
Index[str],
275+
Index[type[object]],
276+
]: ...
267277
def is_(self, other) -> bool: ...
268278
def __len__(self) -> int: ...
269279
def __array__(self, dtype=...) -> np.ndarray: ...
@@ -455,6 +465,8 @@ class Index(IndexOpsMixin[S1]):
455465
),
456466
) -> Self: ...
457467

468+
UnknownIndex: TypeAlias = Index[Any]
469+
458470
def ensure_index_from_sequences(
459471
sequences: Sequence[Sequence[Dtype]], names: list[str] = ...
460472
) -> Index: ...

pandas-stubs/core/series.pyi

+12-1
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,16 @@ class Series(IndexOpsMixin[S1], NDFrame):
11791179
@property
11801180
def str(
11811181
self,
1182-
) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ...
1182+
) -> StringMethods[
1183+
Self,
1184+
DataFrame,
1185+
Series[bool],
1186+
Series[list[str]],
1187+
Series[int],
1188+
Series[bytes],
1189+
Series[str],
1190+
Series[type[object]],
1191+
]: ...
11831192
@property
11841193
def dt(self) -> CombinedDatetimelikeProperties: ...
11851194
@property
@@ -2318,3 +2327,5 @@ class IntervalSeries(Series[Interval[_OrderableT]], Generic[_OrderableT]):
23182327
@property
23192328
def array(self) -> IntervalArray: ...
23202329
def diff(self, periods: int = ...) -> Never: ...
2330+
2331+
UnknownSeries: TypeAlias = Series[Any]

pandas-stubs/core/strings.pyi

+99-78
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: strict
12
from collections.abc import (
23
Callable,
34
Sequence,
@@ -12,6 +13,7 @@ from typing import (
1213
)
1314

1415
import numpy as np
16+
import numpy.typing as npt
1517
import pandas as pd
1618
from pandas import (
1719
DataFrame,
@@ -21,23 +23,36 @@ from pandas import (
2123
)
2224
from pandas.core.base import NoNewAttributesMixin
2325

26+
from pandas._libs.tslibs.nattype import NaTType
2427
from pandas._typing import (
2528
JoinHow,
29+
Scalar,
2630
T,
2731
np_ndarray_bool,
2832
)
2933

30-
# The _TS type is what is used for the result of str.split with expand=True
31-
_TS = TypeVar("_TS", bound=DataFrame | MultiIndex)
32-
# The _TS2 type is what is used for the result of str.split with expand=False
33-
_TS2 = TypeVar("_TS2", bound=Series[list[str]] | Index[list[str]])
34-
# The _TM type is what is used for the result of str.match
35-
_TM = TypeVar("_TM", bound=Series[bool] | np_ndarray_bool)
34+
# Used for the result of str.split with expand=True
35+
_T_EXPANDING = TypeVar("_T_EXPANDING", bound=DataFrame | MultiIndex)
36+
# Used for the result of str.split with expand=False
37+
_T_LIST_STR = TypeVar("_T_LIST_STR", bound=Series[list[str]] | Index[list[str]])
38+
# Used for the result of str.match
39+
_T_BOOL = TypeVar("_T_BOOL", bound=Series[bool] | np_ndarray_bool)
40+
# Used for the result of str.index / str.find
41+
_T_INT = TypeVar("_T_INT", bound=Series[int] | Index[int])
42+
# Used for the result of str.encode
43+
_T_BYTES = TypeVar("_T_BYTES", bound=Series[bytes] | Index[bytes])
44+
# Used for the result of str.decode
45+
_T_STR = TypeVar("_T_STR", bound=Series[str] | Index[str])
46+
# Used for the result of str.partition
47+
_T_OBJECT = TypeVar("_T_OBJECT", bound=Series[type[object]] | Index[type[object]])
3648

37-
class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
49+
class StringMethods(
50+
NoNewAttributesMixin,
51+
Generic[T, _T_EXPANDING, _T_BOOL, _T_LIST_STR, _T_INT, _T_BYTES, _T_STR, _T_OBJECT],
52+
):
3853
def __init__(self, data: T) -> None: ...
39-
def __getitem__(self, key: slice | int) -> T: ...
40-
def __iter__(self) -> T: ...
54+
def __getitem__(self, key: slice | int) -> _T_STR: ...
55+
def __iter__(self) -> _T_STR: ...
4156
@overload
4257
def cat(
4358
self,
@@ -58,15 +73,17 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
5873
@overload
5974
def cat(
6075
self,
61-
others: Series | pd.Index | pd.DataFrame | np.ndarray | list[Any],
76+
others: (
77+
Series[str] | Index[str] | pd.DataFrame | npt.NDArray[np.str_] | list[str]
78+
),
6279
sep: str = ...,
6380
na_rep: str | None = ...,
6481
join: JoinHow = ...,
65-
) -> T: ...
82+
) -> _T_STR: ...
6683
@overload
6784
def split(
6885
self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ...
69-
) -> _TS: ...
86+
) -> _T_EXPANDING: ...
7087
@overload
7188
def split(
7289
self,
@@ -75,77 +92,79 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
7592
n: int = ...,
7693
expand: Literal[False] = ...,
7794
regex: bool = ...,
78-
) -> _TS2: ...
95+
) -> _T_LIST_STR: ...
7996
@overload
80-
def rsplit(self, pat: str = ..., *, n: int = ..., expand: Literal[True]) -> _TS: ...
97+
def rsplit(
98+
self, pat: str = ..., *, n: int = ..., expand: Literal[True]
99+
) -> _T_EXPANDING: ...
81100
@overload
82101
def rsplit(
83102
self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ...
84-
) -> _TS2: ...
103+
) -> _T_LIST_STR: ...
85104
@overload
86-
def partition(self, sep: str = ...) -> pd.DataFrame: ...
105+
def partition(self, sep: str = ...) -> _T_EXPANDING: ...
87106
@overload
88-
def partition(self, *, expand: Literal[True]) -> pd.DataFrame: ...
107+
def partition(self, *, expand: Literal[True]) -> _T_EXPANDING: ...
89108
@overload
90-
def partition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
109+
def partition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ...
91110
@overload
92-
def partition(self, sep: str, expand: Literal[False]) -> T: ...
111+
def partition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ...
93112
@overload
94-
def partition(self, *, expand: Literal[False]) -> T: ...
113+
def partition(self, *, expand: Literal[False]) -> _T_OBJECT: ...
95114
@overload
96-
def rpartition(self, sep: str = ...) -> pd.DataFrame: ...
115+
def rpartition(self, sep: str = ...) -> _T_EXPANDING: ...
97116
@overload
98-
def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ...
117+
def rpartition(self, *, expand: Literal[True]) -> _T_EXPANDING: ...
99118
@overload
100-
def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
119+
def rpartition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ...
101120
@overload
102-
def rpartition(self, sep: str, expand: Literal[False]) -> T: ...
121+
def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ...
103122
@overload
104-
def rpartition(self, *, expand: Literal[False]) -> T: ...
105-
def get(self, i: int) -> T: ...
106-
def join(self, sep: str) -> T: ...
123+
def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ...
124+
def get(self, i: int) -> _T_STR: ...
125+
def join(self, sep: str) -> _T_STR: ...
107126
def contains(
108127
self,
109-
pat: str | re.Pattern,
128+
pat: str | re.Pattern[str],
110129
case: bool = ...,
111130
flags: int = ...,
112-
na=...,
131+
na: Scalar | NaTType | None = ...,
113132
regex: bool = ...,
114-
) -> Series[bool]: ...
133+
) -> _T_BOOL: ...
115134
def match(
116135
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
117-
) -> _TM: ...
136+
) -> _T_BOOL: ...
118137
def replace(
119138
self,
120139
pat: str,
121-
repl: str | Callable[[re.Match], str],
140+
repl: str | Callable[[re.Match[str]], str],
122141
n: int = ...,
123142
case: bool | None = ...,
124143
flags: int = ...,
125144
regex: bool = ...,
126-
) -> T: ...
127-
def repeat(self, repeats: int | Sequence[int]) -> T: ...
145+
) -> _T_STR: ...
146+
def repeat(self, repeats: int | Sequence[int]) -> _T_STR: ...
128147
def pad(
129148
self,
130149
width: int,
131150
side: Literal["left", "right", "both"] = ...,
132151
fillchar: str = ...,
133-
) -> T: ...
134-
def center(self, width: int, fillchar: str = ...) -> T: ...
135-
def ljust(self, width: int, fillchar: str = ...) -> T: ...
136-
def rjust(self, width: int, fillchar: str = ...) -> T: ...
137-
def zfill(self, width: int) -> T: ...
152+
) -> _T_STR: ...
153+
def center(self, width: int, fillchar: str = ...) -> _T_STR: ...
154+
def ljust(self, width: int, fillchar: str = ...) -> _T_STR: ...
155+
def rjust(self, width: int, fillchar: str = ...) -> _T_STR: ...
156+
def zfill(self, width: int) -> _T_STR: ...
138157
def slice(
139158
self, start: int | None = ..., stop: int | None = ..., step: int | None = ...
140159
) -> T: ...
141160
def slice_replace(
142161
self, start: int | None = ..., stop: int | None = ..., repl: str | None = ...
143-
) -> T: ...
144-
def decode(self, encoding: str, errors: str = ...) -> T: ...
145-
def encode(self, encoding: str, errors: str = ...) -> T: ...
146-
def strip(self, to_strip: str | None = ...) -> T: ...
147-
def lstrip(self, to_strip: str | None = ...) -> T: ...
148-
def rstrip(self, to_strip: str | None = ...) -> T: ...
162+
) -> _T_STR: ...
163+
def decode(self, encoding: str, errors: str = ...) -> _T_STR: ...
164+
def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ...
165+
def strip(self, to_strip: str | None = ...) -> _T_STR: ...
166+
def lstrip(self, to_strip: str | None = ...) -> _T_STR: ...
167+
def rstrip(self, to_strip: str | None = ...) -> _T_STR: ...
149168
def wrap(
150169
self,
151170
width: int,
@@ -154,45 +173,47 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
154173
drop_whitespace: bool | None = ...,
155174
break_long_words: bool | None = ...,
156175
break_on_hyphens: bool | None = ...,
157-
) -> T: ...
158-
def get_dummies(self, sep: str = ...) -> pd.DataFrame: ...
159-
def translate(self, table: dict[int, int | str | None] | None) -> T: ...
160-
def count(self, pat: str, flags: int = ...) -> Series[int]: ...
161-
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ...
162-
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ...
163-
def findall(self, pat: str, flags: int = ...) -> Series: ...
176+
) -> _T_STR: ...
177+
def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ...
178+
def translate(self, table: dict[int, int | str | None] | None) -> _T_STR: ...
179+
def count(self, pat: str, flags: int = ...) -> _T_INT: ...
180+
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
181+
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
182+
def findall(self, pat: str, flags: int = ...) -> _T_LIST_STR: ...
164183
@overload
165184
def extract(
166185
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
167186
) -> pd.DataFrame: ...
168187
@overload
169-
def extract(self, pat: str, flags: int, expand: Literal[False]) -> T: ...
188+
def extract(self, pat: str, flags: int, expand: Literal[False]) -> _T_OBJECT: ...
170189
@overload
171-
def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> T: ...
190+
def extract(
191+
self, pat: str, flags: int = ..., *, expand: Literal[False]
192+
) -> _T_OBJECT: ...
172193
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
173-
def find(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
174-
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
175-
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ...
176-
def index(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
177-
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
178-
def len(self) -> Series[int]: ...
179-
def lower(self) -> T: ...
180-
def upper(self) -> T: ...
181-
def title(self) -> T: ...
182-
def capitalize(self) -> T: ...
183-
def swapcase(self) -> T: ...
184-
def casefold(self) -> T: ...
185-
def isalnum(self) -> Series[bool]: ...
186-
def isalpha(self) -> Series[bool]: ...
187-
def isdigit(self) -> Series[bool]: ...
188-
def isspace(self) -> Series[bool]: ...
189-
def islower(self) -> Series[bool]: ...
190-
def isupper(self) -> Series[bool]: ...
191-
def istitle(self) -> Series[bool]: ...
192-
def isnumeric(self) -> Series[bool]: ...
193-
def isdecimal(self) -> Series[bool]: ...
194+
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
195+
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
196+
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ...
197+
def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
198+
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
199+
def len(self) -> _T_INT: ...
200+
def lower(self) -> _T_STR: ...
201+
def upper(self) -> _T_STR: ...
202+
def title(self) -> _T_STR: ...
203+
def capitalize(self) -> _T_STR: ...
204+
def swapcase(self) -> _T_STR: ...
205+
def casefold(self) -> _T_STR: ...
206+
def isalnum(self) -> _T_BOOL: ...
207+
def isalpha(self) -> _T_BOOL: ...
208+
def isdigit(self) -> _T_BOOL: ...
209+
def isspace(self) -> _T_BOOL: ...
210+
def islower(self) -> _T_BOOL: ...
211+
def isupper(self) -> _T_BOOL: ...
212+
def istitle(self) -> _T_BOOL: ...
213+
def isnumeric(self) -> _T_BOOL: ...
214+
def isdecimal(self) -> _T_BOOL: ...
194215
def fullmatch(
195216
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
196-
) -> Series[bool]: ...
197-
def removeprefix(self, prefix: str) -> T: ...
198-
def removesuffix(self, suffix: str) -> T: ...
217+
) -> _T_BOOL: ...
218+
def removeprefix(self, prefix: str) -> _T_STR: ...
219+
def removesuffix(self, suffix: str) -> _T_STR: ...

0 commit comments

Comments
 (0)