Skip to content

Commit e679d32

Browse files
committed
Added order process functions
1 parent 2d97793 commit e679d32

1 file changed

Lines changed: 146 additions & 6 deletions

File tree

decoupler/utils_anndata.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import numpy as np
7+
import scipy.stats as sts
78
from scipy.sparse import csr_matrix, issparse
89
import pandas as pd
910
import sys
@@ -12,7 +13,7 @@
1213
from tqdm import tqdm
1314

1415
from .utils import melt, p_adjust_fdr
15-
from .pre import rename_net
16+
from .pre import rename_net, extract
1617

1718

1819
def get_acts(adata, obsm_key, dtype=np.float32):
@@ -835,8 +836,6 @@ def rank_sources_groups(adata, groupby, reference='rest', method='t-test_overest
835836
results: DataFrame with changes in source activity score between groups.
836837
"""
837838

838-
from scipy.stats import ranksums, ttest_ind_from_stats
839-
840839
# Get tf names
841840
features = adata.var.index.values
842841

@@ -874,9 +873,9 @@ def rank_sources_groups(adata, groupby, reference='rest', method='t-test_overest
874873
assert np.all(np.isfinite(v_group)) and np.all(np.isfinite(v_rest)), \
875874
"adata contains not finite values, please remove them."
876875
if method == 'wilcoxon':
877-
stat, pval = ranksums(v_group, v_rest)
876+
stat, pval = sts.ranksums(v_group, v_rest)
878877
elif method == 't-test':
879-
stat, pval = ttest_ind_from_stats(
878+
stat, pval = sts.ttest_ind_from_stats(
880879
mean1=np.mean(v_group),
881880
std1=np.std(v_group, ddof=1),
882881
nobs1=v_group.size,
@@ -886,7 +885,7 @@ def rank_sources_groups(adata, groupby, reference='rest', method='t-test_overest
886885
equal_var=False, # Welch's
887886
)
888887
elif method == 't-test_overestim_var':
889-
stat, pval = ttest_ind_from_stats(
888+
stat, pval = sts.ttest_ind_from_stats(
890889
mean1=np.mean(v_group),
891890
std1=np.std(v_group, ddof=1),
892891
nobs1=v_group.size,
@@ -1118,3 +1117,144 @@ def get_metadata_associations(data, obs_keys=None, obsm_key=None, use_X=False, l
11181117
stats_df = stats_df.reset_index()
11191118

11201119
return format_assoc_results(data, stats_df, inplace, obsm_key, uns_key, use_X, layer)
1120+
1121+
1122+
def rank_sources_ordered(adata, order, thr_padj=0.05, use_raw=False, seed=42):
1123+
"""
1124+
Rank sources along a continuous, ordered process such as pseudotime.
1125+
1126+
Parameters
1127+
----------
1128+
adata : AnnData
1129+
AnnData obtained after running ``decoupler.utils_anndata.get_acts``.
1130+
order: str
1131+
The name of the column in ``.obs`` to consider for ordering.
1132+
thr_padj: float
1133+
Threshold used to assign significance after FDR correction.
1134+
use_raw : bool
1135+
Use raw attribute of mat if present.
1136+
seed : int
1137+
Random seed to use.
1138+
1139+
Returns
1140+
-------
1141+
DataFrame with sources associated with the ordering variable. For each source the following statistics are reported:
1142+
- importance of the ``XGBRegressor``
1143+
- Pearson correlation coefficient
1144+
- the sign of the association, 0 if the correltation is non-significant, and +1 or -1 depending on the correlation sign.
1145+
"""
1146+
1147+
try:
1148+
from xgboost import XGBRegressor
1149+
except Exception:
1150+
raise ImportError('xgboost is not installed. Please install it with: pip install xgboost')
1151+
1152+
# Get vars and ordinal variable
1153+
X, _, names = extract(adata, use_raw=use_raw)
1154+
if issparse(X):
1155+
X = X.toarray()
1156+
y = adata.obs[order].values
1157+
1158+
# Fit
1159+
reg = XGBRegressor(random_state=seed).fit(X, y)
1160+
df = pd.DataFrame()
1161+
df['name'] = names
1162+
df['impr'] = reg.feature_importances_
1163+
df['corr'], df['pval'] = sts.pearsonr(X, y.reshape(-1, 1), axis=0)
1164+
df['padj'] = sts.false_discovery_control(df['pval'])
1165+
df = df.sort_values('impr', ascending=False).reset_index(drop=True)
1166+
1167+
# Find direction of change
1168+
sign = []
1169+
for corr, padj in zip(df['corr'], df['padj']):
1170+
if padj < thr_padj:
1171+
if corr > 0:
1172+
s = 1
1173+
else:
1174+
s = -1
1175+
else:
1176+
s = 0
1177+
sign.append(s)
1178+
df['sign'] = sign
1179+
1180+
return df
1181+
1182+
1183+
def bin_sources_ordered(adata, order, names, label=None, nbins=100, use_raw=False):
1184+
"""
1185+
Bins given sources along a continuous, ordered process such as pseudotime.
1186+
Used before ``decoupler.plot_sources_ordered``.
1187+
1188+
Parameters
1189+
----------
1190+
adata : AnnData
1191+
AnnData obtained after running ``decoupler.utils_anndata.get_acts``.
1192+
order: str
1193+
The name of the column in ``.obs`` to consider for ordering.
1194+
names: str, list
1195+
Names of the sources to bin.
1196+
label: str, None
1197+
The name of the column in ``.obs`` to consider for coloring the grouping. By default ``None``.
1198+
nbins: int
1199+
Number of bins to use.
1200+
use_raw : bool
1201+
Use raw attribute of mat if present.
1202+
1203+
Returns
1204+
-------
1205+
DataFrame with sources binned alng a continous ordered proess.
1206+
"""
1207+
1208+
# Get vars and ordinal variable
1209+
X, _, cnames = extract(adata, use_raw=use_raw)
1210+
if issparse(X):
1211+
X = X.toarray()
1212+
y = adata.obs[order].values
1213+
1214+
# Normalize to 0 and 1
1215+
yabs = np.abs(y)
1216+
ymax = yabs.max()
1217+
ymin = yabs.min()
1218+
y = (y - ymin) / (ymax - ymin)
1219+
1220+
# Check inputs
1221+
if isinstance(names, str):
1222+
names = [names]
1223+
assert np.isin(names, cnames).all(), 'names must be inside adata.var_names'
1224+
assert nbins > 1 and isinstance(nbins, int), 'nbins should be higher than 1 and be an integer'
1225+
1226+
# Make windows
1227+
bin_edges = np.linspace(0, 1, nbins + 1)
1228+
bin_midpoints = (bin_edges[:-1] + bin_edges[1:]) / 2
1229+
1230+
# Prepare label colors
1231+
cols = ['name', 'midpoint', 'value']
1232+
if label is not None:
1233+
adata.obs[label] = pd.Categorical(adata.obs[label])
1234+
if adata.uns[f'{label}_colors'] is None:
1235+
from matplotlib.colors import to_hex
1236+
import matplotlib.pyplot as plt
1237+
cmap = plt.get_cmap('tab10')
1238+
adata.uns[f'{label}_colors'] = [to_hex(cmap(i)) for i in adata.obs[label].sort_values().cat.codes.unique()]
1239+
cols += ['label', 'color']
1240+
1241+
dfs = []
1242+
for name in names:
1243+
# Assign to windows based on order
1244+
df = pd.DataFrame()
1245+
df['value'] = X[:, cnames == name].ravel()
1246+
df['name'] = name
1247+
df['order'] = y
1248+
df['window'] = pd.cut(df['order'], bins=bin_edges, labels=False, include_lowest=True, right=True)
1249+
df['midpoint'] = df['window'].map(lambda x: bin_midpoints[int(x)])
1250+
if label is not None:
1251+
df['label'] = adata.obs[label].values
1252+
df['color'] = [adata.uns[f'{label}_colors'][i] for i in adata.obs[label].cat.codes]
1253+
df = df.sort_values('order')
1254+
dfs.append(df)
1255+
df = pd.concat(dfs)
1256+
df = df[cols]
1257+
df = df.rename(columns={'midpoint': 'order'}).reset_index(drop=True)
1258+
omin, omax = df['order'].min(), df['order'].max()
1259+
df['order'] = (df['order'] - omin) / (omax - omin)
1260+
return df

0 commit comments

Comments
 (0)