diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index 30156bf482..97d82ab7c1 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -36,6 +36,7 @@ from ._utils import ( _deprecated_scale, _dk, + _obs_vector_compat, check_colornorm, scatter_base, scatter_group, @@ -55,13 +56,7 @@ from seaborn.matrix import ClusterGrid from .._utils import Empty - from ._utils import ( - ColorLike, - DensityNorm, - _FontSize, - _FontWeight, - _LegendLoc, - ) + from ._utils import ColorLike, DensityNorm, _FontSize, _FontWeight, _LegendLoc # TODO: is that all? type _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"] @@ -324,7 +319,7 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915 # ignore the '0th' diffusion component if basis == "diffmap": components += 1 - xy = adata.obsm["X_" + basis][:, components] + xy = adata.obsm[f"X_{basis}"][:, components] # correct the component vector for use in labeling etc. if basis == "diffmap": components -= 1 @@ -332,19 +327,10 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915 msg = f"compute coordinates using visualization tool {basis} first" raise KeyError(msg) from None elif x is not None and y is not None: - if use_raw: - if x in adata.obs.columns: - x_arr = adata.obs_vector(x) - else: - x_arr = adata.raw.obs_vector(x) - if y in adata.obs.columns: - y_arr = adata.obs_vector(y) - else: - y_arr = adata.raw.obs_vector(y) - else: - x_arr = adata.obs_vector(x, layer=layers[0]) - y_arr = adata.obs_vector(y, layer=layers[1]) - + x_arr, y_arr = ( + _obs_vector_compat(adata, k, use_raw=use_raw, layer=layer) + for layer, k in zip(layers, [x, y], strict=False) + ) xy = np.c_[x_arr, y_arr] else: msg = "Either provide a `basis` or `x` and `y`." @@ -399,10 +385,10 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915 else: c = adata.obs[key].to_numpy() # coloring according to gene expression - elif use_raw and adata.raw is not None and key in adata.raw.var_names: - c = adata.raw.obs_vector(key) - elif key in adata.var_names: - c = adata.obs_vector(key, layer=layers[2]) + elif (use_raw and adata.raw is not None and key in adata.raw.var_names) or ( + key in adata.var_names + ): + c = _obs_vector_compat(adata, key, use_raw=use_raw, layer=layers[2]) elif is_color_like(key): # a flat color c = key colorbar = False diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 5ac7c1cda3..f4044a9fa0 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -31,7 +31,7 @@ doc_scatter_spatial, doc_show_save_ax, ) -from .._utils import check_colornorm, check_projection, circles +from .._utils import _obs_vector_compat, check_colornorm, check_projection, circles if TYPE_CHECKING: from collections.abc import Callable, Collection, Mapping @@ -266,6 +266,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 # ] for count, (value_to_plot, dims) in enumerate(zip(color, dimensions, strict=True)): kwargs_scatter = kwargs.copy() # is potentially mutated for each plot + # TODO: It might be worth not returning `NumpyExtensionArray` objects out of the dataframes via accessors because we have a lot of np.ndarray checks. + # Setting np.array here prevents the `NumpyExtensionArray` from propagating. color_source_vector = _get_color_source_vector( adata, value_to_plot, @@ -275,6 +277,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 gene_symbols=gene_symbols, groups=groups, ) + if isinstance(color_source_vector, pd.arrays.NumpyExtensionArray): + color_source_vector = color_source_vector.to_numpy() color_vector, color_type = _color_vector( adata, value_to_plot, @@ -1221,10 +1225,7 @@ def _get_color_source_vector( # We should probably just make an index for this, and share it over runs # TODO: Throw helpful error if this doesn't work value_to_plot = adata.var.index[adata.var[gene_symbols] == value_to_plot][0] - if use_raw and value_to_plot not in adata.obs.columns: - values = adata.raw.obs_vector(value_to_plot) - else: - values = adata.obs_vector(value_to_plot, layer=layer) + values = _obs_vector_compat(adata, value_to_plot, use_raw=use_raw, layer=layer) if mask_obs is not None: values = values.copy() values[~mask_obs] = np.nan diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index d648e3569b..91c79ccd3b 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -29,6 +29,7 @@ from matplotlib.figure import Figure from matplotlib.typing import MarkerType from numpy.typing import ArrayLike + from pandas.api.extensions import ExtensionArray from PIL.Image import Image from .._utils import Empty @@ -44,6 +45,7 @@ "_create_white_to_color_gradient", "_deprecated_scale", "_dk", + "_obs_vector_compat", "add_colors_for_categorical_sample_annotation", "check_colornorm", "check_projection", @@ -1167,3 +1169,23 @@ def _create_white_to_color_gradient( return ListedColormap( clipped_rgb, name=color if isinstance(color, str) else hex_color ) + + +def _obs_vector_compat( + adata: AnnData, k: str, *, use_raw: bool, layer: str | None +) -> np.ndarray | ExtensionArray: + try: + from anndata.acc import A + except ImportError: + return ( + adata.raw.obs_vector(k) + if use_raw and k not in adata.obs.columns + else adata.obs_vector(k, layer=layer) + ) + + if k in adata.obs.columns: + return adata[A.obs[k]] + elif not use_raw: + return adata[A.layers[layer][:, k]] + else: + return adata.raw[A.X[:, k]] diff --git a/tests/test_get.py b/tests/test_get.py index 20a36dee24..0f8d28c22f 100644 --- a/tests/test_get.py +++ b/tests/test_get.py @@ -204,12 +204,15 @@ def test_column_content(): adata = pbmc68k_reduced() # test that columns content is correct for obs_df - query = ["CST3", "NKG7", "GNLY", "louvain", "n_counts", "n_genes"] + cols = ["louvain", "n_counts", "n_genes"] + query = [*cols, "CST3", "NKG7", "GNLY"] df = sc.get.obs_df(adata, query) for col in query: assert col in df np.testing.assert_array_equal(query, df.columns) - np.testing.assert_array_equal(df[col].values, adata.obs_vector(col)) + np.testing.assert_array_equal( + df[col].values, adata.obs[col] if col in cols else adata[:, col].X.ravel() + ) # test that columns content is correct for var_df cell_ids = list(adata.obs.sample(5).index) @@ -217,7 +220,10 @@ def test_column_content(): df = sc.get.var_df(adata, query) np.testing.assert_array_equal(query, df.columns) for col in query: - np.testing.assert_array_equal(df[col].values, adata.var_vector(col)) + np.testing.assert_array_equal( + df[col].values, + adata[col, :].X.ravel() if col in cell_ids else adata.var[col], + ) def test_var_df(adata: AnnData): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 40c6cbfa58..fb44e7cb9f 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -68,7 +68,9 @@ def test_consistency(metric) -> None: ) all_genes = metric(pbmc, layer="raw") - first_gene = metric(pbmc, vals=pbmc.obs_vector(pbmc.var_names[0], layer="raw")) + first_gene = metric( + pbmc, vals=pbmc[:, pbmc.var_names[0]].layers["raw"].toarray().ravel() + ) np.testing.assert_allclose(all_genes[0], first_gene, rtol=1e-9) diff --git a/tests/test_plotting_embedded/conftest.py b/tests/test_plotting_embedded/conftest.py index 42be824899..7b30bad49a 100644 --- a/tests/test_plotting_embedded/conftest.py +++ b/tests/test_plotting_embedded/conftest.py @@ -58,7 +58,10 @@ def adata(): adata.obs["label_missing"] = adata.obs["label"].copy() adata.obs.loc[::2, "label_missing"] = np.nan - adata.obs["1_missing"] = adata.obs_vector("1") + # TODO: If we don't `copy`, something about this being an ArrayView means that all values get set to nan? + # https://github.com/scverse/anndata/issues/2348 + adata.obs["1_missing"] = adata[:, "1"].X.flatten().copy() + adata.obs.loc[ adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing" ] = np.nan