Skip to content

Commit b6bf638

Browse files
authored
Merge pull request #238 from Intron7/update-PV
Make make adjust PV faster
2 parents 1ab7c0d + a8e7fa6 commit b6bf638

3 files changed

Lines changed: 89 additions & 2 deletions

File tree

src/decoupler/mt/_pv.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numba as nb
2+
import numpy as np
3+
4+
5+
@nb.njit(cache=True)
6+
def _fdr_bh_single_row(ps_row, m):
7+
"""Apply Benjamini-Hochberg correction to a single row."""
8+
# Sort the row and get indices
9+
order = np.argsort(ps_row)
10+
ps_sorted = ps_row[order]
11+
12+
# BH scale: p_(i) * m / i
13+
ps_bh = np.empty_like(ps_sorted, dtype=np.float64)
14+
for i in range(m):
15+
ps_bh[i] = ps_sorted[i] * (m / (i + 1))
16+
17+
# Reverse cumulative min
18+
ps_rev = np.empty_like(ps_bh, dtype=np.float64)
19+
for i in range(m):
20+
ps_rev[i] = ps_bh[m - 1 - i]
21+
22+
for j in range(1, m):
23+
ps_rev[j] = min(ps_rev[j], ps_rev[j - 1])
24+
25+
# Reverse back
26+
ps_monotone = np.empty_like(ps_rev, dtype=np.float64)
27+
for i in range(m):
28+
ps_monotone[i] = ps_rev[m - 1 - i]
29+
30+
# Unsort back to original order
31+
ps_adj = np.empty_like(ps_monotone, dtype=np.float64)
32+
for i in range(m):
33+
ps_adj[order[i]] = ps_monotone[i]
34+
35+
# Clip to [0, 1]
36+
for i in range(m):
37+
ps_adj[i] = max(0.0, min(1.0, ps_adj[i]))
38+
39+
return ps_adj
40+
41+
42+
@nb.njit(parallel=True, cache=True)
43+
def _fdr_bh_parallel(ps, m):
44+
"""Apply Benjamini-Hochberg correction to all rows in parallel."""
45+
n_rows = ps.shape[0]
46+
result = np.empty_like(ps, dtype=np.float64)
47+
48+
for i in nb.prange(n_rows):
49+
result[i] = _fdr_bh_single_row(ps[i], m)
50+
51+
return result
52+
53+
54+
def _fdr_bh_axis1_numba(ps):
55+
"""Benjamini–Hochberg adjusted p-values along axis=1 (rows)."""
56+
ps = np.asarray(ps, dtype=np.float64)
57+
if ps.ndim != 2:
58+
raise ValueError("ps must be 2D (n_rows, n_tests) for axis=1.")
59+
if not np.issubdtype(ps.dtype, np.number):
60+
raise ValueError("`ps` must be numeric.")
61+
if not np.all((ps >= 0) & (ps <= 1)):
62+
raise ValueError("`ps` must be within [0, 1].")
63+
64+
n_rows, m = ps.shape
65+
if m <= 1:
66+
return ps.copy().astype(np.float32)
67+
68+
# Process each row in parallel
69+
result = _fdr_bh_parallel(ps, m)
70+
return result.astype(np.float32)

src/decoupler/mt/_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import numpy as np
44
import pandas as pd
55
import scipy.sparse as sps
6-
import scipy.stats as sts
76
from anndata import AnnData
87
from tqdm.auto import tqdm
98

109
from decoupler._datatype import DataType
1110
from decoupler._log import _log
11+
from decoupler.mt._pv import _fdr_bh_axis1_numba
1212
from decoupler.pp.data import extract
1313
from decoupler.pp.net import adjmat, idxmat, prune
1414

@@ -115,7 +115,7 @@ def _run(
115115
pv = pd.DataFrame(pv, index=obs, columns=sources)
116116
if name != "mlm":
117117
_log(f"{name} - adjusting p-values by FDR", level="info", verbose=verbose)
118-
pv.loc[:, :] = sts.false_discovery_control(pv.values, axis=1, method="bh")
118+
pv.loc[:, :] = _fdr_bh_axis1_numba(pv.values)
119119
else:
120120
pv = None
121121
_log(f"{name} - done", level="info", verbose=verbose)

tests/mt/test_adj.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import scipy.stats as sts
5+
6+
import decoupler as dc
7+
from decoupler.mt._pv import _fdr_bh_axis1_numba
8+
9+
10+
def test_func_mlm(
11+
adata,
12+
net,
13+
):
14+
dc.mt.mlm(data=adata, net=net, tmin=3)
15+
dc_pv = adata.obsm["padj_mlm"]
16+
adj = _fdr_bh_axis1_numba(dc_pv.values)
17+
np.testing.assert_allclose(adj, sts.false_discovery_control(dc_pv.values, axis=1, method="bh"))

0 commit comments

Comments
 (0)