Skip to content

Commit f61c367

Browse files
committed
typ: resolve typing for pandas-stubs>2.3.3
1 parent 82f3550 commit f61c367

File tree

9 files changed

+27
-23
lines changed

9 files changed

+27
-23
lines changed

plotnine/_utils/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
FloatArray,
3636
FloatArrayLike,
3737
HorizontalJustification,
38-
IntArray,
3938
Side,
4039
VerticalJustification,
4140
)
@@ -309,7 +308,7 @@ def ninteraction(df: pd.DataFrame, drop: bool = False) -> list[int]:
309308
def len_unique(x):
310309
return len(np.unique(x))
311310

312-
ndistinct: IntArray = ids.apply(len_unique, axis=0).to_numpy()
311+
ndistinct = ids.apply(len_unique, axis=0).to_numpy()
313312

314313
combs = np.array(np.hstack([1, np.cumprod(ndistinct[:-1])]))
315314
mat = np.array(ids)
@@ -743,7 +742,7 @@ def ungroup(data: DataLike) -> DataLike:
743742
"""Return an ungrouped DataFrame, or pass the original data back."""
744743

745744
if isinstance(data, DataFrameGroupBy):
746-
return data.obj
745+
return data.obj # pyright: ignore[reportReturnType]
747746

748747
return data
749748

plotnine/facets/facet_grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame:
189189

190190
n = len(base)
191191
panel = ninteraction(base, drop=True)
192-
panel = pd.Categorical(panel, categories=range(1, n + 1))
192+
panel = pd.Categorical(panel, categories=range(1, n + 1)) # pyright: ignore[reportArgumentType]
193193

194194
if self.rows:
195195
rows = ninteraction(base[self.rows], drop=True)

