Skip to content

Commit 52bb5b7

Browse files
PauBadiaMclaude
andcommitted
Integrate decoupler Plotter class into pl functions
All three plotting functions (heatmap, links, ranking) now use the decoupler Plotter base class for consistent figure management. This adds save parameter support to all functions and standardizes the figsize/dpi/return_fig interface via **kwargs. Fix pre-existing test fixture bug: rename 'dts' column to 'dataset' in test_pl_ranking.py to match _compute_aggregations expectation. Add save parameter tests for heatmap, links, and ranking. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ab9a33b commit 52bb5b7

5 files changed

Lines changed: 105 additions & 56 deletions

File tree

src/gretapy/pl/_heatmap.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib.pyplot as plt
44
import numpy as np
55
import pandas as pd
6+
from decoupler._Plotter import Plotter
67

78

89
def _make_sim_mat(df: pd.DataFrame, col: str) -> pd.DataFrame:
@@ -24,7 +25,7 @@ def heatmap(
2425
vmax: float = 1,
2526
width: float = 2,
2627
height: float = 2,
27-
return_fig: bool = False,
28+
**kwargs,
2829
) -> plt.Figure | None:
2930
"""
3031
Plot overlap coefficient heatmap.
@@ -46,16 +47,17 @@ def heatmap(
4647
vmax
4748
Maximum value for colormap. Default is 1.
4849
width
49-
Width of the heatmap. Default is 2.
50+
Width of the heatmap in marsilea units. Default is 2.
5051
height
51-
Height of the heatmap. Default is 2.
52-
return_fig
53-
Whether to return the figure. Default is False.
52+
Height of the heatmap in marsilea units. Default is 2.
53+
**kwargs
54+
Additional arguments passed to ``decoupler.Plotter`` (e.g. ``figsize``,
55+
``dpi``, ``return_fig``, ``save``).
5456
5557
Returns
5658
-------
5759
plt.Figure or None
58-
Figure if return_fig is True.
60+
Figure if ``return_fig=True``.
5961
"""
6062
if level not in {"source", "cre", "target", "edge"}:
6163
raise ValueError(f'level must be "source", "cre", "target", or "edge", got {level}')
@@ -70,14 +72,20 @@ def heatmap(
7072
if title is None:
7173
title = level.capitalize()
7274

75+
kwargs["ax"] = None
76+
bp = Plotter(**kwargs)
77+
bp.fig.delaxes(bp.ax)
78+
plt.close(bp.fig)
79+
7380
h = ma.Heatmap(mat, cmap=cmap, width=width, height=height, label="Overlap\nCoefficient", vmin=vmin, vmax=vmax)
7481
h.add_bottom(mp.Labels(mat.columns))
7582
h.add_left(mp.Labels(mat.index))
7683
h.add_top(mp.Title(title))
7784
h.add_legends()
7885
h.render()
7986

80-
if return_fig:
81-
fig = h.figure
82-
plt.close()
83-
return fig
87+
bp.fig = h.figure
88+
bp.fig.set_figwidth(bp.figsize[0])
89+
bp.fig.set_figheight(bp.figsize[1])
90+
bp.fig.set_dpi(bp.dpi)
91+
return bp._return()

src/gretapy/pl/_links.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
import pyranges as pr
8+
from decoupler._Plotter import Plotter
89

910
import gretapy as gt
1011

@@ -212,9 +213,7 @@ def links(
212213
agg_mode: str = "mean",
213214
palette: dict[str, str] | None = None,
214215
expr_cmap: None | str | list[str] = None,
215-
figsize: tuple[float, float] | None = None,
216-
dpi: int = 125,
217-
return_fig: bool = False,
216+
**kwargs,
218217
) -> plt.Figure | None:
219218
"""
220219
Plot CRE-to-gene links for TFs in a genomic region.
@@ -256,17 +255,15 @@ def links(
256255
Color palette mapping GRN names to colors. If None, uses default colors.
257256
expr_cmap
258257
Colormap for expression heatmap. Default is white to purple.
259-
figsize
260-
Figure size (width, height). If None, auto-calculated.
261-
dpi
262-
Figure DPI. Default is 150.
263-
return_fig
264-
Whether to return the figure. Default is False.
258+
**kwargs
259+
Additional arguments passed to ``decoupler.Plotter`` (e.g. ``figsize``,
260+
``dpi``, ``return_fig``, ``save``). ``figsize`` defaults to
261+
``(3, auto)`` and ``dpi`` defaults to ``125``.
265262
266263
Returns
267264
-------
268265
plt.Figure or None
269-
Figure if return_fig is True.
266+
Figure if ``return_fig=True``.
270267
271268
Examples
272269
--------
@@ -334,11 +331,16 @@ def links(
334331
gannot_height = max(n_genes * 0.3, 1.0)
335332

336333
# Set up figure with dynamic sizing
337-
if figsize is None:
338-
fig_height = n_tfs * 1.0 + omic_height + gannot_height
339-
figsize = (3, fig_height)
334+
fig_height = n_tfs * 1.0 + omic_height + gannot_height
335+
kwargs.setdefault("figsize", (3, fig_height))
336+
kwargs.setdefault("dpi", 125)
337+
kwargs["ax"] = None
338+
bp = Plotter(**kwargs)
339+
bp.fig.delaxes(bp.ax)
340+
plt.close(bp.fig)
341+
340342
height_ratios = [1] * n_tfs + [omic_height, gannot_height]
341-
fig, axes = plt.subplots(2 + n_tfs, 1, figsize=figsize, dpi=dpi, sharex=True, height_ratios=height_ratios)
343+
fig, axes = plt.subplots(2 + n_tfs, 1, figsize=bp.figsize, dpi=bp.dpi, sharex=True, height_ratios=height_ratios)
342344
axes = axes.ravel()
343345

344346
# Expression colormap
@@ -366,6 +368,5 @@ def links(
366368

367369
fig.subplots_adjust(wspace=0, hspace=0.0)
368370

369-
if return_fig:
370-
plt.close()
371-
return fig
371+
bp.fig = fig
372+
return bp._return()

src/gretapy/pl/_ranking.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib.pyplot as plt
44
import numpy as np
55
import pandas as pd
6+
from decoupler._Plotter import Plotter
67
from matplotlib.gridspec import GridSpec
78
from scipy.stats import rankdata
89

@@ -33,12 +34,9 @@
3334
CLASS_ORDER = ["Predictive", "Genomic", "Literature", "Mechanistic"]
3435

3536

36-
EXCLUDE_DATASETS = {"Synthetic Pituitary", "Unpaired Pituitary"}
37-
38-
3937
def _compute_aggregations(df):
4038
"""Compute hierarchical mean F0.1 at class and overall level."""
41-
s = df.groupby(["name", "class", "task", "db", "dts"])["f01"].mean()
39+
s = df.groupby(["name", "class", "task", "db", "dataset"])["f01"].mean()
4240
s = s.groupby(["name", "class", "task", "db"]).mean()
4341
s = s.groupby(["name", "class", "task"]).mean()
4442
class_mean = s.groupby(["name", "class"]).mean().unstack()
@@ -338,25 +336,24 @@ def ranking(
338336
df,
339337
level="class",
340338
palette=None,
341-
figsize=None,
342-
return_fig=False,
339+
**kwargs,
343340
):
344341
"""
345342
Plot a ranking figure with a barplot of mean F0.1 and a heatmap of rankings.
346343
347344
Parameters
348345
----------
349346
df : pd.DataFrame
350-
Metrics dataframe with columns: class, task, db, dts, name, f01.
347+
Metrics dataframe with columns: class, task, db, dataset, name, f01.
351348
level : str
352349
``'class'`` for summary heatmap at class level (Predictive, Genomic, etc.),
353350
``'task'`` for the fine-grained (db, task) heatmap with hierarchical headers.
354351
palette : dict or None
355352
Method name -> color mapping. Uses default palette if None.
356-
figsize : tuple or None
357-
Figure size. Auto-computed if None.
358-
return_fig : bool
359-
If True, return the figure instead of calling plt.show().
353+
**kwargs
354+
Additional arguments passed to ``decoupler.Plotter`` (e.g. ``figsize``,
355+
``dpi``, ``return_fig``, ``save``). ``figsize`` defaults to auto-computed
356+
based on data dimensions. ``dpi`` defaults to ``100``.
360357
361358
Returns
362359
-------
@@ -365,23 +362,31 @@ def ranking(
365362
if palette is None:
366363
palette = PALETTE
367364

368-
# Filter excluded datasets
369-
if "dts" in df.columns:
370-
df = df[~df["dts"].isin(EXCLUDE_DATASETS)].copy()
371-
372365
# Compute overall mean and class-level aggregation (needed for method ordering)
373366
overall_mean, class_mean = _compute_aggregations(df)
374367
method_order = overall_mean.sort_values(ascending=False).index.tolist()
375368
overall_mean = overall_mean.loc[method_order]
376369
n_methods = len(method_order)
377370

371+
# Extract user-provided figsize before Plotter (sub-functions auto-calc when None)
372+
user_figsize = kwargs.get("figsize", None)
373+
kwargs.setdefault("figsize", (4, 3))
374+
kwargs.setdefault("dpi", 100)
375+
kwargs["ax"] = None
376+
bp = Plotter(**kwargs)
377+
bp.fig.delaxes(bp.ax)
378+
plt.close(bp.fig)
379+
378380
if level == "class":
379-
fig = _ranking_class(df, overall_mean, class_mean, method_order, n_methods, palette, figsize)
381+
fig = _ranking_class(df, overall_mean, class_mean, method_order, n_methods, palette, user_figsize)
380382
elif level == "task":
381-
fig = _ranking_task(df, method_order, n_methods, figsize)
383+
fig = _ranking_task(df, method_order, n_methods, user_figsize)
382384
else:
383385
raise ValueError(f"level must be 'class' or 'task', got '{level}'")
384386

385-
if return_fig:
386-
return fig
387-
plt.show()
387+
bp.fig = fig
388+
if user_figsize is not None:
389+
bp.fig.set_figwidth(bp.figsize[0])
390+
bp.fig.set_figheight(bp.figsize[1])
391+
bp.fig.set_dpi(bp.dpi)
392+
return bp._return()

tests/test_pl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,15 @@ def test_all_parameters_combined(self, ocoeff_df):
257257
assert isinstance(result, plt.Figure)
258258
plt.close(result)
259259

260+
def test_save_writes_file(self, ocoeff_df, tmp_path):
261+
"""Test that save parameter writes a file."""
262+
out = tmp_path / "heatmap.png"
263+
result = heatmap(ocoeff_df, save=str(out))
264+
265+
assert result is None
266+
assert out.exists()
267+
plt.close("all")
268+
260269

261270
# ============================================================================
262271
# Fixtures for links tests
@@ -789,6 +798,23 @@ def test_plot_links_cre_overlapping_tss(self, mdata_pseudobulk, gannot_pyranges)
789798
assert isinstance(result, plt.Figure)
790799
plt.close(result)
791800

801+
def test_save_writes_file(self, mdata_pseudobulk, grn_single, gannot_pyranges, tmp_path):
802+
"""Test that save parameter writes a file."""
803+
out = tmp_path / "links.png"
804+
result = links(
805+
mdata_pseudobulk,
806+
grn_single,
807+
target="GENE_A",
808+
tfs=["TF1"],
809+
gannot=gannot_pyranges,
810+
w_size=500000,
811+
save=str(out),
812+
)
813+
814+
assert result is None
815+
assert out.exists()
816+
plt.close("all")
817+
792818
def test_plot_links_minus_strand_cre_positions(self, mdata_pseudobulk, gannot_pyranges):
793819
"""Test arc computation for minus-strand gene (lines 99-107)."""
794820
# GENE_B is - strand, TSS at End=1250000

tests/test_pl_ranking.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,25 @@ def sample_ranking_df():
3333
methods = ["MethodA", "MethodB", "MethodC"]
3434
rows = [
3535
# Literature - 2 different tasks (triggers line 83 in _build_task_list)
36-
{"name": m, "class": "Literature", "task": "TF Markers", "db": "HPA", "dts": "DatasetX", "f01": np.random.uniform(0, 1)}
36+
{"name": m, "class": "Literature", "task": "TF Markers", "db": "HPA", "dataset": "DatasetX", "f01": np.random.uniform(0, 1)}
3737
for m in methods
3838
] + [
39-
{"name": m, "class": "Literature", "task": "TF Pairs", "db": "Europe PMC", "dts": "DatasetX", "f01": np.random.uniform(0, 1)}
39+
{"name": m, "class": "Literature", "task": "TF Pairs", "db": "Europe PMC", "dataset": "DatasetX", "f01": np.random.uniform(0, 1)}
4040
for m in methods
4141
] + [
4242
# Genomic - 2 different tasks (triggers line 259 in _ranking_task span update)
43-
{"name": m, "class": "Genomic", "task": "TF Binding", "db": "ChIP-Atlas", "dts": "DatasetX", "f01": np.random.uniform(0, 1)}
43+
{"name": m, "class": "Genomic", "task": "TF Binding", "db": "ChIP-Atlas", "dataset": "DatasetX", "f01": np.random.uniform(0, 1)}
4444
for m in methods
4545
] + [
46-
{"name": m, "class": "Genomic", "task": "CREs", "db": "ENCODE CREs", "dts": "DatasetX", "f01": np.random.uniform(0, 1)}
46+
{"name": m, "class": "Genomic", "task": "CREs", "db": "ENCODE CREs", "dataset": "DatasetX", "f01": np.random.uniform(0, 1)}
4747
for m in methods
4848
] + [
4949
# Predictive
50-
{"name": m, "class": "Predictive", "task": "Gene Sets", "db": "Hallmarks", "dts": "DatasetX", "f01": np.random.uniform(0, 1)}
50+
{"name": m, "class": "Predictive", "task": "Gene Sets", "db": "Hallmarks", "dataset": "DatasetX", "f01": np.random.uniform(0, 1)}
5151
for m in methods
5252
] + [
5353
# Mechanistic
54-
{"name": m, "class": "Mechanistic", "task": "TF Scoring", "db": "KnockTF", "dts": "DatasetX", "f01": np.random.uniform(0, 1)}
54+
{"name": m, "class": "Mechanistic", "task": "TF Scoring", "db": "KnockTF", "dataset": "DatasetX", "f01": np.random.uniform(0, 1)}
5555
for m in methods
5656
]
5757
return pd.DataFrame(rows)
@@ -63,10 +63,10 @@ def sample_ranking_df_with_nan():
6363
methods = ["MethodA", "MethodB"]
6464
rows = [
6565
# MethodA has both Predictive and Literature data
66-
{"name": "MethodA", "class": "Predictive", "task": "Gene Sets", "db": "Hallmarks", "dts": "DatasetX", "f01": 0.5},
67-
{"name": "MethodA", "class": "Literature", "task": "TF Markers", "db": "HPA", "dts": "DatasetX", "f01": 0.6},
66+
{"name": "MethodA", "class": "Predictive", "task": "Gene Sets", "db": "Hallmarks", "dataset": "DatasetX", "f01": 0.5},
67+
{"name": "MethodA", "class": "Literature", "task": "TF Markers", "db": "HPA", "dataset": "DatasetX", "f01": 0.6},
6868
# MethodB only has Predictive data → NaN for Literature class
69-
{"name": "MethodB", "class": "Predictive", "task": "Gene Sets", "db": "Hallmarks", "dts": "DatasetX", "f01": 0.4},
69+
{"name": "MethodB", "class": "Predictive", "task": "Gene Sets", "db": "Hallmarks", "dataset": "DatasetX", "f01": 0.4},
7070
]
7171
return pd.DataFrame(rows)
7272

@@ -381,3 +381,12 @@ def test_ranking_task_multiple_tasks_per_class(self, sample_ranking_df):
381381
fig = ranking(sample_ranking_df, level="task", return_fig=True)
382382
assert isinstance(fig, plt.Figure)
383383
plt.close(fig)
384+
385+
def test_save_writes_file(self, sample_ranking_df, tmp_path):
386+
"""Test that save parameter writes a file."""
387+
out = tmp_path / "ranking.png"
388+
result = ranking(sample_ranking_df, level="class", save=str(out))
389+
390+
assert result is None
391+
assert out.exists()
392+
plt.close("all")

0 commit comments

Comments
 (0)