Skip to content

Commit d234850

Browse files
committed
chore: add in accessors for deprecation warnings
1 parent 182d863 commit d234850

5 files changed

Lines changed: 51 additions & 23 deletions

File tree

src/scanpy/plotting/_anndata.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from matplotlib import colormaps, gridspec, patheffects, rcParams
1414
from matplotlib import pyplot as plt
1515
from matplotlib.colors import is_color_like
16+
from packaging.version import Version
1617
from pandas.api.types import CategoricalDtype, is_numeric_dtype
1718

1819
from .. import get
1920
from .. import logging as logg
20-
from .._compat import CSBase, old_positionals
21+
from .._compat import CSBase, old_positionals, pkg_version
2122
from .._settings import settings
2223
from .._utils import (
2324
_doc_params,
@@ -55,13 +56,7 @@
5556
from seaborn.matrix import ClusterGrid
5657

5758
from .._utils import Empty
58-
from ._utils import (
59-
ColorLike,
60-
DensityNorm,
61-
_FontSize,
62-
_FontWeight,
63-
_LegendLoc,
64-
)
59+
from ._utils import ColorLike, DensityNorm, _FontSize, _FontWeight, _LegendLoc
6560

6661
# TODO: is that all?
6762
type _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"]
@@ -319,6 +314,9 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
319314
if title is not None and isinstance(title, str):
320315
title = [title]
321316
highlights = adata.uns.get("highlights", [])
317+
is_anndata_13 = pkg_version("anndata") >= Version("0.13.0rc0")
318+
if is_anndata_13:
319+
from anndata.acc import A
322320
if basis is not None:
323321
try:
324322
# ignore the '0th' diffusion component
@@ -334,16 +332,28 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
334332
elif x is not None and y is not None:
335333
if use_raw:
336334
if x in adata.obs.columns:
337-
x_arr = adata.obs_vector(x)
335+
x_arr = adata[A.obs[x]] if is_anndata_13 else adata.obs_vector(x)
338336
else:
339-
x_arr = adata.raw.obs_vector(x)
337+
x_arr = (
338+
adata.raw[A.X[:, x]] if is_anndata_13 else adata.raw.obs_vector(x)
339+
)
340340
if y in adata.obs.columns:
341-
y_arr = adata.obs_vector(y)
341+
y_arr = adata[A.obs[y]] if is_anndata_13 else adata.obs_vector(y)
342342
else:
343-
y_arr = adata.raw.obs_vector(y)
343+
y_arr = (
344+
adata.raw[A.X[:, y]] if is_anndata_13 else adata.raw.obs_vector(y)
345+
)
344346
else:
345-
x_arr = adata.obs_vector(x, layer=layers[0])
346-
y_arr = adata.obs_vector(y, layer=layers[1])
347+
x_arr = (
348+
adata[A.layers[layers[0]][:, x]]
349+
if is_anndata_13
350+
else adata.obs_vector(x, layer=layers[0])
351+
)
352+
y_arr = (
353+
adata[A.layers[layers[1]][:, y]]
354+
if is_anndata_13
355+
else adata.obs_vector(y, layer=layers[1])
356+
)
347357

348358
xy = np.c_[x_arr, y_arr]
349359
else:
@@ -400,9 +410,13 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
400410
c = adata.obs[key].to_numpy()
401411
# coloring according to gene expression
402412
elif use_raw and adata.raw is not None and key in adata.raw.var_names:
403-
c = adata.raw.obs_vector(key)
413+
c = adata[A.obs[key]] if is_anndata_13 else adata.raw.obs_vector(key)
404414
elif key in adata.var_names:
405-
c = adata.obs_vector(key, layer=layers[2])
415+
c = (
416+
adata[A.layers[layers[2][:, key]]]
417+
if is_anndata_13
418+
else adata.obs_vector(key, layer=layers[2])
419+
)
406420
elif is_color_like(key): # a flat color
407421
c = key
408422
colorbar = False

src/scanpy/plotting/_tools/scatterplots.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from matplotlib import pyplot as plt
1717
from matplotlib.colors import Normalize
1818
from matplotlib.markers import MarkerStyle
19+
from packaging.version import Version
1920

2021
from ... import logging as logg
21-
from ..._compat import deprecated
22+
from ..._compat import deprecated, pkg_version
2223
from ..._settings import settings
2324
from ..._utils import _doc_params, _empty, sanitize_anndata
2425
from ..._utils._doctests import doctest_internet
@@ -1221,10 +1222,21 @@ def _get_color_source_vector(
12211222
# We should probably just make an index for this, and share it over runs
12221223
# TODO: Throw helpful error if this doesn't work
12231224
value_to_plot = adata.var.index[adata.var[gene_symbols] == value_to_plot][0]
1225+
is_anndata_13 = pkg_version("anndata") >= Version("0.13.0rc0")
1226+
if is_anndata_13:
1227+
from anndata.acc import A
12241228
if use_raw and value_to_plot not in adata.obs.columns:
1225-
values = adata.raw.obs_vector(value_to_plot)
1229+
values = (
1230+
adata.raw[A.X[:, value_to_plot]]
1231+
if is_anndata_13
1232+
else adata.raw.obs_vector(value_to_plot)
1233+
)
12261234
else:
1227-
values = adata.obs_vector(value_to_plot, layer=layer)
1235+
values = (
1236+
adata[A.layers[layer][:, value_to_plot]]
1237+
if is_anndata_13
1238+
else adata.obs_vector(value_to_plot, layer=layer)
1239+
)
12281240
if mask_obs is not None:
12291241
values = values.copy()
12301242
values[~mask_obs] = np.nan

tests/test_get.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,15 @@ def test_column_content():
209209
for col in query:
210210
assert col in df
211211
np.testing.assert_array_equal(query, df.columns)
212-
np.testing.assert_array_equal(df[col].values, adata.obs_vector(col))
212+
np.testing.assert_array_equal(df[col].values, adata.obs[col].values)
213213

214214
# test that columns content is correct for var_df
215215
cell_ids = list(adata.obs.sample(5).index)
216216
query = [*cell_ids, "highly_variable", "dispersions_norm", "dispersions"]
217217
df = sc.get.var_df(adata, query)
218218
np.testing.assert_array_equal(query, df.columns)
219219
for col in query:
220-
np.testing.assert_array_equal(df[col].values, adata.var_vector(col))
220+
np.testing.assert_array_equal(df[col].values, adata.var[col].values)
221221

222222

223223
def test_var_df(adata: AnnData):

tests/test_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def test_consistency(metric) -> None:
6868
)
6969

7070
all_genes = metric(pbmc, layer="raw")
71-
first_gene = metric(pbmc, vals=pbmc.obs_vector(pbmc.var_names[0], layer="raw"))
71+
first_gene = metric(
72+
pbmc, vals=pbmc[:, pbmc.var_names[0]].layers["raw"].toarray().ravel()
73+
)
7274

7375
np.testing.assert_allclose(all_genes[0], first_gene, rtol=1e-9)
7476

tests/test_plotting_embedded/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def adata():
5858
adata.obs["label_missing"] = adata.obs["label"].copy()
5959
adata.obs.loc[::2, "label_missing"] = np.nan
6060

61-
adata.obs["1_missing"] = adata.obs_vector("1")
61+
adata.obs["1_missing"] = adata.obs["1"]
6262
adata.obs.loc[
6363
adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing"
6464
] = np.nan

0 commit comments

Comments
 (0)