Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 11 additions & 25 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ._utils import (
_deprecated_scale,
_dk,
_obs_vector_compat,
check_colornorm,
scatter_base,
scatter_group,
Expand All @@ -55,13 +56,7 @@
from seaborn.matrix import ClusterGrid

from .._utils import Empty
from ._utils import (
ColorLike,
DensityNorm,
_FontSize,
_FontWeight,
_LegendLoc,
)
from ._utils import ColorLike, DensityNorm, _FontSize, _FontWeight, _LegendLoc

# TODO: is that all?
type _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"]
Expand Down Expand Up @@ -324,27 +319,18 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
# ignore the '0th' diffusion component
if basis == "diffmap":
components += 1
xy = adata.obsm["X_" + basis][:, components]
xy = adata.obsm[f"X_{basis}"][:, components]
# correct the component vector for use in labeling etc.
if basis == "diffmap":
components -= 1
except KeyError:
msg = f"compute coordinates using visualization tool {basis} first"
raise KeyError(msg) from None
elif x is not None and y is not None:
if use_raw:
if x in adata.obs.columns:
x_arr = adata.obs_vector(x)
else:
x_arr = adata.raw.obs_vector(x)
if y in adata.obs.columns:
y_arr = adata.obs_vector(y)
else:
y_arr = adata.raw.obs_vector(y)
else:
x_arr = adata.obs_vector(x, layer=layers[0])
y_arr = adata.obs_vector(y, layer=layers[1])

x_arr, y_arr = (
_obs_vector_compat(adata, k, use_raw=use_raw, layer=layer)
for layer, k in zip(layers, [x, y], strict=False)
)
xy = np.c_[x_arr, y_arr]
else:
msg = "Either provide a `basis` or `x` and `y`."
Expand Down Expand Up @@ -399,10 +385,10 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
else:
c = adata.obs[key].to_numpy()
# coloring according to gene expression
elif use_raw and adata.raw is not None and key in adata.raw.var_names:
c = adata.raw.obs_vector(key)
elif key in adata.var_names:
c = adata.obs_vector(key, layer=layers[2])
elif (use_raw and adata.raw is not None and key in adata.raw.var_names) or (
key in adata.var_names
):
c = _obs_vector_compat(adata, key, use_raw=use_raw, layer=layers[2])
elif is_color_like(key): # a flat color
c = key
colorbar = False
Expand Down
11 changes: 6 additions & 5 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
doc_scatter_spatial,
doc_show_save_ax,
)
from .._utils import check_colornorm, check_projection, circles
from .._utils import _obs_vector_compat, check_colornorm, check_projection, circles

if TYPE_CHECKING:
from collections.abc import Callable, Collection, Mapping
Expand Down Expand Up @@ -266,6 +266,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
# ]
for count, (value_to_plot, dims) in enumerate(zip(color, dimensions, strict=True)):
kwargs_scatter = kwargs.copy() # is potentially mutated for each plot
# TODO: It might be worth not returning `NumpyExtensionArray` objects out of the dataframes via accessors because we have a lot of np.ndarray checks.
# Setting np.array here prevents the `NumpyExtensionArray` from propagating.
color_source_vector = _get_color_source_vector(
adata,
value_to_plot,
Expand All @@ -275,6 +277,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
gene_symbols=gene_symbols,
groups=groups,
)
if isinstance(color_source_vector, pd.arrays.NumpyExtensionArray):
color_source_vector = color_source_vector.to_numpy()
color_vector, color_type = _color_vector(
adata,
value_to_plot,
Expand Down Expand Up @@ -1221,10 +1225,7 @@ def _get_color_source_vector(
# We should probably just make an index for this, and share it over runs
# TODO: Throw helpful error if this doesn't work
value_to_plot = adata.var.index[adata.var[gene_symbols] == value_to_plot][0]
if use_raw and value_to_plot not in adata.obs.columns:
values = adata.raw.obs_vector(value_to_plot)
else:
values = adata.obs_vector(value_to_plot, layer=layer)
values = _obs_vector_compat(adata, value_to_plot, use_raw=use_raw, layer=layer)
if mask_obs is not None:
values = values.copy()
values[~mask_obs] = np.nan
Expand Down
22 changes: 22 additions & 0 deletions src/scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from matplotlib.figure import Figure
from matplotlib.typing import MarkerType
from numpy.typing import ArrayLike
from pandas.api.extensions import ExtensionArray
from PIL.Image import Image

from .._utils import Empty
Expand All @@ -44,6 +45,7 @@
"_create_white_to_color_gradient",
"_deprecated_scale",
"_dk",
"_obs_vector_compat",
"add_colors_for_categorical_sample_annotation",
"check_colornorm",
"check_projection",
Expand Down Expand Up @@ -1167,3 +1169,23 @@ def _create_white_to_color_gradient(
return ListedColormap(
clipped_rgb, name=color if isinstance(color, str) else hex_color
)


def _obs_vector_compat(
adata: AnnData, k: str, *, use_raw: bool, layer: str | None
) -> np.ndarray | ExtensionArray:
try:
from anndata.acc import A
except ImportError:
return (
adata.raw.obs_vector(k)
if use_raw and k not in adata.obs.columns
else adata.obs_vector(k, layer=layer)
)

if k in adata.obs.columns:
return adata[A.obs[k]]
elif not use_raw:
return adata[A.layers[layer][:, k]]
else:
return adata.raw[A.X[:, k]]
12 changes: 9 additions & 3 deletions tests/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,26 @@ def test_column_content():
adata = pbmc68k_reduced()

# test that columns content is correct for obs_df
query = ["CST3", "NKG7", "GNLY", "louvain", "n_counts", "n_genes"]
cols = ["louvain", "n_counts", "n_genes"]
query = [*cols, "CST3", "NKG7", "GNLY"]
df = sc.get.obs_df(adata, query)
for col in query:
assert col in df
np.testing.assert_array_equal(query, df.columns)
np.testing.assert_array_equal(df[col].values, adata.obs_vector(col))
np.testing.assert_array_equal(
df[col].values, adata.obs[col] if col in cols else adata[:, col].X.ravel()
)

# test that columns content is correct for var_df
cell_ids = list(adata.obs.sample(5).index)
query = [*cell_ids, "highly_variable", "dispersions_norm", "dispersions"]
df = sc.get.var_df(adata, query)
np.testing.assert_array_equal(query, df.columns)
for col in query:
np.testing.assert_array_equal(df[col].values, adata.var_vector(col))
np.testing.assert_array_equal(
df[col].values,
adata[col, :].X.ravel() if col in cell_ids else adata.var[col],
)


def test_var_df(adata: AnnData):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def test_consistency(metric) -> None:
)

all_genes = metric(pbmc, layer="raw")
first_gene = metric(pbmc, vals=pbmc.obs_vector(pbmc.var_names[0], layer="raw"))
first_gene = metric(
pbmc, vals=pbmc[:, pbmc.var_names[0]].layers["raw"].toarray().ravel()
)

np.testing.assert_allclose(all_genes[0], first_gene, rtol=1e-9)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_plotting_embedded/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def adata():
adata.obs["label_missing"] = adata.obs["label"].copy()
adata.obs.loc[::2, "label_missing"] = np.nan

adata.obs["1_missing"] = adata.obs_vector("1")
# TODO: If we don't `copy`, something about this being an ArrayView means that all values get set to nan?
# https://github.com/scverse/anndata/issues/2348
adata.obs["1_missing"] = adata[:, "1"].X.flatten().copy()

adata.obs.loc[
adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing"
] = np.nan
Expand Down
Loading