99import numpy as np
1010import pandas as pd
1111import seaborn as sns
12- import spatialdata as sd
13- import spatialdata_plot # noqa: F401
1412from anndata import AnnData
1513from squidpy ._docs import d
1614
@@ -99,10 +97,26 @@ def boundaries(
9997 kwargs
10098 Additional arguments to pass to the `spatialdata.pl.render_shapes()` function.
10199
100+ Notes
101+ -----
102+ To visualize boundaries with this function, install `spatialdata` and
103+ `spatialdata-plot`, or install the optional extra:
104+
105+ - `pip install "cellcharter[shape]"`
106+
102107 Returns
103108 -------
104109 %(plotting_returns)s
105110 """
111+ # Optional dependency check
112+ try :
113+ import spatialdata as sd # type: ignore
114+ import spatialdata_plot # noqa: F401
115+ except ImportError as err :
116+ raise ImportError (
117+ "pl.boundaries requires 'spatialdata' and 'spatialdata-plot'. Install them with\n "
118+ " pip install spatialdata spatialdata-plot"
119+ ) from err
106120 adata = adata [adata .obs [library_key ] == sample ].copy ()
107121 del adata .raw
108122 clusters = adata .obs [component_key ].unique ()
@@ -234,8 +248,9 @@ def plot_shape_metrics(
234248 component_key : str = "component" ,
235249 metrics : str | tuple [str ] | list [str ] = ("linearity" , "curl" ),
236250 figsize : tuple [float , float ] = (8 , 7 ),
237- title : str | None = None ,
238- ) -> None :
251+ save : str | Path | None = None ,
252+ return_fig : bool = False ,
253+ ):
239254 """
240255 Boxplots of the shape metrics between two conditions.
241256
@@ -270,7 +285,7 @@ def plot_shape_metrics(
270285 FutureWarning ,
271286 stacklevel = 2 ,
272287 )
273- shape_metrics (
288+ return shape_metrics (
274289 adata = adata ,
275290 condition_key = condition_key ,
276291 condition_groups = condition_groups ,
@@ -279,7 +294,8 @@ def plot_shape_metrics(
279294 component_key = component_key ,
280295 metrics = metrics ,
281296 figsize = figsize ,
282- title = title ,
297+ save = save ,
298+ return_fig = return_fig ,
283299 )
284300
285301
@@ -364,7 +380,9 @@ def shape_metrics(
364380 fontsize : str | int = 14 ,
365381 figsize : tuple [float , float ] = (10 , 7 ),
366382 ncols : int = 2 ,
367- ) -> None :
383+ save : str | Path | None = None ,
384+ return_fig : bool = False ,
385+ ):
368386 """
369387 Boxplots of the shape metrics between two conditions.
370388
@@ -387,12 +405,27 @@ def shape_metrics(
387405 Figure size.
388406 ncols
389407 Number of columns in the subplot grid when plotting multiple metrics.
390- title
391- Title of the plot.
408+ save
409+ Path to save the plot. If provided, the plot will be saved using default parameters (``bbox_inches='tight'``).
410+ For more control over saving parameters, use ``return_fig=True`` and call ``savefig()`` manually.
411+ return_fig
412+ If ``True``, return the figure object for further customization. Default is ``False``.
392413
393414 Returns
394415 -------
395- %(plotting_returns)s
416+ If ``return_fig=True``, returns the :class:`matplotlib.figure.Figure` object.
417+ Otherwise returns ``None``.
418+
419+ Examples
420+ --------
421+ Basic usage with automatic saving:
422+
423+ >>> cc.pl.shape_metrics(adata, condition_key='condition', save='plot.pdf')
424+
425+ Advanced usage with custom save parameters:
426+
427+ >>> fig = cc.pl.shape_metrics(adata, condition_key='condition', return_fig=True)
428+ >>> fig.savefig('plot.pdf', dpi=300, bbox_inches='tight', transparent=True)
396429 """
397430 if isinstance (metrics , str ):
398431 metrics = [metrics ]
@@ -427,6 +460,9 @@ def shape_metrics(
427460
428461 metrics_df = adata .obs [[component_key ] + keys ].drop_duplicates ().dropna ().set_index (component_key )
429462
463+ # Initialize fig to None - will be set in one of the code paths below
464+ fig = None
465+
430466 for metric in metrics :
431467 metrics_df [metric ] = metrics_df .index .map (adata .uns [f"shape_{ component_key } " ][metric ])
432468
@@ -539,3 +575,14 @@ def shape_metrics(
539575 # Hide any unused subplots
540576 for j in range (i + 1 , len (axes )):
541577 axes [j ].set_visible (False )
578+
579+ # If fig wasn't set in any code path (edge case), get the current figure
580+ if fig is None :
581+ fig = plt .gcf ()
582+
583+ if save is not None :
584+ fig .savefig (save , bbox_inches = "tight" )
585+
586+ if return_fig :
587+ return fig
588+ return None
0 commit comments