Skip to content

Commit 253825d

Browse files
Backport PR #3974 on branch 1.12.x (chore: add in accessors to silence deprecation warnings) (#3976)
Co-authored-by: Ilan Gold <ilanbassgold@gmail.com>
1 parent d9a8b6a commit 253825d

6 files changed

Lines changed: 55 additions & 35 deletions

File tree

src/scanpy/plotting/_anndata.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ._utils import (
3737
_deprecated_scale,
3838
_dk,
39+
_obs_vector_compat,
3940
check_colornorm,
4041
scatter_base,
4142
scatter_group,
@@ -55,13 +56,7 @@
5556
from seaborn.matrix import ClusterGrid
5657

5758
from .._utils import Empty
58-
from ._utils import (
59-
ColorLike,
60-
DensityNorm,
61-
_FontSize,
62-
_FontWeight,
63-
_LegendLoc,
64-
)
59+
from ._utils import ColorLike, DensityNorm, _FontSize, _FontWeight, _LegendLoc
6560

6661
# TODO: is that all?
6762
type _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"]
@@ -324,27 +319,18 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
324319
# ignore the '0th' diffusion component
325320
if basis == "diffmap":
326321
components += 1
327-
xy = adata.obsm["X_" + basis][:, components]
322+
xy = adata.obsm[f"X_{basis}"][:, components]
328323
# correct the component vector for use in labeling etc.
329324
if basis == "diffmap":
330325
components -= 1
331326
except KeyError:
332327
msg = f"compute coordinates using visualization tool {basis} first"
333328
raise KeyError(msg) from None
334329
elif x is not None and y is not None:
335-
if use_raw:
336-
if x in adata.obs.columns:
337-
x_arr = adata.obs_vector(x)
338-
else:
339-
x_arr = adata.raw.obs_vector(x)
340-
if y in adata.obs.columns:
341-
y_arr = adata.obs_vector(y)
342-
else:
343-
y_arr = adata.raw.obs_vector(y)
344-
else:
345-
x_arr = adata.obs_vector(x, layer=layers[0])
346-
y_arr = adata.obs_vector(y, layer=layers[1])
347-
330+
x_arr, y_arr = (
331+
_obs_vector_compat(adata, k, use_raw=use_raw, layer=layer)
332+
for layer, k in zip(layers, [x, y], strict=False)
333+
)
348334
xy = np.c_[x_arr, y_arr]
349335
else:
350336
msg = "Either provide a `basis` or `x` and `y`."
@@ -399,10 +385,10 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
399385
else:
400386
c = adata.obs[key].to_numpy()
401387
# coloring according to gene expression
402-
elif use_raw and adata.raw is not None and key in adata.raw.var_names:
403-
c = adata.raw.obs_vector(key)
404-
elif key in adata.var_names:
405-
c = adata.obs_vector(key, layer=layers[2])
388+
elif (use_raw and adata.raw is not None and key in adata.raw.var_names) or (
389+
key in adata.var_names
390+
):
391+
c = _obs_vector_compat(adata, key, use_raw=use_raw, layer=layers[2])
406392
elif is_color_like(key): # a flat color
407393
c = key
408394
colorbar = False

src/scanpy/plotting/_tools/scatterplots.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
doc_scatter_spatial,
3232
doc_show_save_ax,
3333
)
34-
from .._utils import check_colornorm, check_projection, circles
34+
from .._utils import _obs_vector_compat, check_colornorm, check_projection, circles
3535

