Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning][].
### Changes

- `tl.rankby_obsm` now accepts `AnnData.obs` column names specified in the `obs_keys` argument
- `pl.source_targets` now accepts extra arguments through `kw_scatter` and its `top` can be `None`, not displaying any label
- Most plotting functions now accept extra arguments through `kw_` arguments.
- p-values now are corrected using a custom numba-optimized version of `scipy.stats.false_discovery_control` called `_fdr_bh_axis1_numba`

## 2.1.1
Expand Down
7 changes: 6 additions & 1 deletion src/decoupler/pl/_barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def barplot(
vmin: float | None = None,
vcenter: float | None = 0,
vmax: float | None = None,
kw_barplot: dict | None = None,
**kwargs,
) -> None | Figure:
"""
Expand All @@ -57,6 +58,8 @@ def barplot(
%(vmin)s
%(vcenter)s
%(vmax)s
kw_barplot
Keyword arguments passed to ``seaborn.barplot``.
%(plot)s

Example
Expand All @@ -78,6 +81,8 @@ def barplot(
assert isinstance(name, str) and name in data.index, "name must be str and in data.index"
assert isinstance(top, int) and top > 0, "top must be int and > 0"
assert isinstance(vertical, bool), "vertical must be bool"
if kw_barplot is None:
kw_barplot = {}
# Process df
df = data.loc[[name]]
df.index.name = None
Expand All @@ -93,7 +98,7 @@ def barplot(
# Instance
bp = Plotter(**kwargs)
# Plot
sns.barplot(data=df, x=x, y=y, ax=bp.ax)
sns.barplot(data=df, x=x, y=y, ax=bp.ax, **kw_barplot)
if not vertical:
sizes = np.array([bar.get_width() for bar in bp.ax.containers[0]])
bp.ax.set_xlabel("Score")
Expand Down
14 changes: 6 additions & 8 deletions src/decoupler/pl/_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def dotplot(
scale: int | float = 0.15,
cmap: str = "RdBu_r",
vcenter: int | float | None = None,
kw_scatter: dict | None = None,
**kwargs,
) -> None | Figure:
"""
Expand All @@ -41,6 +42,8 @@ def dotplot(
Scale of the dots.
%(cmap)s
%(vcenter)s
kw_scatter
Keyword arguments passed to ``matplotlib.pyplot.scatter``.
%(plot)s

Example
Expand Down Expand Up @@ -77,6 +80,8 @@ def dotplot(
assert isinstance(top, int | float) and top > 0, "top must be numerical and > 0"
assert isinstance(scale, int | float), "scale must be numerical"
assert isinstance(vcenter, int | float) or vcenter is None, "vcenter must be numeric or None"
if kw_scatter is None:
kw_scatter = {}
# Filter by top
df = df.copy()
df["abs_x_col"] = df[x].abs()
Expand All @@ -101,14 +106,7 @@ def dotplot(
norm = TwoSlopeNorm(vmin=None, vcenter=vcenter, vmax=None)
else:
norm = None
scatter = bp.ax.scatter(
x=x_vals,
y=y_vals,
c=c_vals,
s=ns,
cmap=cmap,
norm=norm,
)
scatter = bp.ax.scatter(x=x_vals, y=y_vals, c=c_vals, s=ns, cmap=cmap, norm=norm, **kw_scatter)
bp.ax.set_axisbelow(True)
bp.ax.set_xlabel(x)
# Add legend
Expand Down
6 changes: 6 additions & 0 deletions src/decoupler/pl/_filter_by_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def filter_by_expr(
large_n: int = 10,
min_prop: float = 0.7,
cmap: str = "viridis",
kw_histplot: dict | None = None,
**kwargs,
) -> None | Figure:
"""
Expand All @@ -34,6 +35,8 @@ def filter_by_expr(
%(min_total_count)s
%(large_n)s
%(min_prop_expr)s
kw_histplot
Keyword arguments passed to ``seaborn.histplot``.
%(plot)s

Example
Expand All @@ -50,6 +53,8 @@ def filter_by_expr(
X, _, _ = extract(adata, empty=False)
isbacked = isinstance(X, tuple)
assert not isbacked, "adata is in backed mode, reload adata without backed='r'"
if kw_histplot is None:
kw_histplot = {}
obs = adata.obs
# Minimum sample size cutoff
min_sample_size = _min_sample_size(
Expand Down Expand Up @@ -77,6 +82,7 @@ def filter_by_expr(
cbar_kws={"shrink": 0.75, "label": "Number of genes"},
discrete=(False, True),
ax=bp.ax,
**kw_histplot,
)
bp.ax.axhline(y=min_sample_size - 0.5, c="gray", ls="--")
bp.ax.axvline(x=np.log10(min_total_count), c="gray", ls="--")
Expand Down
23 changes: 10 additions & 13 deletions src/decoupler/pl/_filter_by_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@docs.dedent
def filter_by_prop(
adata: AnnData, min_prop: float = 0.1, min_smpls: int = 2, log: bool = True, color="gray", **kwargs
adata: AnnData, min_prop: float = 0.1, min_smpls: int = 2, kw_hist: dict | None = None, **kwargs
) -> None | Figure:
"""
Plot to help determining the thresholds of the ``decoupler.pp.filter_by_prop`` function.
Expand All @@ -19,10 +19,8 @@ def filter_by_prop(
%(adata)s
%(min_prop_prop)s
%(min_smpls)s
log
Whether to log-scale the y axis.
color
Color to use in ``matplotlib.pyplot.hist``.
kw_hist
Keyword arguments passed to ``matplotlib.pyplot.hist``.
%(plot)s

Example
Expand All @@ -40,21 +38,20 @@ def filter_by_prop(
assert "psbulk_props" in adata.layers.keys(), (
"psbulk_props must be in adata.layers, use this function afer running decoupler.pp.pseudobulk"
)
if kw_hist is None:
kw_hist = {}
kw_hist.setdefault("color", "gray")
kw_hist.setdefault("align", "left")
kw_hist.setdefault("rwidth", 0.95)
kw_hist.setdefault("log", True)
props = adata.layers["psbulk_props"]
if isinstance(props, pd.DataFrame):
props = props.values
nsmpls = np.sum(props >= min_prop, axis=0)
# Instance
bp = Plotter(**kwargs)
# Plot
_ = bp.ax.hist(
nsmpls,
bins=range(min(nsmpls), max(nsmpls) + 2),
log=log,
color=color,
align="left",
rwidth=0.95,
)
_ = bp.ax.hist(nsmpls, bins=range(min(nsmpls), max(nsmpls) + 2), **kw_hist)
bp.ax.axvline(x=min_smpls - 0.5, c="black", ls="--")
bp.ax.set_xlabel("Samples (≥ min_prop)")
bp.ax.set_ylabel("Number of genes")
Expand Down
9 changes: 7 additions & 2 deletions src/decoupler/pl/_filter_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def filter_samples(
log: bool = True,
min_cells: int | float = 10,
min_counts: int | float = 1000,
kw_scatterplot: dict | None = None,
**kwargs,
) -> None | Figure:
"""
Expand All @@ -30,6 +31,8 @@ def filter_samples(
If set, log10 transform the ``psbulk_n_cells`` and ``psbulk_counts`` columns during visualization.
%(min_cells)s
%(min_counts)s
kw_scatterplot
Keyword arguments passed to ``seaborn.scatterplot``.
%(plot)s

Example
Expand All @@ -56,6 +59,8 @@ def filter_samples(
groupby = [groupby]
if groupby:
assert all(col in adata.obs for col in groupby), "columns in groupby must be in adata.obs"
if kw_scatterplot is None:
kw_scatterplot = {}
# Extract obs
df = adata.obs.copy()
# Transform to log10
Expand All @@ -78,7 +83,7 @@ def filter_samples(
for ax, grp in zip(axes, groupby, strict=False):
ax.grid(zorder=0)
ax.set_axisbelow(True)
sns.scatterplot(x="psbulk_cells", y="psbulk_counts", hue=grp, ax=ax, data=df, zorder=1)
sns.scatterplot(x="psbulk_cells", y="psbulk_counts", hue=grp, ax=ax, data=df, zorder=1, **kw_scatterplot)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, title=grp)
ax.set_xlabel(label_x)
ax.set_ylabel(label_y)
Expand All @@ -91,7 +96,7 @@ def filter_samples(
bp = Plotter(**kwargs)
bp.ax.grid(zorder=0)
bp.ax.set_axisbelow(True)
sns.scatterplot(x="psbulk_cells", y="psbulk_counts", hue=groupby, ax=bp.ax, data=df, zorder=1)
sns.scatterplot(x="psbulk_cells", y="psbulk_counts", hue=groupby, ax=bp.ax, data=df, zorder=1, **kw_scatterplot)
if groupby:
bp.ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, title=groupby)
bp.ax.set_xlabel(label_x)
Expand Down
38 changes: 19 additions & 19 deletions src/decoupler/pl/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,18 @@ def network(
score: pd.DataFrame = None,
sources: int | list | str = 5,
targets: int | list | str = 10,
by_abs=True,
size_node=5,
size_label=2.5,
s_cmap="RdBu_r",
t_cmap="viridis",
vcenter=False,
c_pos_w="darkgreen",
c_neg_w="darkred",
s_label="Enrichment\nscore",
t_label="Gene\nexpression",
layout="kk",
by_abs: bool = True,
size_node: int = 5,
size_label: float | int = 2.5,
s_cmap: str = "RdBu_r",
t_cmap: str = "viridis",
vcenter: bool = False,
c_pos_w: str = "darkgreen",
c_neg_w: str = "darkred",
s_label: str = "Enrichment\nscore",
t_label: str = "Gene\nexpression",
layout: str = "kk",
kw_igraph: dict | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -228,6 +229,8 @@ def network(
Label to place in the target colorbar.
layout
Layout to use to order the nodes. Check ``igraph`` documentation for more options.
kw_igraph
Keyword arguments passed to ``igraph.plot``.
%(plot)s

Example
Expand All @@ -249,6 +252,8 @@ def network(
data = pd.DataFrame(np.ones((1, trgs.size)), index=["0"], columns=trgs)
s_cmap = "white"
t_cmap = "white"
if kw_igraph is None:
kw_igraph = {}
# Filter
fdata, fscore, fnet = _filter(
data=data,
Expand Down Expand Up @@ -283,14 +288,9 @@ def network(
ax2 = bp.fig.add_subplot(gs[-1, 0])
ax3 = bp.fig.add_subplot(gs[-1, 1])
ax4 = bp.fig.add_subplot(gs[-1, -1])
ig.plot(
g,
target=ax1,
layout=layout,
vertex_size=(size_node * bp.dpi) / (bp.figsize[0] * bp.figsize[0]),
vertex_size_label=(size_label * bp.dpi) / (bp.figsize[0] * bp.figsize[0]),
bbox_inches="tight",
)
kw_igraph.setdefault("layout", layout)
kw_igraph.setdefault("vertex_size", (size_node * bp.dpi) / (bp.figsize[0] * bp.figsize[0]))
ig.plot(g, target=ax1, bbox_inches="tight", **kw_igraph)
if is_cmap:
sm = matplotlib.cm.ScalarMappable(norm=s_norm, cmap=s_cmap)
bp.fig.colorbar(sm, cax=ax2, orientation="horizontal", label=s_label)
Expand Down
7 changes: 6 additions & 1 deletion src/decoupler/pl/_volcano.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def volcano(
color_pos: str = "#D62728",
color_neg: str = "#1F77B4",
color_null: str = "gray",
kw_scatter: dict | None = None,
**kwargs,
) -> None | Figure:
"""
Expand Down Expand Up @@ -54,6 +55,8 @@ def volcano(
Color to plot significant negative features.
color_null
Color to plot rest of the genes.
kw_scatter
Keyword arguments passed to ``matplotlib.pyplot.scatter``.
%(plot)s

Example
Expand Down Expand Up @@ -85,6 +88,8 @@ def volcano(
assert isinstance(color_pos, str), "color_pos must be str"
assert isinstance(color_neg, str), "color_neg must be str"
assert isinstance(color_null, str), "color_null must be str"
if kw_scatter is None:
kw_scatter = {}
# Instance
bp = Plotter(**kwargs)
# Transform thr_sign
Expand All @@ -111,7 +116,7 @@ def volcano(
df.loc[up_msk, "weight"] = color_pos
df.loc[dw_msk, "weight"] = color_neg
# Plot
df.plot.scatter(x="stat", y="pval", c="weight", sharex=False, ax=bp.ax)
df.plot.scatter(x="stat", y="pval", c="weight", sharex=False, ax=bp.ax, **kw_scatter)
bp.ax.set_axisbelow(True)
# Draw thr lines
bp.ax.axvline(x=thr_stat, linestyle="--", color="black")
Expand Down
2 changes: 1 addition & 1 deletion tests/pl/test_barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ def test_barplot(
vertical,
vcenter,
):
fig = dc.pl.barplot(data=df, name=name, top=top, vertical=vertical, return_fig=True)
fig = dc.pl.barplot(data=df, name=name, top=top, vertical=vertical, return_fig=True, kw_barplot={"alpha": 0.5})
assert isinstance(fig, Figure)
plt.close(fig)
11 changes: 10 additions & 1 deletion tests/pl/test_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def test_dotplot(
df,
vcenter,
):
fig = dc.pl.dotplot(df=df, x="x", y="y", c="c", s="s", vcenter=vcenter, return_fig=True)
fig = dc.pl.dotplot(
df=df,
x="x",
y="y",
c="c",
s="s",
vcenter=vcenter,
return_fig=True,
kw_scatter={"alpha": 0.5},
)
assert isinstance(fig, Figure)
plt.close(fig)
2 changes: 1 addition & 1 deletion tests/pl/test_filter_by_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
def test_filter_by_expr(
pdata,
):
fig = dc.pl.filter_by_expr(adata=pdata, return_fig=True)
fig = dc.pl.filter_by_expr(adata=pdata, return_fig=True, kw_histplot={"alpha": 0.5})
assert isinstance(fig, Figure)
plt.close(fig)
2 changes: 1 addition & 1 deletion tests/pl/test_filter_by_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
def test_filter_by_prop(
pdata,
):
fig = dc.pl.filter_by_prop(adata=pdata, return_fig=True)
fig = dc.pl.filter_by_prop(adata=pdata, return_fig=True, kw_hist={"alpha": 0.5})
assert isinstance(fig, Figure)
plt.close(fig)
2 changes: 1 addition & 1 deletion tests/pl/test_filter_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def test_filter_samples(
groupby,
log,
):
fig = dc.pl.filter_samples(adata=pdata, groupby=groupby, log=log, return_fig=True)
fig = dc.pl.filter_samples(adata=pdata, groupby=groupby, log=log, return_fig=True, kw_scatterplot={"alpha": 0.5})
assert isinstance(fig, Figure)
plt.close(fig)
1 change: 1 addition & 0 deletions tests/pl/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_network(
vcenter=vcenter,
s_cmap=s_cmap,
figsize=(5, 5),
kw_igraph={"vertex_color": "green"},
return_fig=True,
)
assert isinstance(fig, Figure)
Expand Down
Loading