Skip to content

Commit 0270f1c

Browse files
committed
Add values_to_plot support to matrixplot and heatmap (issue #3842)
1 parent 16eee94 commit 0270f1c

2 files changed

Lines changed: 174 additions & 0 deletions

File tree

src/scanpy/plotting/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,9 @@
110110
timeseries = deprecated("Use `dpt_timeseries`.")(timeseries)
111111
timeseries_as_heatmap = deprecated("Use `dpt_timeseries`.")(timeseries_as_heatmap)
112112
timeseries_subplot = deprecated("Use `dpt_timeseries`.")(timeseries_subplot)
113+
114+
from ._rank_genes_groups import (
115+
rank_genes_groups_matrixplot,
116+
rank_genes_groups_heatmap,
117+
rank_genes_groups_dotplot,
118+
)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal, Sequence
4+
5+
import numpy as np
6+
import pandas as pd
7+
8+
from anndata import AnnData
9+
10+
from ._anndata import heatmap, matrixplot, dotplot
11+
12+
13+
ArrayLike = np.ndarray
14+
ValuesToPlot = Literal["scores", "logfoldchanges", "pvals", "pvals_adj"]
15+
16+
17+
def _extract_rgg_values(
18+
adata: AnnData,
19+
values_to_plot: ValuesToPlot,
20+
groups: Sequence[str] | str | None,
21+
n_genes: int,
22+
):
23+
"""Extract dataframe (groups × genes) for the selected rank_genes_groups metric."""
24+
rgg = adata.uns.get("rank_genes_groups", None)
25+
if rgg is None:
26+
raise ValueError("`adata.uns['rank_genes_groups']` not found.")
27+
28+
groups_order = rgg["names"].dtype.names
29+
if isinstance(groups, str) and groups != "all":
30+
groups = [groups]
31+
elif groups is None or groups == "all":
32+
groups = list(groups_order)
33+
34+
# gather top N genes per group
35+
selected_genes = []
36+
for g in groups:
37+
arr = rgg["names"][g][:n_genes]
38+
selected_genes.extend(arr)
39+
40+
selected_genes = list(dict.fromkeys(selected_genes)) # deduplicate, preserve order
41+
42+
# build dataframe values_df[group][gene]
43+
df = pd.DataFrame(index=groups, columns=selected_genes, dtype=float)
44+
45+
for g in groups:
46+
metrics = rgg[values_to_plot][g] # ndarray length = total ranked genes
47+
names = rgg["names"][g]
48+
49+
# map each gene to its metric
50+
mapping = {gene: metrics[i] for i, gene in enumerate(names)}
51+
52+
# fill row
53+
df.loc[g, :] = [mapping.get(gn, np.nan) for gn in selected_genes]
54+
55+
return df, selected_genes, groups
56+
57+
58+
# ------------------------------------------------------------------------------
59+
# Matrixplot
60+
# ------------------------------------------------------------------------------
61+
62+
def rank_genes_groups_matrixplot(
63+
adata: AnnData,
64+
*,
65+
values_to_plot: ValuesToPlot | None = None,
66+
groups: Sequence[str] | str | None = None,
67+
n_genes: int = 20,
68+
**kwargs,
69+
MatrixPlot wrapper for rank_genes_groups with DE metric selection.
70+
71+
Example:
72+
sc.pl.rank_genes_groups_matrixplot(
73+
adata,
74+
values_to_plot="logfoldchanges",
75+
groups=["0","1"],
76+
n_genes=20,
77+
)
78+
"""
79+
if values_to_plot is None:
80+
# default: plot expression of marker genes
81+
82+
raise ValueError(
83+
"`values_to_plot` must be provided. Options: "
84+
"['scores', 'logfoldchanges', 'pvals', 'pvals_adj']"
85+
)
86+
87+
values_df, genes, groups = _extract_rgg_values(
88+
adata, values_to_plot, groups, n_genes
89+
)
90+
91+
return matrixplot(
92+
adata,
93+
var_names=genes,
94+
groupby=groups,
95+
values_df=values_df,
96+
**kwargs,
97+
)
98+
99+
100+
# ------------------------------------------------------------------------------
101+
# Heatmap
102+
# ------------------------------------------------------------------------------
103+
104+
def rank_genes_groups_heatmap(
105+
adata: AnnData,
106+
*,
107+
values_to_plot: ValuesToPlot | None = None,
108+
groups: Sequence[str] | str | None = None,
109+
n_genes: int = 20,
110+
**kwargs,
111+
):
112+
"""
113+
Heatmap wrapper for rank_genes_groups with DE metric selection.
114+
"""
115+
if values_to_plot is None:
116+
raise ValueError(
117+
"`values_to_plot` must be provided. Options: "
118+
"['scores', 'logfoldchanges', 'pvals', 'pvals_adj']"
119+
)
120+
121+
values_df, genes, groups = _extract_rgg_values(
122+
adata, values_to_plot, groups, n_genes
123+
)
124+
125+
return heatmap(
126+
adata,
127+
var_names=genes,
128+
groupby=groups,
129+
values_df=values_df,
130+
**kwargs,
131+
)
132+
133+
134+
# ------------------------------------------------------------------------------
135+
# Dotplot (for completeness parity with the issue text)
136+
# ------------------------------------------------------------------------------
137+
138+
def rank_genes_groups_dotplot(
139+
adata: AnnData,
140+
*,
141+
values_to_plot: ValuesToPlot | None = None,
142+
groups: Sequence[str] | str | None = None,
143+
n_genes: int = 20,
144+
**kwargs,
145+
):
146+
"""
147+
DotPlot wrapper for rank_genes_groups with DE metric selection.
148+
This adds parity with the existing sc.pl.rank_genes_groups_dotplot API.
149+
"""
150+
if values_to_plot is None:
151+
raise ValueError(
152+
"`values_to_plot` must be provided. Options: "
153+
"['scores', 'logfoldchanges', 'pvals', 'pvals_adj']"
154+
)
155+
156+
values_df, genes, groups = _extract_rgg_values(
157+
adata, values_to_plot, groups, n_genes
158+
)
159+
160+
# DotPlot uses values_df as dot_color_df
161+
return dotplot(
162+
adata,
163+
var_names=genes,
164+
groupby=groups,
165+
dot_color_df=values_df,
166+
**kwargs,
167+
)
168+

0 commit comments

Comments
 (0)