1+
2+
3+ """
4+ Calculates the entropy of MCC vectors (microenvironment cell type composition).
5+ Kozachenko-Leonenko estimate of entropy is calculated.
6+ """
7+ import anndata
8+ import squidpy as sq
9+ import torch
10+ from torch_geometric .utils .convert import from_scipy_sparse_matrix
11+ import torch_geometric as pyg
12+ import pandas as pd
13+ import numpy as np
14+ import cupy as cp
15+ from cuml .neighbors import NearestNeighbors
16+ from scipy .special import psi , gamma
17+ from tqdm .autonotebook import tqdm
18+
19+
20+ from .. import modules
21+
22+
23+ # Grabbed/modified from google (AI mode)
24+ def _gpu_kl_entropy (np_data , k_calcentropy ):
25+ """
26+ Computes Shannon entropy for high-dim continuous variables on GPU.
27+ data: cupy.ndarray of shape (N, d)
28+ """
29+ assert isinstance (np_data , np .ndarray )
30+ data = cp .asarray (np_data )
31+ k = k_calcentropy
32+
33+ N , d = data .shape
34+
35+ # 1. Use cuML to find the distance to the k-th neighbor
36+ # We set n_neighbors=k+1 because the first neighbor is the point itself (dist=0)
37+ nn = NearestNeighbors (n_neighbors = k + 1 )
38+ nn .fit (data )
39+ distances , _ = nn .kneighbors (data )
40+
41+ # Extract the k-th neighbor distance (last column)
42+ # Ensure no zero distances to avoid log(0)
43+ eps = distances [:, - 1 ]
44+ eps = cp .maximum (eps , 1e-15 )
45+
46+ # 2. Constants (calculated on CPU or GPU)
47+ v_d = (cp .pi ** (d / 2 )) / gamma (1 + d / 2 )
48+ term1 = psi (N ) - psi (k )
49+ term2 = cp .log (v_d )
50+
51+ # 3. Final summation using CuPy
52+ sum_log_eps = cp .sum (cp .log (eps ))
53+ entropy = term1 + term2 + (d / N ) * sum_log_eps
54+
55+ return float (entropy )
56+
57+
58+ def get_MCC_entropy (
59+ adata :anndata .AnnData ,
60+ kwargs_neighbourhood_graph :dict ,
61+ obskey_celltype :str ,
62+ device ,
63+ k_calcentropy :int ,
64+ batch_size_computeMCC :int = 10 ,
65+ ):
66+ """
67+ Calculates the entropy-esimate of MCC vectors per cell type.
68+
69+ :param adata: The input anndata object.
70+ :type adata: anndata.AnnData
71+ :param kwargs_neighbourhood_graph: kwargs to create the neighbourhood graph. This function recreates the neighbourhood graph internally.
72+ :type kwargs_neighbourhood_graph: dict
73+ :param obskey_celltype: The column in `.obs` containig cell type annotations.
74+ :type obskey_celltype: str
75+ :param device: device, e.g., cpu or gpu (recommeneded)
76+ :param k_calcentropy: The number of nearest neighbours used by the Kozachenko-Leonenko estimator. Default is 1, while one can use, e.g., 3 or 5.
77+ :type k_calcentropy: int
78+ :param batch_size_computeMCC: The batch size of pyg neighbourloader to calculate the MCC vectors.
79+ :type batch_size_computeMCC: int
80+ """
81+
82+ # compute the neighrborhood graph
83+ adata .uns = {}
84+ adata .obsp = {}
85+ sq .gr .spatial_neighbors (
86+ adata = adata ,
87+ ** kwargs_neighbourhood_graph
88+ )
89+
90+ # get `edge_index`
91+ with torch .no_grad ():
92+ edge_index , _ = from_scipy_sparse_matrix (adata .obsp ['spatial_connectivities' ]) # [2, num_edges]
93+ edge_index = torch .Tensor (pyg .utils .remove_self_loops (pyg .utils .to_undirected (edge_index ))[0 ])
94+
95+ df_CT = pd .get_dummies (adata .obs [obskey_celltype ])
96+ ten_CT = torch .tensor (np .array (df_CT ) + 0.0 , requires_grad = False )
97+
98+ # compute MCC
99+ module_compMCC = modules .gnn .KhopAvgPoolWithoutselfloop (
100+ num_hops = 1 ,
101+ dim_input = None ,
102+ dim_output = None
103+ )
104+ module_compMCC = module_compMCC .to (device )
105+ ten_MCC = module_compMCC .evaluate_layered (
106+ x = ten_CT ,
107+ edge_index = edge_index ,
108+ kwargs_dl = {
109+ 'batch_size' :batch_size_computeMCC ,
110+ 'num_workers' :0 ,
111+ 'num_neighbors' :[- 1 ]
112+ }
113+ )
114+
115+ # compute the entropy values ct by ct
116+ dict_ct_to_MCCentropy = {}
117+ tmp_assert_rowsel = 0.0
118+ for ct in tqdm (set (adata .obs [obskey_celltype ].tolist ()), desc = "Computing MCC entropy for different cell types" ):
119+ list_rowsel = (adata .obs [obskey_celltype ] == ct ).tolist ()
120+
121+ tmp_assert_rowsel = tmp_assert_rowsel + np .array (list_rowsel ) + 0.0
122+
123+ dict_ct_to_MCCentropy [ct ] = _gpu_kl_entropy (
124+ ten_MCC [list_rowsel , :].detach ().cpu ().numpy (),
125+ k_calcentropy = k_calcentropy
126+ )
127+
128+ assert np .allclose (
129+ tmp_assert_rowsel ,
130+ np .ones_like (tmp_assert_rowsel )
131+ )
132+
133+ # create the df toret, to be used for, e.g., visualisation
134+ df = pd .DataFrame (
135+ {'cell_type' :[k for k in dict_ct_to_MCCentropy .keys ()], 'MCC_entropy' :[v for _ , v in dict_ct_to_MCCentropy .items ()]}
136+ )
137+ df = df .sort_values (by = 'MCC_entropy' , ascending = False )
138+
139+ return dict_ct_to_MCCentropy , df
0 commit comments