plotnine/facets/facet_wrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def compute_layout(
115115

116116
layout = pd.DataFrame(
117117
{
118-
"PANEL": pd.Categorical(range(1, n + 1)),
118+
"PANEL": pd.Categorical(range(1, n + 1)), # pyright: ignore[reportArgumentType]
119119
"ROW": row.astype(int),
120120
"COL": col.astype(int),
121121
}

plotnine/geoms/geom_map.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,18 @@ def draw_panel(
126126
params = self.params
127127
data.loc[data["color"].isna(), "color"] = "none"
128128
data.loc[data["fill"].isna(), "fill"] = "none"
129-
data["fill"] = to_rgba(data["fill"], data["alpha"])
130129

131130
geom_type = data.geometry.iloc[0].geom_type
132131
if geom_type in ("Polygon", "MultiPolygon"):
133132
from matplotlib.collections import PatchCollection
134133

135134
linewidth = data["size"] * SIZE_FACTOR
135+
fill = to_rgba(data["fill"], data["alpha"])
136136
patches = [PolygonPatch(g) for g in data["geometry"]]
137137
coll = PatchCollection(
138138
patches,
139139
edgecolor=data["color"],
140-
facecolor=data["fill"],
140+
facecolor=fill,
141141
linestyle=data["linetype"],
142142
linewidth=linewidth,
143143
zorder=params["zorder"],
@@ -152,7 +152,6 @@ def draw_panel(
152152
data["y"] = arr[:, 1]
153153
for _, gdata in data.groupby("group"):
154154
gdata.reset_index(inplace=True, drop=True)
155-
gdata.is_copy = None
156155
geom_point.draw_group(gdata, panel_params, coord, ax, params)
157156
elif geom_type == "MultiPoint":
158157
# Where n is the length of the dataframe (no. of multipoints),
@@ -173,7 +172,7 @@ def draw_panel(
173172
from matplotlib.collections import LineCollection
174173

175174
linewidth = data["size"] * SIZE_FACTOR
176-
data["color"] = to_rgba(data["color"], data["alpha"])
175+
color = to_rgba(data["color"], data["alpha"])
177176
segments = []
178177
for g in data["geometry"]:
179178
if g.geom_type == "LineString":
@@ -183,7 +182,7 @@ def draw_panel(
183182

184183
coll = LineCollection(
185184
segments,
186-
edgecolor=data["color"],
185+
edgecolor=color,
187186
linewidth=linewidth,
188187
linestyle=data["linetype"],
189188
zorder=params["zorder"],

plotnine/geoms/geom_path.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def draw(
302302
last = self.ends in ("last", "both")
303303

304304
data = data.sort_values("group", kind="mergesort")
305-
data["color"] = to_rgba(data["color"], data["alpha"])
305+
data["color"] = to_rgba(data["color"], data["alpha"]) # pyright: ignore[reportCallIssue,reportArgumentType]
306306

307307
if self.type == "open":
308308
data["facecolor"] = "none"

plotnine/geoms/geom_text.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
import typing
43
from contextlib import suppress
4+
from typing import TYPE_CHECKING, cast
55
from warnings import warn
66

77
import numpy as np
@@ -12,7 +12,7 @@
1212
from ..positions import position_nudge
1313
from .geom import geom
1414

15-
if typing.TYPE_CHECKING:
15+
if TYPE_CHECKING:
1616
from typing import Any, Sequence
1717

1818
import pandas as pd
@@ -244,7 +244,7 @@ def draw_group(
244244
axis=1,
245245
inplace=True,
246246
)
247-
plot_data["color"] = color
247+
plot_data["color"] = color # pyright: ignore[reportCallIssue,reportArgumentType]
248248
plot_data["zorder"] = zorder
249249
plot_data["rasterized"] = params["raster"]
250250
plot_data["clip_on"] = True
@@ -255,7 +255,7 @@ def draw_group(
255255
fill = to_rgba(data.pop("fill"), data["alpha"])
256256
if isinstance(fill, tuple):
257257
fill = [list(fill)] * len(data["x"])
258-
plot_data["facecolor"] = fill
258+
plot_data["facecolor"] = fill # pyright: ignore[reportCallIssue,reportArgumentType]
259259

260260
tokens = [params["boxstyle"], f"pad={params['label_padding']}"]
261261
if params["boxstyle"] in {"round", "round4"}:
@@ -272,7 +272,7 @@ def draw_group(
272272

273273
# For labels add a bbox
274274
for i in range(len(data)):
275-
kw: dict[str, Any] = plot_data.iloc[i].to_dict()
275+
kw = cast("dict[str, Any]", plot_data.iloc[i].to_dict())
276276
if draw_label:
277277
kw["bbox"] = bbox
278278
kw["bbox"]["edgecolor"] = params["boxcolor"] or kw["color"]

plotnine/stats/smoothers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
if TYPE_CHECKING:
1414
import statsmodels.api as sm
1515

16+
from plotnine.typing import FloatArray
17+
1618

1719
def predictdf(data, xseq, params) -> pd.DataFrame:
1820
"""
@@ -454,12 +456,14 @@ def gpr(data, xseq, params):
454456
if params["se"]:
455457
y, stderr = regressor.predict(Xseq, return_std=True)
456458
data["y"] = y
457-
data["se"] = stderr
459+
data["se"] = cast("FloatArray", stderr)
458460
data["ymin"], data["ymax"] = tdist_ci(
459461
y, n - 1, stderr, params["level"]
460462
)
461463
else:
462-
data["y"] = regressor.predict(Xseq, return_std=True)
464+
data["y"] = cast(
465+
"FloatArray", regressor.predict(Xseq, return_std=False)
466+
)
463467

464468
return data
465469

plotnine/stats/stat_qq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from .stat import stat
1212

1313
if TYPE_CHECKING:
14-
from typing import Any, Sequence
14+
from typing import Any
1515

16-
from plotnine.typing import FloatArray
16+
from plotnine.typing import FloatArray, FloatArrayLike
1717

1818

1919
# Note: distribution should be a name from scipy.stat.distribution
@@ -93,7 +93,7 @@ def theoretical_qq(
9393
distribution: str,
9494
alpha: float,
9595
beta: float,
96-
quantiles: Sequence[float] | None,
96+
quantiles: FloatArrayLike | None,
9797
distribution_params: dict[str, Any],
9898
) -> FloatArray:
9999
"""

plotnine/stats/stat_summary.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import cast
2+
13
import numpy as np
24
import pandas as pd
35

@@ -314,8 +316,8 @@ def compute_panel(self, data, scales):
314316
summaries = []
315317
for (group, x), df in data.groupby(["group", "x"]):
316318
summary = func(df)
317-
summary["x"] = x
318-
summary["group"] = group
319+
summary["x"] = x # pyright: ignore[reportCallIssue,reportArgumentType]
320+
summary["group"] = cast("int", group)
319321
summary["n"] = len(df)
320322
unique = uniquecols(df)
321323
if "y" in unique:

0 commit comments

Comments
 (0)