Skip to content

Commit 461165f

Browse files
authored
Merge pull request #104 from CSOgroup/100-save-shape-plot
Allow saving shape_metrics plot
2 parents 5787208 + dd22310 commit 461165f

File tree

1 file changed

+57
-10
lines changed

1 file changed

+57
-10
lines changed

src/cellcharter/pl/_shape.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import numpy as np
1010
import pandas as pd
1111
import seaborn as sns
12-
import spatialdata as sd
13-
import spatialdata_plot # noqa: F401
1412
from anndata import AnnData
1513
from squidpy._docs import d
1614

@@ -99,10 +97,26 @@ def boundaries(
9997
kwargs
10098
Additional arguments to pass to the `spatialdata.pl.render_shapes()` function.
10199
100+
Notes
101+
-----
102+
To visualize boundaries with this function, install `spatialdata` and
103+
`spatialdata-plot`, or install the optional extra:
104+
105+
- `pip install "cellcharter[shape]"`
106+
102107
Returns
103108
-------
104109
%(plotting_returns)s
105110
"""
111+
# Optional dependency check
112+
try:
113+
import spatialdata as sd # type: ignore
114+
import spatialdata_plot # noqa: F401
115+
except ImportError as err:
116+
raise ImportError(
117+
"pl.boundaries requires 'spatialdata' and 'spatialdata-plot'. Install them with\n"
118+
" pip install spatialdata spatialdata-plot"
119+
) from err
106120
adata = adata[adata.obs[library_key] == sample].copy()
107121
del adata.raw
108122
clusters = adata.obs[component_key].unique()
@@ -234,8 +248,9 @@ def plot_shape_metrics(
234248
component_key: str = "component",
235249
metrics: str | tuple[str] | list[str] = ("linearity", "curl"),
236250
figsize: tuple[float, float] = (8, 7),
237-
title: str | None = None,
238-
) -> None:
251+
save: str | Path | None = None,
252+
return_fig: bool = False,
253+
):
239254
"""
240255
Boxplots of the shape metrics between two conditions.
241256
@@ -270,7 +285,7 @@ def plot_shape_metrics(
270285
FutureWarning,
271286
stacklevel=2,
272287
)
273-
shape_metrics(
288+
return shape_metrics(
274289
adata=adata,
275290
condition_key=condition_key,
276291
condition_groups=condition_groups,
@@ -279,7 +294,8 @@ def plot_shape_metrics(
279294
component_key=component_key,
280295
metrics=metrics,
281296
figsize=figsize,
282-
title=title,
297+
save=save,
298+
return_fig=return_fig,
283299
)
284300

285301

@@ -364,7 +380,9 @@ def shape_metrics(
364380
fontsize: str | int = 14,
365381
figsize: tuple[float, float] = (10, 7),
366382
ncols: int = 2,
367-
) -> None:
383+
save: str | Path | None = None,
384+
return_fig: bool = False,
385+
):
368386
"""
369387
Boxplots of the shape metrics between two conditions.
370388
@@ -387,12 +405,27 @@ def shape_metrics(
387405
Figure size.
388406
ncols
389407
Number of columns in the subplot grid when plotting multiple metrics.
390-
title
391-
Title of the plot.
408+
save
409+
Path to save the plot. If provided, the plot will be saved using default parameters (``bbox_inches='tight'``).
410+
For more control over saving parameters, use ``return_fig=True`` and call ``savefig()`` manually.
411+
return_fig
412+
If ``True``, return the figure object for further customization. Default is ``False``.
392413
393414
Returns
394415
-------
395-
%(plotting_returns)s
416+
If ``return_fig=True``, returns the :class:`matplotlib.figure.Figure` object.
417+
Otherwise returns ``None``.
418+
419+
Examples
420+
--------
421+
Basic usage with automatic saving:
422+
423+
>>> cc.pl.shape_metrics(adata, condition_key='condition', save='plot.pdf')
424+
425+
Advanced usage with custom save parameters:
426+
427+
>>> fig = cc.pl.shape_metrics(adata, condition_key='condition', return_fig=True)
428+
>>> fig.savefig('plot.pdf', dpi=300, bbox_inches='tight', transparent=True)
396429
"""
397430
if isinstance(metrics, str):
398431
metrics = [metrics]
@@ -427,6 +460,9 @@ def shape_metrics(
427460

428461
metrics_df = adata.obs[[component_key] + keys].drop_duplicates().dropna().set_index(component_key)
429462

463+
# Initialize fig to None - will be set in one of the code paths below
464+
fig = None
465+
430466
for metric in metrics:
431467
metrics_df[metric] = metrics_df.index.map(adata.uns[f"shape_{component_key}"][metric])
432468

@@ -539,3 +575,14 @@ def shape_metrics(
539575
# Hide any unused subplots
540576
for j in range(i + 1, len(axes)):
541577
axes[j].set_visible(False)
578+
579+
# If fig wasn't set in any code path (edge case), get the current figure
580+
if fig is None:
581+
fig = plt.gcf()
582+
583+
if save is not None:
584+
fig.savefig(save, bbox_inches="tight")
585+
586+
if return_fig:
587+
return fig
588+
return None

0 commit comments

Comments
 (0)