Skip to content

Commit 2092219

Browse files
committed
Refactored return data for all methods
1 parent 6fc3916 commit 2092219

15 files changed

+85
-104
lines changed

decoupler/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__version__ = '1.6.3' # noqa: F401
22
__version_info__ = tuple([int(num) for num in __version__.split('.')]) # noqa: F401
33

4-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, mask_features # noqa: F401
4+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, mask_features, return_data # noqa: F401
55
from .utils import (
66
melt, show_methods, check_corr, get_toy_data, summarize_acts, assign_groups, dense_run, p_adjust_fdr, shuffle_net,
77
read_gmt # noqa: F401

decoupler/method_aucell.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numpy.random import default_rng
1111
from tqdm import tqdm
1212

13-
from .pre import extract, rename_net, filt_min_n
13+
from .pre import extract, rename_net, filt_min_n, return_data
1414

1515
from anndata import AnnData
1616
import numba as nb
@@ -151,9 +151,4 @@ def run_aucell(mat, net, source='source', target='target', n_up=None, min_n=5, s
151151
estimate = pd.DataFrame(estimate, index=r, columns=net.index)
152152
estimate.name = 'aucell_estimate'
153153

154-
# AnnData support
155-
if isinstance(mat, AnnData):
156-
# Update obsm AnnData object
157-
mat.obsm[estimate.name] = estimate
158-
else:
159-
return estimate
154+
return return_data(mat=mat, results=(estimate, ))

decoupler/method_gsea.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numpy.random import default_rng
1010
from scipy.sparse import csr_matrix
1111

12-
from .pre import extract, rename_net, filt_min_n
12+
from .pre import extract, rename_net, filt_min_n, return_data
1313
from .utils import p_adjust_fdr
1414

1515
from anndata import AnnData
@@ -369,15 +369,4 @@ def run_gsea(mat, net, source='source', target='target', times=1000, batch_size=
369369
pvals = pd.DataFrame(pvals, index=r, columns=net.index)
370370
pvals.name = 'gsea_pvals'
371371

372-
# AnnData support
373-
if isinstance(mat, AnnData):
374-
# Update obsm AnnData object
375-
mat.obsm[estimate.name] = estimate
376-
if norm_e is not None:
377-
mat.obsm[norm_e.name] = norm_e
378-
mat.obsm[pvals.name] = pvals
379-
else:
380-
if pvals is not None:
381-
return estimate, norm_e, pvals
382-
else:
383-
return estimate
372+
return return_data(mat=mat, results=(estimate, norm_e, pvals))

decoupler/method_gsva.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from scipy.sparse import csr_matrix
1111
from numpy.random import default_rng
1212

13-
from .pre import extract, rename_net, filt_min_n
13+
from .pre import extract, rename_net, filt_min_n, return_data
1414
from .method_gsea import std
1515

1616
from anndata import AnnData
@@ -232,9 +232,4 @@ def run_gsva(mat, net, source='source', target='target', kcdf=False, mx_diff=Tru
232232
estimate = pd.DataFrame(estimate, index=r, columns=net.index)
233233
estimate.name = 'gsva_estimate'
234234

235-
# AnnData support
236-
if isinstance(mat, AnnData):
237-
# Update obsm AnnData object
238-
mat.obsm[estimate.name] = estimate
239-
else:
240-
return estimate
235+
return return_data(mat=mat, results=(estimate, ))

decoupler/method_mdt.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from scipy.sparse import csr_matrix
99

10-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
10+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1111

1212
from anndata import AnnData
1313
from tqdm import tqdm
@@ -117,9 +117,4 @@ def run_mdt(mat, net, source='source', target='target', weight='weight', trees=1
117117
estimate = pd.DataFrame(estimate, index=r, columns=sources)
118118
estimate.name = 'mdt_estimate'
119119

120-
# AnnData support
121-
if isinstance(mat, AnnData):
122-
# Update obsm AnnData object
123-
mat.obsm[estimate.name] = estimate
124-
else:
125-
return estimate
120+
return return_data(mat=mat, results=(estimate, ))

decoupler/method_mlm.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from scipy.sparse import csr_matrix
99

10-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
10+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1111

1212
from anndata import AnnData
1313
from scipy import stats
@@ -131,10 +131,4 @@ def run_mlm(mat, net, source='source', target='target', weight='weight', batch_s
131131
pvals = pd.DataFrame(pvals, index=r, columns=sources)
132132
pvals.name = 'mlm_pvals'
133133

134-
# AnnData support
135-
if isinstance(mat, AnnData):
136-
# Update obsm AnnData object
137-
mat.obsm[estimate.name] = estimate
138-
mat.obsm[pvals.name] = pvals
139-
else:
140-
return estimate, pvals
134+
return return_data(mat=mat, results=(estimate, pvals))

decoupler/method_ora.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from scipy.stats import rankdata
1313
from math import log, exp, lgamma
1414

15-
from .pre import extract, rename_net, filt_min_n
15+
from .pre import extract, rename_net, filt_min_n, return_data
1616
from .utils import p_adjust_fdr
1717

1818
from anndata import AnnData
@@ -315,10 +315,4 @@ def run_ora(mat, net, source='source', target='target', n_up=None, n_bottom=0, n
315315
estimate = pd.DataFrame(-np.log10(pvals), index=r, columns=pvals.columns)
316316
estimate.name = 'ora_estimate'
317317

318-
# AnnData support
319-
if isinstance(mat, AnnData):
320-
# Update obsm AnnData object
321-
mat.obsm[estimate.name] = estimate
322-
mat.obsm[pvals.name] = pvals
323-
else:
324-
return estimate, pvals
318+
return return_data(mat=mat, results=(estimate, pvals))

decoupler/method_udt.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from scipy.sparse import csr_matrix
88
import pandas as pd
99

10-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
10+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1111

1212
from anndata import AnnData
1313
from tqdm import tqdm
@@ -114,9 +114,4 @@ def run_udt(mat, net, source='source', target='target', weight='weight', min_lea
114114
estimate = pd.DataFrame(estimate, index=r, columns=sources)
115115
estimate.name = 'udt_estimate'
116116

117-
# AnnData support
118-
if isinstance(mat, AnnData):
119-
# Update obsm AnnData object
120-
mat.obsm[estimate.name] = estimate
121-
else:
122-
return estimate
117+
return return_data(mat=mat, results=(estimate, ))

decoupler/method_ulm.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from scipy.stats import t
1111

12-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
12+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1313

1414
from anndata import AnnData
1515
from tqdm import tqdm
@@ -124,10 +124,4 @@ def run_ulm(mat, net, source='source', target='target', weight='weight', batch_s
124124
pvals = pd.DataFrame(pvals, index=r, columns=sources)
125125
pvals.name = 'ulm_pvals'
126126

127-
# AnnData support
128-
if isinstance(mat, AnnData):
129-
# Update obsm AnnData object
130-
mat.obsm[estimate.name] = estimate
131-
mat.obsm[pvals.name] = pvals
132-
else:
133-
return estimate, pvals
127+
return return_data(mat=mat, results=(estimate, pvals))

decoupler/method_viper.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from scipy.stats import rankdata
1111
from scipy.stats import norm
1212

13-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
13+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1414

1515
from anndata import AnnData
1616
from tqdm import tqdm
@@ -308,10 +308,4 @@ def run_viper(mat, net, source='source', target='target', weight='weight', pleio
308308
pvals = pd.DataFrame(pvals, index=r, columns=sources)
309309
pvals.name = 'viper_pvals'
310310

311-
# AnnData support
312-
if isinstance(mat, AnnData):
313-
# Update obsm AnnData object
314-
mat.obsm[estimate.name] = estimate
315-
mat.obsm[pvals.name] = pvals
316-
else:
317-
return estimate, pvals
311+
return return_data(mat=mat, results=(estimate, pvals))

decoupler/method_wmean.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from scipy.sparse import csr_matrix
99

10-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
10+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1111
from .method_gsea import std
1212

1313
from anndata import AnnData
@@ -177,16 +177,4 @@ def run_wmean(mat, net, source='source', target='target', weight='weight', times
177177
pvals = pd.DataFrame(pvals, index=r, columns=sources)
178178
pvals.name = 'wmean_pvals'
179179

180-
# AnnData support
181-
if isinstance(mat, AnnData):
182-
# Update obsm AnnData object
183-
mat.obsm[estimate.name] = estimate
184-
if pvals is not None:
185-
mat.obsm[norm.name] = norm
186-
mat.obsm[corr.name] = corr
187-
mat.obsm[pvals.name] = pvals
188-
else:
189-
if pvals is not None:
190-
return estimate, norm, corr, pvals
191-
else:
192-
return estimate
180+
return return_data(mat=mat, results=(estimate, norm, corr, pvals))

decoupler/method_wsum.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from scipy.sparse import csr_matrix
99

10-
from .pre import extract, match, rename_net, get_net_mat, filt_min_n
10+
from .pre import extract, match, rename_net, get_net_mat, filt_min_n, return_data
1111
from .method_gsea import std
1212

1313
from anndata import AnnData
@@ -173,16 +173,4 @@ def run_wsum(mat, net, source='source', target='target', weight='weight', times=
173173
pvals = pd.DataFrame(pvals, index=r, columns=sources)
174174
pvals.name = 'wsum_pvals'
175175

176-
# AnnData support
177-
if isinstance(mat, AnnData):
178-
# Update obsm AnnData object
179-
mat.obsm[estimate.name] = estimate
180-
if pvals is not None:
181-
mat.obsm[norm.name] = norm
182-
mat.obsm[corr.name] = corr
183-
mat.obsm[pvals.name] = pvals
184-
else:
185-
if pvals is not None:
186-
return estimate, norm, corr, pvals
187-
else:
188-
return estimate
176+
return return_data(mat=mat, results=(estimate, norm, corr, pvals))

decoupler/pre.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from scipy.sparse import csr_matrix, issparse
88
import pandas as pd
9-
9+
import logging
1010
from anndata import AnnData
1111

1212

@@ -278,3 +278,23 @@ def mask_features(mat, log=False, thr=1, use_raw=False):
278278
else:
279279
raise ValueError("""mat must be a list of [matrix, samples, features], dataframe (samples x features) or an AnnData
280280
instance.""")
281+
282+
283+
def add_to_anndata(mat, results):
284+
for result in results:
285+
if result is not None:
286+
mat.obsm[result.name] = result
287+
288+
289+
def return_data(mat, results):
290+
if isinstance(mat, AnnData):
291+
if mat.obs_names.size != results[0].index.size:
292+
logging.warning('Provided AnnData contains empty observations. Returning repaired object.')
293+
mat = mat[results[0].index, :].copy()
294+
add_to_anndata(mat, results)
295+
return mat
296+
else:
297+
add_to_anndata(mat, results)
298+
return None
299+
else:
300+
return tuple([result for result in results if result is not None])

decoupler/tests/test_pre.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import numpy as np
44
from scipy.sparse import csr_matrix
55
from anndata import AnnData
6-
from ..pre import check_mat, extract, filt_min_n, match, rename_net, get_net_mat, mask_features
6+
from ..pre import (
7+
check_mat, extract, filt_min_n, match, rename_net, get_net_mat, mask_features,
8+
return_data, add_to_anndata
9+
)
710

811

912
def test_check_mat():
@@ -101,3 +104,38 @@ def test_mask_features():
101104
mask_features('asdfg')
102105
with pytest.raises(ValueError):
103106
mask_features(adata, use_raw=True)
107+
108+
109+
def test_add_to_anndata():
110+
m = np.array([[1, 0, 2], [1, 0, 3]])
111+
r = np.array(['S1', 'S2'])
112+
c = np.array(['G1', 'G2', 'G3'])
113+
df = pd.DataFrame(m, index=r, columns=c)
114+
adata = AnnData(df.astype(np.float32))
115+
estimate = np.array([[1], [4]])
116+
s = np.array(['S1'])
117+
estimate = pd.DataFrame(estimate, index=r, columns=s)
118+
estimate.name = 'estimate'
119+
add_to_anndata(mat=adata, results=(estimate, None))
120+
assert 'estimate' in adata.obsm
121+
122+
123+
def test_return_data():
124+
m = np.array([[1, 0, 2], [1, 0, 3], [0, 0, 0]])
125+
r = np.array(['S1', 'S2', 'S3'])
126+
c = np.array(['G1', 'G2', 'G3'])
127+
df = pd.DataFrame(m, index=r, columns=c)
128+
adata = AnnData(df.astype(np.float32))
129+
estimate = np.array([[1], [4]])
130+
s = np.array(['S1'])
131+
estimate = pd.DataFrame(estimate, index=r[:-1], columns=s)
132+
estimate.name = 'estimate'
133+
pvals = np.array([[0.4], [0.01]])
134+
pvals = pd.DataFrame(pvals, index=estimate.index, columns=estimate.columns)
135+
pvals.name = 'pvals'
136+
ret = return_data(mat=adata, results=(estimate, pvals))
137+
assert isinstance(ret, AnnData)
138+
ret = return_data(mat=adata[estimate.index, :].copy(), results=(estimate, pvals))
139+
assert ret is None
140+
ret = return_data(mat=df, results=(estimate, pvals))
141+
assert isinstance(ret, tuple)

docs/source/release_notes.rst

+2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ Bug fixes
99
- Fixed error in in ``get_contrast`` by reverting use of ``copy.deepcopy`` to ``copy``.
1010
- Fixed verbose error regarding the number of unique sources being used in ``benchmark``.
1111
- Added check for minimum version of ``igraph>=0.10.0`` to properly render ``plot_network``.
12+
- Fixed return error of methods triggered when an observation was empty and input was ``AnnData``.
1213

1314
Changes
1415
~~~~~~~
1516
- Resource functions such as ``get_resource`` or ``get_collectri`` now accept different ``genesymbol_resource`` than UniProt for gene translation to other organisms.
17+
- Deprecated ``sklearn`` and switched to ``sklearn`` for ``udt``.
1618

1719
Additions
1820
~~~~~~~~~

0 commit comments

Comments
 (0)