1212@docs .dedent
1313def filter_samples (
1414 adata : AnnData ,
15- groupby : str | list ,
15+ groupby : str | list | None = None ,
1616 log : bool = True ,
1717 min_cells : int | float = 10 ,
1818 min_counts : int | float = 1000 ,
@@ -51,10 +51,11 @@ def filter_samples(
5151 assert all (col in adata .obs .columns for col in ["psbulk_cells" , "psbulk_counts" ]), (
5252 "psbulk_* columns not present in adata.obs, this function should be used after running decoupler.pp.pseudobulk"
5353 )
54- assert isinstance (groupby , str | list ), "groupby must be str or list "
54+ assert isinstance (groupby , str | list ) or groupby is None , "groupby must be str, list or None "
5555 if isinstance (groupby , str ):
5656 groupby = [groupby ]
57- assert all (col in adata .obs for col in groupby ), "columns in groupby must be in adata.obs"
57+ if groupby :
58+ assert all (col in adata .obs for col in groupby ), "columns in groupby must be in adata.obs"
5859 # Extract obs
5960 df = adata .obs .copy ()
6061 # Transform to log10
@@ -65,7 +66,7 @@ def filter_samples(
6566 label_x , label_y = r"$\log_{10}$ " + label_x , r"$\log_{10}$ " + label_y
6667 min_cells , min_counts = np .log10 (min_cells ), np .log10 (min_counts )
6768 # Plot
68- if len (groupby ) > 1 :
69+ if groupby is not None and ( len (groupby ) > 1 ) :
6970 # Instance
7071 assert kwargs .get ("ax" ) is None , "when groupby is list, ax must be None"
7172 kwargs ["ax" ] = None
@@ -85,12 +86,14 @@ def filter_samples(
8586 ax .axhline (y = min_counts , linestyle = "--" , color = "black" )
8687 else :
8788 # Instance
88- groupby = groupby [0 ]
89+ if isinstance (groupby , list ):
90+ groupby = groupby [0 ]
8991 bp = Plotter (** kwargs )
9092 bp .ax .grid (zorder = 0 )
9193 bp .ax .set_axisbelow (True )
9294 sns .scatterplot (x = "psbulk_cells" , y = "psbulk_counts" , hue = groupby , ax = bp .ax , data = df , zorder = 1 )
93- bp .ax .legend (loc = "center left" , bbox_to_anchor = (1 , 0.5 ), frameon = False , title = groupby )
95+ if groupby :
96+ bp .ax .legend (loc = "center left" , bbox_to_anchor = (1 , 0.5 ), frameon = False , title = groupby )
9497 bp .ax .set_xlabel (label_x )
9598 bp .ax .set_ylabel (label_y )
9699 bp .ax .axvline (x = min_cells , linestyle = "--" , color = "black" )
0 commit comments