3636
if TYPE_CHECKING:
3737
from collections.abc import Callable, Collection, Mapping
@@ -266,6 +266,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
266266
# ]
267267
for count, (value_to_plot, dims) in enumerate(zip(color, dimensions, strict=True)):
268268
kwargs_scatter = kwargs.copy() # is potentially mutated for each plot
269+
# TODO: It might be worth not returning `NumpyExtensionArray` objects out of the dataframes via accessors because we have a lot of np.ndarray checks.
270+
# Setting np.array here prevents the `NumpyExtensionArray` from propagating.
269271
color_source_vector = _get_color_source_vector(
270272
adata,
271273
value_to_plot,
@@ -275,6 +277,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
275277
gene_symbols=gene_symbols,
276278
groups=groups,
277279
)
280+
if isinstance(color_source_vector, pd.arrays.NumpyExtensionArray):
281+
color_source_vector = color_source_vector.to_numpy()
278282
color_vector, color_type = _color_vector(
279283
adata,
280284
value_to_plot,
@@ -1221,10 +1225,7 @@ def _get_color_source_vector(
12211225
# We should probably just make an index for this, and share it over runs
12221226
# TODO: Throw helpful error if this doesn't work
12231227
value_to_plot = adata.var.index[adata.var[gene_symbols] == value_to_plot][0]
1224-
if use_raw and value_to_plot not in adata.obs.columns:
1225-
values = adata.raw.obs_vector(value_to_plot)
1226-
else:
1227-
values = adata.obs_vector(value_to_plot, layer=layer)
1228+
values = _obs_vector_compat(adata, value_to_plot, use_raw=use_raw, layer=layer)
12281229
if mask_obs is not None:
12291230
values = values.copy()
12301231
values[~mask_obs] = np.nan

src/scanpy/plotting/_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from matplotlib.figure import Figure
3030
from matplotlib.typing import MarkerType
3131
from numpy.typing import ArrayLike
32+
from pandas.api.extensions import ExtensionArray
3233
from PIL.Image import Image
3334

3435
from .._utils import Empty
@@ -44,6 +45,7 @@
4445
"_create_white_to_color_gradient",
4546
"_deprecated_scale",
4647
"_dk",
48+
"_obs_vector_compat",
4749
"add_colors_for_categorical_sample_annotation",
4850
"check_colornorm",
4951
"check_projection",
@@ -1167,3 +1169,23 @@ def _create_white_to_color_gradient(
11671169
return ListedColormap(
11681170
clipped_rgb, name=color if isinstance(color, str) else hex_color
11691171
)
1172+
1173+
1174+
def _obs_vector_compat(
1175+
adata: AnnData, k: str, *, use_raw: bool, layer: str | None
1176+
) -> np.ndarray | ExtensionArray:
1177+
try:
1178+
from anndata.acc import A
1179+
except ImportError:
1180+
return (
1181+
adata.raw.obs_vector(k)
1182+
if use_raw and k not in adata.obs.columns
1183+
else adata.obs_vector(k, layer=layer)
1184+
)
1185+
1186+
if k in adata.obs.columns:
1187+
return adata[A.obs[k]]
1188+
elif not use_raw:
1189+
return adata[A.layers[layer][:, k]]
1190+
else:
1191+
return adata.raw[A.X[:, k]]

tests/test_get.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,20 +204,26 @@ def test_column_content():
204204
adata = pbmc68k_reduced()
205205

206206
# test that columns content is correct for obs_df
207-
query = ["CST3", "NKG7", "GNLY", "louvain", "n_counts", "n_genes"]
207+
cols = ["louvain", "n_counts", "n_genes"]
208+
query = [*cols, "CST3", "NKG7", "GNLY"]
208209
df = sc.get.obs_df(adata, query)
209210
for col in query:
210211
assert col in df
211212
np.testing.assert_array_equal(query, df.columns)
212-
np.testing.assert_array_equal(df[col].values, adata.obs_vector(col))
213+
np.testing.assert_array_equal(
214+
df[col].values, adata.obs[col] if col in cols else adata[:, col].X.ravel()
215+
)
213216

214217
# test that columns content is correct for var_df
215218
cell_ids = list(adata.obs.sample(5).index)
216219
query = [*cell_ids, "highly_variable", "dispersions_norm", "dispersions"]
217220
df = sc.get.var_df(adata, query)
218221
np.testing.assert_array_equal(query, df.columns)
219222
for col in query:
220-
np.testing.assert_array_equal(df[col].values, adata.var_vector(col))
223+
np.testing.assert_array_equal(
224+
df[col].values,
225+
adata[col, :].X.ravel() if col in cell_ids else adata.var[col],
226+
)
221227

222228

223229
def test_var_df(adata: AnnData):

tests/test_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def test_consistency(metric) -> None:
6868
)
6969

7070
all_genes = metric(pbmc, layer="raw")
71-
first_gene = metric(pbmc, vals=pbmc.obs_vector(pbmc.var_names[0], layer="raw"))
71+
first_gene = metric(
72+
pbmc, vals=pbmc[:, pbmc.var_names[0]].layers["raw"].toarray().ravel()
73+
)
7274

7375
np.testing.assert_allclose(all_genes[0], first_gene, rtol=1e-9)
7476

tests/test_plotting_embedded/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def adata():
5858
adata.obs["label_missing"] = adata.obs["label"].copy()
5959
adata.obs.loc[::2, "label_missing"] = np.nan
6060

61-
adata.obs["1_missing"] = adata.obs_vector("1")
61+
# TODO: If we don't `copy`, something about this being an ArrayView means that all values get set to nan?
62+
# https://github.com/scverse/anndata/issues/2348
63+
adata.obs["1_missing"] = adata[:, "1"].X.flatten().copy()
64+
6265
adata.obs.loc[
6366
adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing"
6467
] = np.nan

0 commit comments

Comments
 (0)