Skip to content

Commit b103908

Browse files
flying-sheepilan-goldpre-commit-ci[bot]
authored
fix: add compat with pandas 3 (#3929)
Co-authored-by: Ilan Gold <ilanbassgold@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 28a1ed4 commit b103908

19 files changed

Lines changed: 196 additions & 104 deletions

File tree

.github/workflows/ci.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,13 @@ jobs:
6060

6161
- name: Install dependencies
6262
run: |
63-
uv tool install --with='click!=8.3.0' hatch
63+
echo "::group::Install hatch"
64+
uv tool install hatch
65+
echo "::endgroup::"
66+
echo "::group::Create environment"
6467
hatch -v env create ${{ matrix.env.name }}
68+
echo "::endgroup::"
69+
hatch run ${{ matrix.env.name }}:session-info scanpy anndata
6570
6671
- name: Run tests
6772
if: matrix.env.test-type == null

docs/release-notes/3929.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix compatibility with pandas 3.0 {smaller}`P Angerer`

hatch.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ overrides.matrix.deps.python = [
3535
]
3636
overrides.matrix.deps.extra-dependencies = [
3737
{ if = [ "pre" ], value = "anndata @ git+https://github.com/scverse/anndata.git" },
38+
{ if = [ "pre" ], value = "pandas>=3rc0" },
3839
]
3940
overrides.matrix.deps.dependency-groups = [
4041
{ if = [ "stable", "pre", "low-vers" ], value = "test" },

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ dependencies = [
5454
"numpy>=2",
5555
"fast-array-utils[accel,sparse]>=1.2.1",
5656
"matplotlib>=3.9",
57-
"pandas >=2.2.2, <3.0.0rc0",
57+
"pandas >=2.2.2",
5858
"scipy>=1.13",
5959
"seaborn>=0.13.2",
6060
"h5py>=3.11",

src/scanpy/_utils/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import h5py
3131
import numpy as np
32+
import pandas as pd
3233
from anndata._core.sparse_dataset import BaseCompressedSparseDataset
3334
from packaging.version import Version
3435

@@ -44,6 +45,7 @@
4445
from anndata import AnnData
4546
from igraph import Graph
4647
from numpy.typing import ArrayLike, NDArray
48+
from pandas._typing import Dtype as PdDtype
4749

4850
from .._compat import CSRBase
4951
from ..neighbors import NeighborsParams, RPForestDict
@@ -79,6 +81,7 @@
7981
"sanitize_anndata",
8082
"select_groups",
8183
"update_params",
84+
"with_cat_dtype",
8285
]
8386

8487

@@ -287,7 +290,7 @@ def get_igraph_from_adjacency(adjacency: CSBase, *, directed: bool = False) -> G
287290
import igraph as ig
288291

289292
sources, targets = adjacency.nonzero()
290-
weights = dematrix(adjacency[sources, targets]).ravel()
293+
weights = dematrix(adjacency[sources, targets]).ravel() if len(sources) else []
291294
g = ig.Graph(directed=directed)
292295
g.add_vertices(adjacency.shape[0]) # this adds adjacency.shape[0] vertices
293296
g.add_edges(list(zip(sources, targets, strict=True)))
@@ -494,6 +497,23 @@ def moving_average(a: np.ndarray, n: int):
494497
return ret[n - 1 :] / n
495498

496499

500+
@singledispatch
501+
def with_cat_dtype[X: pd.Series | pd.CategoricalIndex | pd.Categorical](
502+
x: X, dtype: PdDtype
503+
) -> X:
504+
raise NotImplementedError
505+
506+
507+
@with_cat_dtype.register(pd.Series)
508+
def _(x: pd.Series, dtype: PdDtype) -> pd.Series:
509+
return x.cat.set_categories(x.cat.categories.astype(dtype))
510+
511+
512+
@with_cat_dtype.register(pd.Categorical | pd.CategoricalIndex)
513+
def _[X: pd.Categorical | pd.CategoricalIndex](x: X, dtype: PdDtype) -> X:
514+
return x.set_categories(x.categories.astype(dtype))
515+
516+
497517
# --------------------------------------------------------------------------------
498518
# Deal with tool parameters
499519
# --------------------------------------------------------------------------------

src/scanpy/external/exporting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def spring_project( # noqa: PLR0912, PLR0915
219219
np.save(subplot_dir / "cell_filter.npy", np.arange(x.shape[0]))
220220

221221
# Write 2-D coordinates, after adjusting to roughly match SPRING's default d3js force layout parameters
222-
coords = coords - coords.min(0)[None, :]
222+
coords = coords - coords.min(axis=0)[None, :]
223223
coords = (
224224
coords * (np.array([1000, 1000]) / coords.ptp(0))[None, :]
225225
+ np.array([200, -200])[None, :]
@@ -342,8 +342,8 @@ def _get_color_stats_genes(color_stats, x, gene_list):
342342
means, variances = mean_var(x, axis=0, correction=1)
343343
stdevs = np.zeros(variances.shape, dtype=float)
344344
stdevs[variances > 0] = np.sqrt(variances[variances > 0])
345-
mins = x.min(0).todense().A1
346-
maxes = x.max(0).todense().A1
345+
mins = x.min(axis=0).todense().A1
346+
maxes = x.max(axis=0).todense().A1
347347

348348
pctl = 99.6
349349
pctl_n = (100 - pctl) / 100.0 * x.shape[0]

src/scanpy/get/get.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ def obs_df(
259259
>>> plotdf = sc.get.obs_df(
260260
... pbmc, keys=["CD8B", "n_genes"], obsm_keys=[("X_umap", 0), ("X_umap", 1)]
261261
... )
262-
>>> plotdf.columns
263-
Index(['CD8B', 'n_genes', 'X_umap-0', 'X_umap-1'], dtype='object')
262+
>>> plotdf.columns.astype("string")
263+
Index(['CD8B', 'n_genes', 'X_umap-0', 'X_umap-1'], dtype='string')
264264
>>> plotdf.plot.scatter("X_umap-0", "X_umap-1", c="CD8B") # doctest: +SKIP
265265
<Axes: xlabel='X_umap-0', ylabel='X_umap-1'>
266266

src/scanpy/plotting/_anndata.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
749749
layer: str | None = None,
750750
density_norm: DensityNorm = "width",
751751
order: Sequence[str] | None = None,
752-
multi_panel: bool | None = None,
752+
multi_panel: bool = False,
753753
xlabel: str = "",
754754
ylabel: str | Sequence[str] | None = None,
755755
rotation: float | None = None,
@@ -1202,11 +1202,11 @@ def heatmap( # noqa: PLR0912, PLR0913, PLR0915
12021202
).issubset(categories)
12031203

12041204
if standard_scale == "obs":
1205-
obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0)
1206-
obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0)
1205+
obs_tidy = obs_tidy.sub(obs_tidy.min(axis=1), axis=0)
1206+
obs_tidy = obs_tidy.div(obs_tidy.max(axis=1), axis=0).fillna(0)
12071207
elif standard_scale == "var":
1208-
obs_tidy -= obs_tidy.min(0)
1209-
obs_tidy = (obs_tidy / obs_tidy.max(0)).fillna(0)
1208+
obs_tidy -= obs_tidy.min(axis=0)
1209+
obs_tidy = (obs_tidy / obs_tidy.max(axis=0)).fillna(0)
12101210
elif standard_scale is None:
12111211
pass
12121212
else:

src/scanpy/plotting/_dotplot.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,13 @@ def __init__( # noqa: PLR0913
213213
dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean()
214214

215215
if standard_scale == "group":
216-
dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0)
217-
dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0)
216+
dot_color_df = dot_color_df.sub(dot_color_df.min(axis=1), axis=0)
217+
dot_color_df = dot_color_df.div(
218+
dot_color_df.max(axis=1), axis=0
219+
).fillna(0)
218220
elif standard_scale == "var":
219-
dot_color_df -= dot_color_df.min(0)
220-
dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0)
221+
dot_color_df -= dot_color_df.min(axis=0)
222+
dot_color_df = (dot_color_df / dot_color_df.max(axis=0)).fillna(0)
221223
elif standard_scale is None:
222224
pass
223225
else:
@@ -696,10 +698,10 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915
696698
group_axis = 1
697699
if standard_scale is not None:
698700
dot_color = dot_color.sub(
699-
dot_color.min((group_axis + 1) % 2), axis=group_axis
701+
dot_color.min(axis=1 - group_axis), axis=group_axis
700702
)
701703
dot_color = dot_color.div(
702-
dot_color.max((group_axis + 1) % 2), axis=group_axis
704+
dot_color.max(axis=1 - group_axis), axis=group_axis
703705
).fillna(0)
704706
# make scatter plot in which
705707
# x = var_names

src/scanpy/plotting/_matrixplot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def __init__( # noqa: PLR0913
180180
)
181181

182182
if standard_scale == "group":
183-
values_df = values_df.sub(values_df.min(1), axis=0)
184-
values_df = values_df.div(values_df.max(1), axis=0).fillna(0)
183+
values_df = values_df.sub(values_df.min(axis=1), axis=0)
184+
values_df = values_df.div(values_df.max(axis=1), axis=0).fillna(0)
185185
elif standard_scale == "var":
186-
values_df -= values_df.min(0)
187-
values_df = (values_df / values_df.max(0)).fillna(0)
186+
values_df -= values_df.min(axis=0)
187+
values_df = (values_df / values_df.max(axis=0)).fillna(0)
188188
elif standard_scale is None:
189189
pass
190190
else:

0 commit comments

Comments
 (0)