Skip to content
Merged
36 changes: 11 additions & 25 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ._utils import (
_deprecated_scale,
_dk,
_obs_vector_compat,
check_colornorm,
scatter_base,
scatter_group,
Expand All @@ -55,13 +56,7 @@
from seaborn.matrix import ClusterGrid

from .._utils import Empty
from ._utils import (
ColorLike,
DensityNorm,
_FontSize,
_FontWeight,
_LegendLoc,
)
from ._utils import ColorLike, DensityNorm, _FontSize, _FontWeight, _LegendLoc

# TODO: is that all?
type _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"]
Expand Down Expand Up @@ -324,27 +319,18 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
# ignore the '0th' diffusion component
if basis == "diffmap":
components += 1
xy = adata.obsm["X_" + basis][:, components]
xy = adata.obsm[f"X_{basis}"][:, components]
# correct the component vector for use in labeling etc.
if basis == "diffmap":
components -= 1
except KeyError:
msg = f"compute coordinates using visualization tool {basis} first"
raise KeyError(msg) from None
elif x is not None and y is not None:
if use_raw:
if x in adata.obs.columns:
x_arr = adata.obs_vector(x)
else:
x_arr = adata.raw.obs_vector(x)
if y in adata.obs.columns:
y_arr = adata.obs_vector(y)
else:
y_arr = adata.raw.obs_vector(y)
else:
x_arr = adata.obs_vector(x, layer=layers[0])
y_arr = adata.obs_vector(y, layer=layers[1])

x_arr, y_arr = (
_obs_vector_compat(adata, k, use_raw=use_raw, layer=layer)
for layer, k in zip(layers, [x, y], strict=False)
)
xy = np.c_[x_arr, y_arr]
else:
msg = "Either provide a `basis` or `x` and `y`."
Expand Down Expand Up @@ -399,10 +385,10 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
else:
c = adata.obs[key].to_numpy()
# coloring according to gene expression
elif use_raw and adata.raw is not None and key in adata.raw.var_names:
c = adata.raw.obs_vector(key)
elif key in adata.var_names:
c = adata.obs_vector(key, layer=layers[2])
elif (use_raw and adata.raw is not None and key in adata.raw.var_names) or (
key in adata.var_names
):
c = _obs_vector_compat(adata, key, use_raw=use_raw, layer=layers[2])
elif is_color_like(key): # a flat color
c = key
colorbar = False
Expand Down
11 changes: 6 additions & 5 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
doc_scatter_spatial,
doc_show_save_ax,
)
from .._utils import check_colornorm, check_projection, circles
from .._utils import _obs_vector_compat, check_colornorm, check_projection, circles

if TYPE_CHECKING:
from collections.abc import Callable, Collection, Mapping
Expand Down Expand Up @@ -266,6 +266,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
# ]
for count, (value_to_plot, dims) in enumerate(zip(color, dimensions, strict=True)):
kwargs_scatter = kwargs.copy() # is potentially mutated for each plot
# TODO: It might be worth not returning `NumpyExtensionArray` objects out of the dataframes via accessors because we have a lot of np.ndarray checks.
# Setting np.array here prevents the `NumpyExtensionArray` from propagating.
color_source_vector = _get_color_source_vector(
adata,
value_to_plot,
Expand All @@ -275,6 +277,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
gene_symbols=gene_symbols,
groups=groups,
)
if isinstance(color_source_vector, pd.arrays.NumpyExtensionArray):
color_source_vector = color_source_vector.to_numpy()
color_vector, color_type = _color_vector(
adata,
value_to_plot,
Expand Down Expand Up @@ -1221,10 +1225,7 @@ def _get_color_source_vector(
# We should probably just make an index for this, and share it over runs
# TODO: Throw helpful error if this doesn't work
value_to_plot = adata.var.index[adata.var[gene_symbols] == value_to_plot][0]
if use_raw and value_to_plot not in adata.obs.columns:
values = adata.raw.obs_vector(value_to_plot)
else:
values = adata.obs_vector(value_to_plot, layer=layer)
values = _obs_vector_compat(adata, value_to_plot, use_raw=use_raw, layer=layer)
if mask_obs is not None:
values = values.copy()
values[~mask_obs] = np.nan
Expand Down
22 changes: 22 additions & 0 deletions src/scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from matplotlib.figure import Figure
from matplotlib.typing import MarkerType
from numpy.typing import ArrayLike
from pandas.api.extensions import ExtensionArray
from PIL.Image import Image

from .._utils import Empty
Expand All @@ -44,6 +45,7 @@
"_create_white_to_color_gradient",
"_deprecated_scale",
"_dk",
"_obs_vector_compat",
"add_colors_for_categorical_sample_annotation",
"check_colornorm",
"check_projection",
Expand Down Expand Up @@ -1167,3 +1169,23 @@ def _create_white_to_color_gradient(
return ListedColormap(
clipped_rgb, name=color if isinstance(color, str) else hex_color
)


def _obs_vector_compat(
adata: AnnData, k: str, *, use_raw: bool, layer: str | None
) -> np.ndarray | ExtensionArray:
try:
from anndata.acc import A
except ImportError:
return (
adata.raw.obs_vector(k)
if use_raw and k not in adata.obs.columns
else adata.obs_vector(k, layer=layer)
)

if k in adata.obs.columns:
return adata[A.obs[k]]
elif not use_raw:
return adata[A.layers[layer][:, k]]
else:
return adata.raw[A.X[:, k]]
12 changes: 9 additions & 3 deletions tests/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,26 @@ def test_column_content():
adata = pbmc68k_reduced()

# test that columns content is correct for obs_df
query = ["CST3", "NKG7", "GNLY", "louvain", "n_counts", "n_genes"]
cols = ["louvain", "n_counts", "n_genes"]
query = [*cols, "CST3", "NKG7", "GNLY"]
df = sc.get.obs_df(adata, query)
for col in query:
assert col in df
np.testing.assert_array_equal(query, df.columns)
np.testing.assert_array_equal(df[col].values, adata.obs_vector(col))
np.testing.assert_array_equal(
df[col].values, adata.obs[col] if col in cols else adata[:, col].X.ravel()
)

# test that columns content is correct for var_df
cell_ids = list(adata.obs.sample(5).index)
query = [*cell_ids, "highly_variable", "dispersions_norm", "dispersions"]
df = sc.get.var_df(adata, query)
np.testing.assert_array_equal(query, df.columns)
for col in query:
np.testing.assert_array_equal(df[col].values, adata.var_vector(col))
np.testing.assert_array_equal(
df[col].values,
adata[col, :].X.ravel() if col in cell_ids else adata.var[col],
)


def test_var_df(adata: AnnData):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def test_consistency(metric) -> None:
)

all_genes = metric(pbmc, layer="raw")
first_gene = metric(pbmc, vals=pbmc.obs_vector(pbmc.var_names[0], layer="raw"))
first_gene = metric(
pbmc, vals=pbmc[:, pbmc.var_names[0]].layers["raw"].toarray().ravel()
)

np.testing.assert_allclose(all_genes[0], first_gene, rtol=1e-9)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_plotting_embedded/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def adata():
adata.obs["label_missing"] = adata.obs["label"].copy()
adata.obs.loc[::2, "label_missing"] = np.nan

adata.obs["1_missing"] = adata.obs_vector("1")
# TODO: If we don't `copy`, something about this being an ArrayView means that all values get set to nan?
# https://github.com/scverse/anndata/issues/2348
adata.obs["1_missing"] = adata[:, "1"].X.flatten().copy()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to get rid of views


adata.obs.loc[
adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing"
] = np.nan
Expand Down
Loading