Skip to content

Commit 71b2555

Browse files
df.apply should allow pd.NA from Callables (#961)
1 parent e78aaca commit 71b2555

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pandas-stubs/core/frame.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ class DataFrame(NDFrame, OpsMixin):
12241224
@overload
12251225
def apply(
12261226
self,
1227-
f: Callable[..., S1],
1227+
f: Callable[..., S1 | NAType],
12281228
axis: AxisIndex = ...,
12291229
raw: _bool = ...,
12301230
result_type: None = ...,
@@ -1248,7 +1248,7 @@ class DataFrame(NDFrame, OpsMixin):
12481248
@overload
12491249
def apply(
12501250
self,
1251-
f: Callable[..., S1],
1251+
f: Callable[..., S1 | NAType],
12521252
axis: Axis = ...,
12531253
raw: _bool = ...,
12541254
args: Any = ...,
@@ -1309,7 +1309,7 @@ class DataFrame(NDFrame, OpsMixin):
13091309
@overload
13101310
def apply(
13111311
self,
1312-
f: Callable[..., S1],
1312+
f: Callable[..., S1 | NAType],
13131313
raw: _bool = ...,
13141314
result_type: None = ...,
13151315
args: Any = ...,

tests/test_frame.py

+9
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
import xarray as xr
4545

46+
from pandas._libs.missing import NAType
4647
from pandas._typing import Scalar
4748

4849
from tests import (
@@ -578,6 +579,9 @@ def test_types_apply() -> None:
578579
def returns_scalar(x: pd.Series) -> int:
579580
return 2
580581

582+
def returns_scalar_na(x: pd.Series) -> int | NAType:
583+
return 2 if (x < 5).all() else pd.NA
584+
581585
def returns_series(x: pd.Series) -> pd.Series:
582586
return x**2
583587

@@ -604,6 +608,11 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
604608
check(
605609
assert_type(df.apply(returns_scalar), "pd.Series[int]"), pd.Series, np.integer
606610
)
611+
check(
612+
assert_type(df.apply(returns_scalar_na), "pd.Series[int]"),
613+
pd.Series,
614+
int,
615+
)
607616
check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame)
608617
check(assert_type(df.apply(returns_listlike_of_3), pd.DataFrame), pd.DataFrame)
609618
check(assert_type(df.apply(returns_dict), pd.Series), pd.Series)

0 commit comments

Comments
 (0)