Skip to content

Commit cdfb772

Browse files
added MCC entropy score
1 parent 9942b71 commit cdfb772

1 file changed

Lines changed: 139 additions & 0 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
2+
3+
"""
4+
Calculates the entropy of MCC vectors (microenvironment cell type composition).
5+
Kozachenko-Leonenko estimate of entropy is calculated.
6+
"""
7+
import anndata
8+
import squidpy as sq
9+
import torch
10+
from torch_geometric.utils.convert import from_scipy_sparse_matrix
11+
import torch_geometric as pyg
12+
import pandas as pd
13+
import numpy as np
14+
import cupy as cp
15+
from cuml.neighbors import NearestNeighbors
16+
from scipy.special import psi, gamma
17+
from tqdm.autonotebook import tqdm
18+
19+
20+
from .. import modules
21+
22+
23+
# Grabbed/modified from google (AI mode)
24+
def _gpu_kl_entropy(np_data, k_calcentropy):
25+
"""
26+
Computes Shannon entropy for high-dim continuous variables on GPU.
27+
data: cupy.ndarray of shape (N, d)
28+
"""
29+
assert isinstance(np_data, np.ndarray)
30+
data = cp.asarray(np_data)
31+
k = k_calcentropy
32+
33+
N, d = data.shape
34+
35+
# 1. Use cuML to find the distance to the k-th neighbor
36+
# We set n_neighbors=k+1 because the first neighbor is the point itself (dist=0)
37+
nn = NearestNeighbors(n_neighbors=k+1)
38+
nn.fit(data)
39+
distances, _ = nn.kneighbors(data)
40+
41+
# Extract the k-th neighbor distance (last column)
42+
# Ensure no zero distances to avoid log(0)
43+
eps = distances[:, -1]
44+
eps = cp.maximum(eps, 1e-15)
45+
46+
# 2. Constants (calculated on CPU or GPU)
47+
v_d = (cp.pi**(d/2)) / gamma(1 + d/2)
48+
term1 = psi(N) - psi(k)
49+
term2 = cp.log(v_d)
50+
51+
# 3. Final summation using CuPy
52+
sum_log_eps = cp.sum(cp.log(eps))
53+
entropy = term1 + term2 + (d / N) * sum_log_eps
54+
55+
return float(entropy)
56+
57+
58+
def get_MCC_entropy(
59+
adata:anndata.AnnData,
60+
kwargs_neighbourhood_graph:dict,
61+
obskey_celltype:str,
62+
device,
63+
k_calcentropy:int,
64+
batch_size_computeMCC:int = 10,
65+
):
66+
"""
67+
Calculates the entropy-esimate of MCC vectors per cell type.
68+
69+
:param adata: The input anndata object.
70+
:type adata: anndata.AnnData
71+
:param kwargs_neighbourhood_graph: kwargs to create the neighbourhood graph. This function recreates the neighbourhood graph internally.
72+
:type kwargs_neighbourhood_graph: dict
73+
:param obskey_celltype: The column in `.obs` containig cell type annotations.
74+
:type obskey_celltype: str
75+
:param device: device, e.g., cpu or gpu (recommeneded)
76+
:param k_calcentropy: The number of nearest neighbours used by the Kozachenko-Leonenko estimator. Default is 1, while one can use, e.g., 3 or 5.
77+
:type k_calcentropy: int
78+
:param batch_size_computeMCC: The batch size of pyg neighbourloader to calculate the MCC vectors.
79+
:type batch_size_computeMCC: int
80+
"""
81+
82+
# compute the neighrborhood graph
83+
adata.uns = {}
84+
adata.obsp = {}
85+
sq.gr.spatial_neighbors(
86+
adata=adata,
87+
**kwargs_neighbourhood_graph
88+
)
89+
90+
# get `edge_index`
91+
with torch.no_grad():
92+
edge_index, _ = from_scipy_sparse_matrix(adata.obsp['spatial_connectivities']) # [2, num_edges]
93+
edge_index = torch.Tensor(pyg.utils.remove_self_loops(pyg.utils.to_undirected(edge_index))[0])
94+
95+
df_CT = pd.get_dummies(adata.obs[obskey_celltype])
96+
ten_CT = torch.tensor(np.array(df_CT) + 0.0, requires_grad=False)
97+
98+
# compute MCC
99+
module_compMCC = modules.gnn.KhopAvgPoolWithoutselfloop(
100+
num_hops=1,
101+
dim_input=None,
102+
dim_output=None
103+
)
104+
module_compMCC = module_compMCC.to(device)
105+
ten_MCC = module_compMCC.evaluate_layered(
106+
x=ten_CT,
107+
edge_index=edge_index,
108+
kwargs_dl={
109+
'batch_size':batch_size_computeMCC,
110+
'num_workers':0,
111+
'num_neighbors':[-1]
112+
}
113+
)
114+
115+
# compute the entropy values ct by ct
116+
dict_ct_to_MCCentropy = {}
117+
tmp_assert_rowsel = 0.0
118+
for ct in tqdm(set(adata.obs[obskey_celltype].tolist()), desc="Computing MCC entropy for different cell types"):
119+
list_rowsel = (adata.obs[obskey_celltype] == ct).tolist()
120+
121+
tmp_assert_rowsel = tmp_assert_rowsel + np.array(list_rowsel) + 0.0
122+
123+
dict_ct_to_MCCentropy[ct] = _gpu_kl_entropy(
124+
ten_MCC[list_rowsel, :].detach().cpu().numpy(),
125+
k_calcentropy=k_calcentropy
126+
)
127+
128+
assert np.allclose(
129+
tmp_assert_rowsel,
130+
np.ones_like(tmp_assert_rowsel)
131+
)
132+
133+
# create the df toret, to be used for, e.g., visualisation
134+
df = pd.DataFrame(
135+
{'cell_type':[k for k in dict_ct_to_MCCentropy.keys()], 'MCC_entropy':[v for _, v in dict_ct_to_MCCentropy.items()]}
136+
)
137+
df = df.sort_values(by='MCC_entropy', ascending=False)
138+
139+
return dict_ct_to_MCCentropy, df

0 commit comments

Comments
 (0)