22from typing import TYPE_CHECKING
33
44if TYPE_CHECKING :
5- from typing import Optional
5+ from typing import Optional , Literal
66
77import warnings
88
@@ -144,9 +144,11 @@ def calculate_enrichment_vs_all(
144144 adata : AnnData ,
145145 covariates : Optional [list [str ]] = None ,
146146 subcellular_enrichment_column : str = "subcellular_enrichment" ,
147+ enrichment_method : Literal ["lfc" , "proportion" ] = "lfc" ,
147148 correlation_threshold : float = 1.0 ,
148149 original_intensities_key : str | None = "original_intensities" ,
149150 keep_raw : bool = True ,
151+ min_comparison_warning : int | None = None ,
150152) -> AnnData :
151153 """Calculate enrichment of each subcellular enrichment vs all other samples as the background.
152154
@@ -159,10 +161,16 @@ def calculate_enrichment_vs_all(
159161 If None, uses columns starting with "covariate_"
160162 subcellular_enrichment_column
161163 Column in adata.var containing subcellular enrichment labels
164+ enrichment_method
165+ Calculating enrichment based on Log Fold Change (lfc) or Proportion-based analysis.
166+ Must be either "proportion" or "lfc"
162167 original_intensities_key
163168 If provided, store the original intensities in this layer
164169 keep_raw
165170 Whether to keep the unaggregated data in the .raw attribute of the returned AnnData object
171+ min_comparison_warning
172+ The minimum number of control samples required before issuing a warning about low statistical power.
173+
166174
167175 Returns
168176 -------
@@ -171,27 +179,31 @@ def calculate_enrichment_vs_all(
171179 Raw values are stored in .layers[original_intensities_key] if provided.
172180 """
173181
182+ if enrichment_method not in ["lfc" , "proportion" ]:
183+ raise ValueError ("enrichment_method must be either 'lfc' or 'proportion'" )
184+
174185 data = adata .copy ()
175186
176187 if covariates is None :
177188 covariates = data .var .columns [data .var .columns .str .startswith ("covariate_" )].tolist ()
178- # Check that all covariates are in the data
189+ if not isinstance (covariates , list ):
190+ covariates = [covariates ]
191+
179192 for c in covariates :
180193 if c not in data .var .columns :
181194 raise ValueError (f"Covariate { c } not found in data.var.columns" )
182195
183- if not isinstance (covariates , list ):
184- covariates = [covariates ]
185- # Create aggregated data with the desired output shape
186196 grouping_columns = [subcellular_enrichment_column ] + covariates
187- # Create a temporary column that contains the experimental conditions
197+
188198 data .var ["_experimental_condition" ] = data .var [grouping_columns ].apply (
189199 lambda x : "_" .join (x .dropna ().astype (str )),
190200 axis = 1 ,
191201 )
202+ data .var ["_covariates" ] = data .var [covariates ].apply (
203+ lambda x : "_" .join (x .dropna ().astype (str )), axis = 1
204+ )
192205
193206 data_aggr = aggregate_samples (data , grouping_columns = grouping_columns , keep_raw = False )
194- data_aggr .var_names = data_aggr .var_names .str .replace (r"_\d+" , "" , regex = True )
195207
196208 if original_intensities_key is not None :
197209 data_aggr .layers [original_intensities_key ] = data_aggr .X
@@ -203,22 +215,40 @@ def calculate_enrichment_vs_all(
203215
204216 for experimental_condition in data_aggr .var ["_experimental_condition" ].unique ():
205217 mask = data_aggr .var ["_experimental_condition" ] == experimental_condition
206- corr_mat_sub = corr_matrix [mask , :].mean (axis = 0 )
207- control_mask = ~ mask & (corr_mat_sub < correlation_threshold )
208- if control_mask .sum () < 10 :
209- warnings .warn (
210- f"Less than 10 ({ control_mask .sum ()} ) control samples found for condition: { experimental_condition } "
211- )
212- intensities_control = intensities [:, control_mask ]
218+
213219 intensities_ip = intensities [:, mask ]
220+ covariate = data .var .loc [
221+ data .var ._experimental_condition == experimental_condition , "_covariates"
222+ ].values [0 ]
223+ covariate_mask = data_aggr .var ["_covariates" ] == covariate
224+ control_mask = ~ mask & covariate_mask
225+ corr_mat_sub = corr_matrix [mask , control_mask ].mean (axis = 0 )
226+ control_mask = control_mask & (corr_mat_sub < correlation_threshold )
227+ intensities_control = intensities [:, control_mask ]
228+ if min_comparison_warning is not None :
229+ if control_mask .sum () < min_comparison_warning :
230+ warnings .warn (
231+ f"Less than { min_comparison_warning } ({ control_mask .sum ()} ) control samples found for condition: { experimental_condition } "
232+ ) # Check for statistical power (if fewer than 10 samples selected )
233+
214234 scores , pv = stats .ttest_ind (intensities_ip .T , intensities_control .T )
215- lfc = np .median (intensities_ip , axis = 1 ) - np .median (intensities_control , axis = 1 )
235+
236+ if enrichment_method == "lfc" :
237+ enrichment_values = np .median (intensities_ip , axis = 1 ) - np .median (
238+ intensities_control , axis = 1
239+ )
240+ else :
241+ enrichment_values = np .nansum (intensities_ip , axis = 1 ) / (
242+ np .nansum (intensities_ip , axis = 1 ) + np .nansum (intensities_control , axis = 1 )
243+ )
244+
216245 aggr_mask = data_aggr .var ["_experimental_condition" ] == experimental_condition
217246 data_aggr .layers ["pvals" ][:, aggr_mask ] = pv [:, None ]
218- data_aggr [:, aggr_mask ]. X = lfc [:, None ]
247+ data_aggr . X [:, aggr_mask ] = enrichment_values [:, None ]
219248 data_aggr .var .loc [aggr_mask , "enriched_vs" ] = "," .join (
220249 data_aggr .var_names [control_mask ]
221250 )
251+
222252 data_aggr .var .drop (columns = ["_experimental_condition" ], inplace = True )
223253 if keep_raw :
224254 data_aggr .raw = data .copy ()
0 commit comments