diff --git a/CHANGELOG.md b/CHANGELOG.md index 79fc9ec3af..2a27066455 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo #### Added +- Add {mod}`scvi.external.harreman` for inference of metabolic exchanges in tissues using spatial transcriptomics {pr}`XXXX`. - Add support for Python 3.14, {pr}`3563`. - Add support for Pandas3, {pr}`3638`. diff --git a/docs/tutorials/index_spatial.md b/docs/tutorials/index_spatial.md index c67af7294f..b8fd084d05 100644 --- a/docs/tutorials/index_spatial.md +++ b/docs/tutorials/index_spatial.md @@ -3,6 +3,7 @@ ```{toctree} :maxdepth: 1 +notebooks/spatial/harreman_tutorial notebooks/spatial/resolVI_tutorial notebooks/spatial/scVIVA_tutorial notebooks/spatial/DestVI_tutorial @@ -12,6 +13,13 @@ notebooks/spatial/stereoscope_heart_LV_tutorial notebooks/spatial/cell2location_lymph_node_spatial_tutorial ``` +```{customcard} +:path: notebooks/spatial/harreman_tutorial +:tags: Analysis, Spatial-statistics, Metabolic-inference + +Infer metabolic exchanges in tissues using spatial transcriptomics with Harreman +``` + ```{customcard} :path: notebooks/spatial/resolVI_tutorial :tags: Analysis, Integration, Transfer-learning, Dimensionality-reduction diff --git a/docs/user_guide/models/harreman.md b/docs/user_guide/models/harreman.md new file mode 100644 index 0000000000..fd5caede4c --- /dev/null +++ b/docs/user_guide/models/harreman.md @@ -0,0 +1,125 @@ +# Harreman + +**Harreman** (`scvi.external.harreman`) is a toolkit for inferring metabolic exchanges and cell-cell communication in tissues using spatial transcriptomics data. + +The advantages of Harreman are: + +- Infers spatially-resolved metabolic gene programs using local autocorrelation +- Identifies cell-cell metabolic communication and ligand-receptor interactions using spatial proximity graphs +- Supports multiple spatial technologies (Visium, Slide-seq, and others) +- Scalable to large spatial datasets +- Supports both parametric and non-parametric significance testing + +The limitations of Harreman include: + +- Requires spatial coordinates to be available in `adata.obsm` +- Cell communication inference requires a ligand-receptor or metabolite transporter database + +```{topic} Tutorials: + +- {doc}`/tutorials/notebooks/spatial/harreman_tutorial` +``` + +```{topic} External links: + +- [Harreman documentation](https://harreman.readthedocs.io) +- [Harreman GitHub](https://github.com/YosefLab/Harreman) +``` + +## Overview + +Harreman operates in three main steps: + +1. **Spatial graph construction** ({func}`~scvi.external.harreman.tl.compute_knn_graph`): builds a spatial proximity graph from cell coordinates, supporting both k-nearest neighbors and radius-based neighborhoods, with optional Gaussian kernel weighting. + +2. **Local autocorrelation** ({func}`~scvi.external.harreman.hs.compute_local_autocorrelation`): identifies spatially variable genes using the local autocorrelation statistic from the Hotspot algorithm (DeTomaso and Yosef, *Cell systems*, 2021), supporting DANB, Bernoulli, and normal count models. + +3. **Cell communication** ({func}`~scvi.external.harreman.tl.compute_cell_communication`): infers spatially-resolved metabolic exchanges and ligand-receptor interactions between neighboring cells using HarremanDB and CellChatDB. + +## Generative process + +At the coarsest level, Harreman partitions the tissue into modules of different metabolic functions based on enzyme co-expression. At the following stage, Harreman formulates hypotheses about which metabolites are exchanged across the tissue or within each spatial zone. Moving to a finer resolution, Harreman can also infer which specific cell subsets participate in the exchange of distinct metabolic activities inside each zone. + +For proteins composed of multiple subunits, Harreman computes either an algebraic or geometric mean of the expression values of the corresponding genes: + +```{math} +:nowrap: true + +\begin{align} + X_{ai} &= \frac{\sum_{l \in S_l} X_{a_li}}{|S_l|}; \quad X_{bj} = \frac{\sum_{r \in S_r} X_{b_rj}}{|S_r|} +\end{align} +``` + +### Test statistic 1: Spatial autocorrelation + +Spatially variable genes are identified using the following autocorrelation statistic: + +```{math} +:nowrap: true + +\begin{align} + H_{a} &= \sum_{i}\sum_{j} w_{ij}X_{ai}X_{aj} +\end{align} +``` + +where $w_{ij}$ represents the communication strength between neighboring cells, computed using a Gaussian kernel: + +```{math} +:nowrap: true + +\begin{align} + \hat{w}_{ij} &= e^{-d_{ij}^2/\sigma_{i}^2} +\end{align} +``` + +Significance is assessed by converting $H_a$ to a Z-score and adjusting p-values using the Benjamini-Hochberg procedure. + +### Test statistic 2: Spatial co-localization + +Pairwise spatial correlation between genes is computed as: + +```{math} +:nowrap: true + +\begin{align} + H_{ab} &= \sum_{i}\sum_{j} w_{ij} \left(X_{ai}X_{bj} + X_{bi}X_{aj}\right) +\end{align} +``` + +This statistic is used to group genes into spatial modules and to identify cell-type-agnostic metabolic exchange events. + +### Test statistic 3: Metabolite autocorrelation + +Gene-pair results are integrated at the metabolite level: + +```{math} +:nowrap: true + +\begin{align} + H_{m} &= \sum_{a,b \in m} H_{ab} +\end{align} +``` + +where $m$ is a metabolite exchanged by genes $a$ and $b$. + +## Usage + +```python +import scvi.external.harreman as harreman + +# 1. Build spatial KNN graph +harreman.tl.compute_knn_graph(adata, compute_neighbors_on_key="spatial", n_neighbors=10) + +# 2. Identify spatially variable genes +harreman.hs.compute_local_autocorrelation(adata, model="danb") + +# 3. Compute pairwise local correlation +harreman.hs.compute_local_correlation(adata) + +# 4. Infer cell-cell communication +harreman.tl.compute_cell_communication(adata) +``` + +## API + +Please see {mod}`scvi.external.harreman` for the full API reference. diff --git a/docs/user_guide/models/index.md b/docs/user_guide/models/index.md index 3a75d45ca2..e4e275d73b 100644 --- a/docs/user_guide/models/index.md +++ b/docs/user_guide/models/index.md @@ -11,6 +11,7 @@ cytovi decipher destvi gimvi +harreman linearscvi methylanvi methylvi diff --git a/pyproject.toml b/pyproject.toml index 423018afba..eefd5d872d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ ] [project.optional-dependencies] +harreman = ["pooch"] tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"] editing = ["jupyter", "pre-commit"] dev = ["scvi-tools[editing,tests]"] diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 201f7a5ae6..2de26112c0 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -3,6 +3,7 @@ from scvi import settings from scvi.utils import error_on_missing_dependencies +from . import harreman from .cellassign import CellAssign from .contrastivevi import ContrastiveVI from .cytovi import CYTOVI @@ -43,6 +44,7 @@ "RESOLVI", "SCVIVA", "CYTOVI", + "harreman", ] diff --git a/src/scvi/external/harreman/__init__.py b/src/scvi/external/harreman/__init__.py new file mode 100644 index 0000000000..36b37b82c4 --- /dev/null +++ b/src/scvi/external/harreman/__init__.py @@ -0,0 +1,6 @@ +from . import datasets as ds +from . import hotspot as hs +from . import preprocessing as pp +from . import tools as tl + +__all__ = ["ds", "hs", "pp", "tl"] diff --git a/src/scvi/external/harreman/datasets/__init__.py b/src/scvi/external/harreman/datasets/__init__.py new file mode 100644 index 0000000000..ae68140a6c --- /dev/null +++ b/src/scvi/external/harreman/datasets/__init__.py @@ -0,0 +1 @@ +from .datasets import load_slide_seq_human_lung_dataset, load_visium_mouse_colon_dataset diff --git a/src/scvi/external/harreman/datasets/datasets.py b/src/scvi/external/harreman/datasets/datasets.py new file mode 100644 index 0000000000..f890114c62 --- /dev/null +++ b/src/scvi/external/harreman/datasets/datasets.py @@ -0,0 +1,77 @@ +import os +import tempfile + +import scanpy as sc + +temp_dir_obj = tempfile.TemporaryDirectory() + + +def load_visium_mouse_colon_dataset( + sample: str | None = None, +) -> "sc.AnnData": + """ + Load the mouse colon 10x Visium dataset. + + Returns + ------- + adata : AnnData + The loaded 10x Visium dataset. + """ + dataset_prefix = "Parigi_et_al_mouse_colon" + + samples_path_dict = { + "d0": "https://figshare.com/ndownloader/files/59325113", + "d14": "https://figshare.com/ndownloader/files/59325116", + } + + if sample: + if sample not in samples_path_dict.keys(): + raise ValueError(f'"sample" needs to be one of: {list(samples_path_dict.keys())}') + else: + adata_path = os.path.join(temp_dir_obj.name, f"{dataset_prefix}_{sample}.h5ad") + backup_url = samples_path_dict[sample] + else: + adata_path = os.path.join(temp_dir_obj.name, f"{dataset_prefix}_unrolled.h5ad") + backup_url = "https://figshare.com/ndownloader/files/59325119" + + adata = sc.read(adata_path, backup_url=backup_url) + + return adata + + +def load_slide_seq_human_lung_dataset( + sample: str | None = None, +) -> "sc.AnnData": + """ + Load the human lung Slide-seq dataset. + + Returns + ------- + adata : AnnData + The loaded Slide-seq dataset. + """ + dataset_prefix = "Liu_et_al_human_lung" + + samples_path_dict = { + "Puck_200727_08": "https://figshare.com/ndownloader/files/59325098", + "Puck_200727_09": "https://figshare.com/ndownloader/files/59325092", + "Puck_200727_10": "https://figshare.com/ndownloader/files/59325095", + "Puck_220408_13": "https://figshare.com/ndownloader/files/59325101", + "Puck_220408_14": "https://figshare.com/ndownloader/files/59325104", + "Puck_220408_15": "https://figshare.com/ndownloader/files/59325107", + "Puck_220408_20": "https://figshare.com/ndownloader/files/59325110", + } + + if sample: + if sample not in samples_path_dict.keys(): + raise ValueError(f'"sample" needs to be one of: {list(samples_path_dict.keys())}') + else: + adata_path = os.path.join(temp_dir_obj.name, f"{dataset_prefix}_{sample}.h5ad") + backup_url = samples_path_dict[sample] + else: + adata_path = os.path.join(temp_dir_obj.name, f"{dataset_prefix}.h5ad") + backup_url = "https://figshare.com/ndownloader/files/59325125" + + adata = sc.read(adata_path, backup_url=backup_url) + + return adata diff --git a/src/scvi/external/harreman/hotspot/__init__.py b/src/scvi/external/harreman/hotspot/__init__.py new file mode 100755 index 0000000000..a6cb3c1f8c --- /dev/null +++ b/src/scvi/external/harreman/hotspot/__init__.py @@ -0,0 +1,9 @@ +from .local_autocorrelation import compute_local_autocorrelation, load_metabolic_genes +from .local_correlation import compute_local_correlation +from .modules import ( + calculate_module_scores, + calculate_super_module_scores, + compute_top_scoring_modules, + create_modules, + integrate_vision_hotspot_results, +) diff --git a/src/scvi/external/harreman/hotspot/local_autocorrelation.py b/src/scvi/external/harreman/hotspot/local_autocorrelation.py new file mode 100755 index 0000000000..030b5f772a --- /dev/null +++ b/src/scvi/external/harreman/hotspot/local_autocorrelation.py @@ -0,0 +1,522 @@ +import time +import os +import pooch +from typing import Literal + +import numpy as np +import pandas as pd +import torch +from anndata import AnnData +from numba import jit, njit +from scipy.stats import norm +from statsmodels.stats.multitest import multipletests +from tqdm import tqdm + +from ..preprocessing.anndata import counts_from_anndata +from ..tools.knn import make_weights_non_redundant +from . import models + + +def load_metabolic_genes( + species: Literal["mouse"] | Literal["human"] | None = None, +): + """ + Load the list of metabolic genes for a given species. + + Parameters + ---------- + species : {"mouse", "human"}, optional (default: "mouse") + Species used to select the correct metabolic gene list . + + Returns + ------- + List of metabolic genes. + """ + metabolic_genes_path = pooch.retrieve( + url=f"https://scverse-public-data.s3.eu-central-1.amazonaws.com/scvi-tools/harreman/metabolic_genes/{species}_metabolic_genes.csv", + known_hash=None, + fname=f"{species}_metabolic_genes.csv", + path=pooch.os_cache("scvi_harreman"), + progressbar=False, + ) + + metabolic_genes = pd.read_csv(metabolic_genes_path, index_col=0)["0"].tolist() + + return metabolic_genes + + +def compute_local_autocorrelation( + adata: AnnData, + layer_key: Literal["use_raw"] | str | None = None, + database_varm_key: str | None = None, + model: str | None = None, + genes: list | None = None, + use_metabolic_genes: bool | None = False, + species: Literal["mouse"] | Literal["human"] | None = "mouse", + umi_counts_obs_key: str | None = None, + permutation_test: bool | None = False, + M: int | None = 1000, + seed: int | None = 42, + check_analytic_null: bool | None = False, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Computes gene-level spatial autocorrelation statistics using spatial weights and centered gene expression values. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). Requires `obsp["weights"]` for the spatial graph. + layer_key : str or "use_raw", optional + Key in `adata.layers` to use for expression data. Use "use_raw" to access `adata.raw`. + database_varm_key : str, optional + Key in `adata.varm` used for filtering genes that are part of the transporter or ligand-receptor database. + model : str, optional + Normalization model to use for centering gene expression. Options include "none", "normal", "bernoulli", or "danb". + genes : list, optional + List of gene names to include in the analysis. If `None`, all genes are used or selected via metabolic/pathway filters. + use_metabolic_genes : bool, optional (default: False) + If `True`, restricts analysis to metabolic genes as defined for the selected species. + species : {"mouse", "human"}, optional (default: "mouse") + Species used to select the correct metabolic gene list if `use_metabolic_genes=True`. + umi_counts_obs_key : str, optional + Key in `adata.obs` with total UMI counts per cell. If `None`, inferred from the expression matrix. + permutation_test : bool, optional (default: False) + Whether to compute an empirical p-value and null distribution by permuting the data. + M : int, optional (default: 1000) + Number of permutations to perform if `permutation_test` is True. + seed : int, optional (default: 42) + Random seed for permutation reproducibility. + check_analytic_null : bool, optional (default: False) + Whether to evaluate Z-scores under an analytic null distribution using permutation Z-scores. + device : torch.device, optional + PyTorch device to run computations on. Defaults to CUDA if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + The results are stored in `adata.uns["gene_autocorrelation_results"]` as a DataFrame. + """ + start = time.time() + if verbose: + print("Computing local autocorrelation...") + + adata.uns["layer_key"] = layer_key + adata.uns["model"] = model + adata.uns["species"] = species + + sample_specific = "sample_key" in adata.uns + + # Gene selection + if use_metabolic_genes and genes is None: + genes = pd.Index(load_metabolic_genes(species)).intersection(adata.var_names) + elif database_varm_key is not None and genes is None: + source = adata.raw if (layer_key == "use_raw") else adata + metab_matrix = source.varm[database_varm_key] + genes = metab_matrix.loc[(metab_matrix != 0).any(axis=1)].index + elif genes is None: + genes = adata.raw.var.index if layer_key == "use_raw" else adata.var_names + else: + genes = pd.Index(genes) + + # Load counts + counts = counts_from_anndata(adata[:, genes], layer_key, dense=True) + + # Gene filtering + if sample_specific: + sample_key = adata.uns["sample_key"] + sample_arr = adata.obs[sample_key].to_numpy() + mask = np.zeros(counts.shape[0], dtype=bool) + for sample in np.unique(sample_arr): + sample_idx = np.where(sample_arr == sample)[0] + mask |= np.all(counts[:, sample_idx] == 0, axis=1) + else: + mask = np.all(counts == 0, axis=1) + + counts = counts[~mask] + genes = genes[~mask] + + # UMI counts + num_umi = ( + counts.sum(axis=0) + if umi_counts_obs_key is None + else np.asarray(adata.obs[umi_counts_obs_key]) + ) + adata.uns["umi_counts"] = num_umi + + # Convert to tensors + num_umi = torch.tensor(adata.uns["umi_counts"], dtype=torch.float64, device=device) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + + # Center values + counts = standardize_counts(adata, counts, model, num_umi, sample_specific) + + adata.var["local_autocorrelation"] = False + adata.var.loc[genes, "local_autocorrelation"] = True + + # Compute weights + weights = make_weights_non_redundant(adata.obsp["weights"]).tocoo() + Wtot2 = torch.tensor((weights.data**2).sum(), device=device) + weights = torch.sparse_coo_tensor( + torch.tensor(np.vstack((weights.row, weights.col)), dtype=torch.long, device=device), + torch.tensor(weights.data, dtype=torch.float64, device=device), + torch.Size(weights.shape), + device=device, + ) + + # Compute node degree + row_degrees = torch.sparse.sum(weights, dim=1).to_dense() + col_degrees = torch.sparse.sum(weights, dim=0).to_dense() + D = row_degrees + col_degrees + + # Autocorrelation + WXt = torch.sparse.mm(weights, counts.T) + G = (counts.T * WXt).sum(dim=0) + G_max = 0.5 * torch.sum((counts**2) * D[None, :], dim=1) + + # Results + results = compute_gene_autocorrelation_results( + counts=counts, + weights=weights, + G=G, + G_max=G_max, + Wtot2=Wtot2, + genes=genes, + D=D, + M=M, + permutation_test=permutation_test, + seed=seed, + check_analytic_null=check_analytic_null, + device=device, + ) + + # Save results + if isinstance(results, tuple): + results_df, zs_perm, pvals_perm = results + adata.uns["analytic_null_ac_zs_perm"] = zs_perm + adata.uns["analytic_null_ac_pvals_perm"] = pvals_perm + if verbose: + print( + "Analytic null results are stored in adata.uns with the following keys: ['analytic_null_ac_zs_perm', 'analytic_null_ac_pvals_perm']" + ) + else: + results_df = results + + results_df = results_df.sort_values("Z", ascending=False) + results_df.index.name = "Gene" + cols = ["C", "Z", "Z_Pval", "Z_FDR"] + if "Perm_Pval" in results_df.columns: + cols += ["Perm_Pval", "Perm_FDR"] + adata.uns["gene_autocorrelation_results"] = results_df[cols] + if verbose: + print( + "Local autocorrelation results are stored in adata.uns['gene_autocorrelation_results']" + ) + + print("Finished computing local autocorrelation in %.3f seconds" % (time.time() - start)) + + return + + +def compute_gene_autocorrelation_results( + counts, + weights, + G, + G_max, + Wtot2, + genes, + D, + M, + permutation_test, + seed, + check_analytic_null, + device, +): + # Compute core stats + stats = compute_autocor_Z_scores_torch(G, G_max, Wtot2) + + # Build DataFrame + results = pd.DataFrame({k: v.cpu().numpy() for k, v in stats.items()}, index=genes) + + # Z P-values and FDR + results["Z_Pval"] = norm.sf(results["Z"]) + results["Z_FDR"] = multipletests(results["Z_Pval"], method="fdr_bh")[1] + + # Permutation test + if permutation_test: + n_genes, n_cells = counts.shape + perm_array = torch.zeros((n_genes, M), dtype=torch.float16, device=device) + + if check_analytic_null: + ac_zs_perm_array = torch.zeros((n_genes, M), dtype=torch.float16, device=device) + ac_pvals_perm_array = torch.zeros((n_genes, M), dtype=torch.float16, device=device) + + torch.manual_seed(seed) + for i in tqdm(range(M), desc="Permutation test"): + idx = torch.randperm(n_cells, device=device) + X_perm = counts[:, idx] # (genes x cells) + WXt_perm = torch.sparse.mm(weights, X_perm.T) # (cells x genes) + G_perm = (X_perm.T * WXt_perm).sum(dim=0) # (genes,) + perm_array[:, i] = G_perm.half() + + if check_analytic_null: + # Compute G_max for permuted data + G_perm_max = 0.5 * torch.sum(X_perm**2 * D.unsqueeze(0), dim=1) + stats_perm = compute_autocor_Z_scores_torch(G_perm, G_perm_max, Wtot2) + ac_zs_perm_array[:, i] = stats_perm["Z"].half() + ac_pvals_perm_array[:, i] = torch.tensor( + norm.sf(stats_perm["Z"].cpu().numpy()), device=device + ).half() + + # Compute empirical permutation p-values + G_expanded = G.unsqueeze(1) # (genes, 1) + x = torch.sum(perm_array > G_expanded, dim=1) + pvals = ((x + 1) / (M + 1)).cpu().numpy() + results["Perm_Pval"] = pvals + results["Perm_FDR"] = multipletests(pvals, method="fdr_bh")[1] + + # Save optional nulls + if check_analytic_null: + return results, ac_zs_perm_array.cpu().numpy(), ac_pvals_perm_array.cpu().numpy() + + return results + + +@jit(nopython=True) +def local_cov_weights(x, weights_data, weights_coords): + out = 0 + + for i in range(len(x)): + mask_i = weights_coords[0] == i + indices_i = weights_coords[:, mask_i][1] + values_i = weights_data[mask_i] + for k in range(len(indices_i)): + j = indices_i[k] + j = int(j) + + w_ij = values_i[k] + + xi = x[i] + xj = x[j] + if xi == 0 or xj == 0 or w_ij == 0: + out += 0 + else: + out += xi * xj * w_ij + + return out + + +@jit(nopython=True) +def compute_local_cov_max(vals, node_degrees): + tot = 0.0 + + for i in range(node_degrees.size): + tot += node_degrees[i] * (vals[i] ** 2) + + return tot / 2 + + +def _compute_hs_inner(vals, weights_data, weights_coords, num_umi, model, Wtot2, D): + """Note, since this is an inner function, for parallelization to work well + none of the contents of the function can use MKL or OPENBLAS threads. + Or else we open too many. Because of this, some simple numpy operations + are re-implemented using numba instead as it's difficult to control + the number of threads in numpy after it's imported. + """ + if model == "bernoulli": + vals = (vals > 0).astype("double") + mu, var, x2 = models.bernoulli_model(vals, num_umi) + elif model == "danb": + mu, var, x2 = models.danb_model(vals, num_umi) + elif model == "normal": + mu, var, x2 = models.normal_model(vals, num_umi) + elif model == "none": + mu, var, x2 = models.none_model(vals, num_umi) + else: + raise Exception(f"Invalid Model: {model}") + + vals = center_values(vals, mu, var) + + G = local_cov_weights(vals, weights_data, weights_coords) + + EG, EG2 = 0, Wtot2 + + stdG = (EG2 - EG * EG) ** 0.5 + + Z = (G - EG) / stdG + + G_max = compute_local_cov_max(vals, D) + C = (G - EG) / G_max + + return [G, EG, stdG, Z, C] + + +@njit +def center_values(vals, mu, var): + out = np.zeros_like(vals) + + for i in range(len(vals)): + std = var[i] ** 0.5 + if std == 0: + out[i] = 0 + else: + out[i] = (vals[i] - mu[i]) / std + + return out + + +def center_values_total(vals, num_umi, model): + """ + Note, since this is an inner function, for parallelization to work well + none of the contents of the function can use MKL or OPENBLAS threads. + Or else we open too many. Because of this, some simple numpy operations + are re-implemented using numba instead as it's difficult to control + the number of threads in numpy after it's imported + """ + if model == "bernoulli": + vals = (vals > 0).astype("double") + mu, var, x2 = models.bernoulli_model(vals, num_umi) + elif model == "danb": + mu, var, x2 = models.danb_model(vals, num_umi) + elif model == "normal": + mu, var, x2 = models.normal_model(vals, num_umi) + elif model == "none": + mu, var, x2 = models.none_model(vals, num_umi) + else: + raise Exception(f"Invalid Model: {model}") + + centered_vals = center_values(vals, mu, var) + + return centered_vals + + +def center_counts_torch(counts, num_umi, model): + """ + counts: Tensor [genes, cells] + num_umi: Tensor [cells] + model: 'bernoulli', 'danb', 'normal', or 'none' + + Returns + ------- + Centered counts: Tensor [genes, cells] + """ + # Binarize if using Bernoulli + if model == "bernoulli": + counts = (counts > 0).double() + mu, var, _ = models.bernoulli_model_torch(counts, num_umi) + elif model == "danb": + mu, var, _ = models.danb_model_torch(counts, num_umi) + elif model == "normal": + mu, var, _ = models.normal_model_torch(counts, num_umi) + elif model == "none": + mu, var, _ = models.none_model_torch(counts, num_umi) + else: + raise ValueError(f"Unsupported model type: {model}") + + # Avoid division by zero + std = torch.sqrt(var) + std[std == 0] = 1.0 + + centered = (counts - mu) / std + centered[centered == 0] = 0 # Optional: to match old behavior + + return centered + + +def compute_autocor_Z_scores(G, G_max, Wtot2): + + EG, EG2 = 0, Wtot2 + + stdG = (EG2 - EG * EG) ** 0.5 + + Z = [(G[i] - EG) / stdG for i in range(len(G))] + + C = (G - EG) / G_max + + EG = [EG for i in range(len(G))] + stdG = [stdG for i in range(len(G))] + + return [G, G_max, EG, stdG, Z, C] + + +def compute_autocor_Z_scores_torch(G, G_max, Wtot2): + """ + G, G_max: torch tensors of shape (genes,) + Wtot2: float scalar (already computed) + Returns a dict with tensors: G, G_max, EG, stdG, Z, C + """ + EG = 0.0 + stdG = (Wtot2 - EG**2) ** 0.5 + + Z = (G - EG) / stdG # (genes,) + C = (G - EG) / G_max # (genes,) + + EG_tensor = torch.full_like(G, EG) + stdG_tensor = torch.full_like(G, stdG) + + return { + "G": G, + "G_max": G_max, + "EG": EG_tensor, + "stdG": stdG_tensor, + "Z": Z, + "C": C, + } + + +def standardize_counts(adata, counts, model, num_umi, sample_specific): + + if sample_specific: + sample_key = adata.uns["sample_key"] + for sample in adata.obs[sample_key].unique(): + subset = np.where(adata.obs[sample_key] == sample)[0] + counts[:, subset] = center_counts_torch(counts[:, subset], num_umi[subset], model) + else: + counts = center_counts_torch(counts, num_umi, model) + + return counts + + +def compute_communication_autocorrelation(adata, spatial_coords_obsm_key): + """Computes Geary's C for numerical data.""" + metab_scores_df = adata.obsm["metabolite_scores"] + gene_pair_scores_df = adata.obsm["gene_pair_scores"] + + # Compute autocorrelation on the metabolite scores + + metab_adata = AnnData(metab_scores_df) + metab_adata.obsm[spatial_coords_obsm_key] = adata.obsm[spatial_coords_obsm_key] + + metab_adata.obsm["neighbors_sort"] = adata.obsm["neighbors_sort"] + metab_adata.obsp["weights"] = adata.obsp["weights"] + + compute_local_autocorrelation( + metab_adata, + model="none", + jobs=1, + ) + + adata.uns["metabolite_autocorrelation_results"] = metab_adata.uns[ + "gene_autocorrelation_results" + ] + + # Compute autocorrelation on the gene pair scores + + gene_pair_adata = AnnData(gene_pair_scores_df) + gene_pair_adata.obsm[spatial_coords_obsm_key] = adata.obsm[spatial_coords_obsm_key] + + gene_pair_adata.obsm["neighbors_sort"] = adata.obsm["neighbors_sort"] + gene_pair_adata.obsp["weights"] = adata.obsp["weights"] + + compute_local_autocorrelation( + gene_pair_adata, + model="none", + jobs=1, + ) + + adata.uns["gene_pair_autocorrelation_results"] = gene_pair_adata.uns[ + "gene_autocorrelation_results" + ] diff --git a/src/scvi/external/harreman/hotspot/local_correlation.py b/src/scvi/external/harreman/hotspot/local_correlation.py new file mode 100755 index 0000000000..2e24929c26 --- /dev/null +++ b/src/scvi/external/harreman/hotspot/local_correlation.py @@ -0,0 +1,471 @@ +import time + +import numpy as np +import pandas as pd +import sparse +import torch +from anndata import AnnData +from numba import jit, njit +from scipy.stats import norm +from statsmodels.stats.multitest import multipletests +from tqdm import tqdm + +from ..preprocessing.anndata import counts_from_anndata +from ..tools.knn import make_weights_non_redundant +from . import models +from .local_autocorrelation import compute_local_cov_max, standardize_counts + + +def compute_local_correlation( + adata: AnnData, + genes: list | None = None, + permutation_test: bool | None = False, + M: int | None = 1000, + seed: int | None = 42, + check_analytic_null: bool | None = False, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Computes pairwise local correlation between selected genes using a spatial weight matrix. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). Required keys in `adata.uns`: + - 'gene_autocorrelation_results' (if `genes` is None) + - 'layer_key': layer in `adata` to extract counts from + - 'model': statistical model to use for centering (e.g., 'DANB', 'normal') + - 'umi_counts': per-cell UMI counts + - optionally 'sample_key': key in `adata.obs` to use for per-sample normalization + genes : list, optional + List of genes to include in the correlation analysis. If None, selects genes with FDR < 0.05 + from `adata.uns['gene_autocorrelation_results']`, ordered by Z-score. + permutation_test : bool, optional (default: False) + Whether to compute an empirical p-value and null distribution by permuting the data. + M : int, optional (default: 1000) + Number of permutations to perform if `permutation_test` is True. + seed : int, optional (default: 42) + Random seed for permutation reproducibility. + check_analytic_null : bool, optional (default: False) + Whether to compute an analytic null distribution for the local correlation scores. + device : torch.device, optional + PyTorch device to run computations on. Defaults to CUDA if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + Results are stored in the following keys in `adata.uns`: `lcs`, `lc_zs`, `lc_z_pvals`, and `lc_z_FDR`. + """ + start = time.time() + + if genes is None: + gene_autocorrelation_results = adata.uns["gene_autocorrelation_results"] + genes = ( + gene_autocorrelation_results.loc[gene_autocorrelation_results.Z_FDR < 0.05] + .sort_values("Z", ascending=False) + .index + ) + + if verbose: + print(f"Computing pair-wise local correlation on {len(genes)} features...") + + layer_key = adata.uns["layer_key"] + model = adata.uns["model"] + sample_specific = "sample_key" in adata.uns.keys() + + # Load counts + counts = counts_from_anndata(adata[:, genes], layer_key, dense=True) + + # UMI counts + num_umi = adata.uns["umi_counts"] + + # Convert to tensors + num_umi = torch.tensor(adata.uns["umi_counts"], dtype=torch.float64, device=device) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + + # Center values + counts = standardize_counts(adata, counts, model, num_umi, sample_specific) + + # Compute weights + weights = make_weights_non_redundant(adata.obsp["weights"]).tocoo() + weights = torch.sparse_coo_tensor( + torch.tensor(np.vstack((weights.row, weights.col)), dtype=torch.long, device=device), + torch.tensor(weights.data, dtype=torch.float64, device=device), + torch.Size(weights.shape), + device=device, + ) + + # Compute node degree + row_degrees = torch.sparse.sum(weights, dim=1).to_dense() + col_degrees = torch.sparse.sum(weights, dim=0).to_dense() + D = row_degrees + col_degrees + + # Pairwise correlation + WXt = torch.sparse.mm(weights, counts.T) # (cells x genes) + WtXt = torch.sparse.mm(weights.transpose(0, 1), counts.T) # (cells x genes) + lcs = torch.matmul(counts, WXt) + torch.matmul(counts, WtXt) + + # Compute second moments of H + eg2s = (WXt + WtXt).pow(2).sum(dim=0) + + # Results + results = compute_pairwise_correlation_results( + counts=counts, + weights=weights, + lcs=lcs, + eg2s=eg2s, + genes=genes, + D=D, + M=M, + permutation_test=permutation_test, + seed=seed, + check_analytic_null=check_analytic_null, + device=device, + ) + + # Save results + for key, value in results.items(): + adata.uns[key] = value + + if verbose: + print( + f"Pair-wise local correlation results are stored in adata.uns with the following keys: {list(results.keys())}" + ) + + print( + "Finished computing pair-wise local correlation in %.3f seconds" + % (time.time() - start) + ) + + return + + +def compute_pairwise_correlation_results( + counts, weights, lcs, eg2s, genes, D, M, permutation_test, seed, check_analytic_null, device +): + + results = {} + + lc_zs = compute_cor_Z_scores_torch(lcs, eg2s) + + lc_z_pvals = norm.sf(lc_zs.cpu().numpy()) + lc_z_FDR = multipletests(lc_z_pvals.flatten(), method="fdr_bh")[1].reshape(lc_z_pvals.shape) + + gene_index = pd.Index(genes) + + if permutation_test: + n_genes, n_cells = counts.shape + perm_array = torch.empty((n_genes, n_genes, M), dtype=torch.float16, device=device) + if check_analytic_null: + lc_zs_perm_array = torch.empty_like(perm_array) + lc_pvals_perm_array = torch.empty_like(perm_array) + + torch.manual_seed(seed) + for i in tqdm(range(M), desc="Permutation test"): + idx = torch.randperm(n_cells, device=device) + + WXt_perm = torch.sparse.mm(weights, counts[:, idx].T) + WtXt_perm = torch.sparse.mm(weights.transpose(0, 1), counts[:, idx].T) + lcs_perm = torch.matmul(counts, WXt_perm) + torch.matmul(counts, WtXt_perm) + + perm_array[:, :, i] = lcs_perm.half() + + if check_analytic_null: + lc_zs_perm = compute_cor_Z_scores_torch(lcs_perm, eg2s) + lc_zs_perm_array[:, :, i] = lc_zs_perm.half() + lc_pvals_perm_array[:, :, i] = torch.tensor( + norm.sf(lc_zs_perm.cpu().numpy()), device=device + ).half() + + x = (perm_array > lcs.unsqueeze(-1)).sum(dim=2) + lc_perm_pvals = (x + 1).float() / (M + 1) + + lc_perm_pvals_ab = torch.tril(lc_perm_pvals, diagonal=-1) + lc_perm_pvals_ba = torch.tril(lc_perm_pvals.transpose(0, 1), diagonal=-1) + lc_perm_pvals_sym = torch.where( + lc_perm_pvals_ab > lc_perm_pvals_ba, lc_perm_pvals_ab, lc_perm_pvals_ba + ) + i_upper = torch.triu_indices(n_genes, n_genes, offset=1) + lc_perm_pvals_sym[i_upper[0], i_upper[1]] = lc_perm_pvals_sym[i_upper[1], i_upper[0]] + + results["lc_perm_pvals"] = pd.DataFrame( + lc_perm_pvals.cpu().numpy(), index=gene_index, columns=gene_index + ) + results["lc_perm_pvals_sym"] = pd.DataFrame( + lc_perm_pvals_sym.cpu().numpy(), index=gene_index, columns=gene_index + ) + + if check_analytic_null: + results["analytic_null_lc_zs_perm"] = lc_zs_perm_array.cpu().numpy() + results["analytic_null_lc_pvals_perm"] = lc_pvals_perm_array.cpu().numpy() + + gene_maxs = 0.5 * torch.sum((counts**2) * D[None, :], dim=1) # shape (n_genes,) + lc_maxs = (gene_maxs[:, None] + gene_maxs[None, :]) / 2 + lcs = lcs / lc_maxs + + results["lcs"] = pd.DataFrame(lcs.cpu().numpy(), index=gene_index, columns=gene_index) + results["lc_zs"] = pd.DataFrame(lc_zs.cpu().numpy(), index=gene_index, columns=gene_index) + results["lc_z_pvals"] = pd.DataFrame(lc_z_pvals, index=gene_index, columns=gene_index) + results["lc_z_FDR"] = pd.DataFrame(lc_z_FDR, index=gene_index, columns=gene_index) + + return results + + +@jit(nopython=True) +def conditional_eg2(x, neighbors, weights): + """ + Computes EG2 for the conditional correlation + """ + N = neighbors.shape[0] + K = neighbors.shape[1] + + t1x = np.zeros(N) + + for i in range(N): + K = len(neighbors[i][~np.isnan(neighbors[i])]) + for k in range(K): + j = neighbors[i, k] + j = int(j) + + wij = weights[i, j] + if wij == 0: + continue + + t1x[i] += wij * x[j] + t1x[j] += wij * x[i] + + out_eg2 = (t1x**2).sum() + + return out_eg2 + + +@jit(nopython=True) +def local_cov_pair(x, y, neighbors, weights): + """Test statistic for local pair-wise autocorrelation""" + out = 0 + + for i in range(len(x)): + xi = x[i] + yi = y[i] + if xi == 0 and yi == 0: + continue + K = len(neighbors[i][~np.isnan(neighbors[i])]) + for k in range(K): + j = neighbors[i, k] + j = int(j) + + w_ij = weights[i, j] + + xj = x[j] + yj = y[j] + + out += w_ij * (xi * yj + yi * xj) / 2 + + return out + + +@jit(nopython=True) +def local_cov_pair_fast(counts, weights): + """Test statistic for local pair-wise autocorrelation""" + counts_t = counts.transpose() + weights_t = weights.transpose() + + lc_1 = sparse.einsum("ik,kl,lj->ij", counts, weights, counts_t) + lc_2 = sparse.einsum("ik,kl,lj->ij", counts, weights_t, counts_t) + lc = lc_1 + lc_2 + + # lc = counts @ weights @ counts.T + counts @ weights.T @ counts.T + + return lc + + +def create_centered_counts_torch(counts, model, num_umi, device, eps=1e-10): + """ + Vectorized PyTorch version of centered counts transformation. + + Args: + counts: torch.Tensor of shape (G, C) + model: one of {'bernoulli', 'danb', 'normal', 'none'} + num_umi: torch.Tensor of shape (C,) = total UMIs per cell + device + + Returns + ------- + centered_counts: torch.Tensor of shape (G, C) + """ + G, C = counts.shape + + if model == "bernoulli": + vals = (counts > 0).double() + model_fn = models.bernoulli_model_linear_torch + elif model == "danb": + vals = counts.double() + model_fn = models.danb_model_torch + elif model == "normal": + vals = counts.double() + model_fn = models.normal_model_torch + elif model == "none": + # Center only (no variance normalization) + mu = vals.mean(dim=1, keepdim=True) + return counts - mu + else: + raise ValueError(f"Invalid model type: {model}") + + # Output tensors + mu_all = torch.zeros_like(counts, dtype=torch.double) + var_all = torch.ones_like(counts, dtype=torch.double) + + for g in range(G): + mu, var, _ = model_fn(counts[g], num_umi, device) + mu_all[g] = mu + var_all[g] = var + + # Prevent division by zero + var_all[var_all == 0] = 1.0 + + centered = (counts - mu_all) / var_all.sqrt() + centered[centered == 0] = 0.0 # Optional, usually unnecessary + + return centered + + +def create_centered_counts(counts, model, num_umi): + """ + Creates a matrix of centered/standardized counts given + the selected statistical model + """ + out = np.zeros_like(counts, dtype="double") + + for i in range(out.shape[0]): + vals_x = counts[i] + + out_x = create_centered_counts_row(vals_x, model, num_umi) + + out[i] = out_x + + return out + + +def create_centered_counts_row(vals_x, model, num_umi): + + if model == "bernoulli": + vals_x = (vals_x > 0).astype("double") + mu_x, var_x, x2_x = models.bernoulli_model(vals_x, num_umi) + elif model == "danb": + mu_x, var_x, x2_x = models.danb_model(vals_x, num_umi) + elif model == "normal": + mu_x, var_x, x2_x = models.normal_model(vals_x, num_umi) + elif model == "none": + mu_x, var_x, x2_x = models.none_model(vals_x, num_umi) + else: + raise Exception(f"Invalid Model: {model}") + + var_x[var_x == 0] = 1 + out_x = (vals_x - mu_x) / (var_x**0.5) + out_x[out_x == 0] = 0 + + return out_x + + +@jit(nopython=True) +def _compute_hs_pairs_inner_centered_cond_sym(rowpair, counts, neighbors, weights, eg2s): + """ + This version assumes that the counts have already been modeled + and centered + """ + row_i, row_j = rowpair + + vals_x = counts[row_i] + vals_y = counts[row_j] + + lc = local_cov_pair(vals_x, vals_y, neighbors, weights) * 2 + + # Compute xy + EG, EG2 = 0, eg2s[row_i] + + stdG = (EG2 - EG**2) ** 0.5 + + Zxy = (lc - EG) / stdG + + # Compute yx + EG, EG2 = 0, eg2s[row_j] + + stdG = (EG2 - EG**2) ** 0.5 + + Zyx = (lc - EG) / stdG + + if abs(Zxy) < abs(Zyx): + Z = Zxy + else: + Z = Zyx + + return (lc, Z) + + +def compute_cor_Z_scores(lc, eg2s): + + EG, EG2 = 0, eg2s + stdG = (EG2 - EG**2) ** 0.5 + + Z = (lc - EG) / stdG[:, np.newaxis] + + Z_ab = np.tril(Z, k=-1) + Z_ba = np.tril(Z.T, k=-1) + + Z = np.where(np.abs(Z_ab) < np.abs(Z_ba), Z_ab, Z_ba) + + i_upper = np.triu_indices(Z.shape[0], k=1) + Z[i_upper] = Z.T[i_upper] + + return Z + + +def compute_cor_Z_scores_torch(lc, eg2s): + + EG = 0.0 + stdG = (eg2s - EG**2) ** 0.5 + + Z = (lc - EG) / stdG + + Z_ab = torch.tril(Z, diagonal=-1) + Z_ba = torch.tril(Z.T, diagonal=-1) + + Z = torch.where(Z_ab.abs() < Z_ba.abs(), Z_ab, Z_ba) + Z = Z + Z.T + + return Z + + +@njit +def expand_pairs(pairs, vals, N): + + out = np.zeros((N, N)) + + for i in range(len(pairs)): + x = pairs[i, 0] + y = pairs[i, 1] + v = vals[i] + + out[x, y] = v + out[y, x] = v + + return out + + +def compute_max_correlation(node_degrees, counts): + """ + For a Genes x Cells count matrix, compute the maximal pair-wise correlation + between any two genes + """ + N_GENES = counts.shape[0] + + gene_maxs = np.zeros(N_GENES) + for i in range(N_GENES): + gene_maxs[i] = compute_local_cov_max(node_degrees, counts[i]) + + result = gene_maxs.reshape((-1, 1)) + gene_maxs.reshape((1, -1)) + result = result / 2 + return result diff --git a/src/scvi/external/harreman/hotspot/models.py b/src/scvi/external/harreman/hotspot/models.py new file mode 100755 index 0000000000..6dbf273b43 --- /dev/null +++ b/src/scvi/external/harreman/hotspot/models.py @@ -0,0 +1,494 @@ +from collections.abc import Callable + +import numpy as np +import pandas as pd +import torch +from numba import jit, njit + + +def danb_model(gene_counts, umi_counts): + + tj = gene_counts.sum() + tis = umi_counts + total = tis.sum() + + N = gene_counts.size + + min_size = 10 ** (-10) + + mu = tj * tis / total + vv = (gene_counts - mu).var() * (N / (N - 1)) + # vv = ((gene_counts - mu)**2).sum() + my_rowvar = vv + + size = ((tj**2) / total) * ((tis**2).sum() / total) / ((N - 1) * my_rowvar - tj) + # size = ((tj**2) * ((tis/total)**2).sum()) / ((N-1)*my_rowvar-tj) + + if size < 0: # Can't have negative dispersion + size = 1e9 + + if size < min_size and size >= 0: + size = min_size + + var = mu * (1 + mu / size) + x2 = var + mu**2 + + return mu, var, x2 + + +def danb_model_torch(counts: torch.Tensor, umi_counts: torch.Tensor, eps: float = 1e-10): + """ + Vectorized DANB model computation in PyTorch for a batch of genes. + + Args: + counts: Tensor of shape [genes, cells], gene expression counts. + umi_counts: Tensor of shape [cells], total UMI per cell. + eps: Small constant to avoid division by zero or log(0). + + Returns + ------- + mu: Mean per gene per cell [genes, cells] + var: Variance per gene per cell [genes, cells] + x2: Second moment per gene per cell [genes, cells] + """ + tj = counts.sum(dim=1, keepdim=True) # [genes, 1] + total = umi_counts.sum() # scalar + N = counts.shape[1] # number of cells + + mu = tj * umi_counts / total # [genes, cells] + diff = counts - mu # [genes, cells] + + # Unbiased sample variance (N / (N - 1)) + var_gene = (diff**2).mean(dim=1) * N / (N - 1) # [genes] + + numerator = ((tj**2) / total).squeeze() * (umi_counts**2).sum() / total # [genes] + denominator = (N - 1) * var_gene - tj.squeeze() # [genes] + size = numerator / (denominator + eps) # [genes] + + # Clamp size for numerical stability + size = torch.where(size < 0, torch.tensor(1e9, device=size.device), size) + size = torch.clamp(size, min=eps) + + size = size.unsqueeze(1) # [genes, 1] for broadcasting + var = mu * (1 + mu / size) # [genes, cells] + x2 = var + mu**2 # [genes, cells] + + return mu, var, x2 + + +def ct_danb_model(gene_counts, umi_counts, cell_types): + + mu_ct = np.zeros(len(cell_types)) + var_ct = np.zeros(len(cell_types)) + x2_ct = np.zeros(len(cell_types)) + + min_size = 10 ** (-10) + + for cell_type in np.unique(cell_types): + gene_counts_ct = gene_counts[cell_types == cell_type] + umi_counts_ct = umi_counts[cell_types == cell_type] + + tj = gene_counts_ct.sum() + tis = umi_counts_ct + total = tis.sum() + + N = gene_counts_ct.size + + mu = tj * tis / total + vv = (gene_counts_ct - mu).var() * (N / (N - 1)) if N > 1 else (gene_counts_ct - mu).var() + my_rowvar = vv + + size = ((tj**2) / total) * ((tis**2).sum() / total) / ((N - 1) * my_rowvar - tj) + + if size < 0: # Can't have negative dispersion + size = 1e9 + + if size < min_size and size >= 0: + size = min_size + + var = mu * (1 + mu / size) + x2 = var + mu**2 + + mu_ct[cell_types == cell_type] = mu + var_ct[cell_types == cell_type] = var + x2_ct[cell_types == cell_type] = x2 + + return mu_ct, var_ct, x2_ct + + +N_BIN_TARGET = 30 + + +@jit(nopython=True) +def find_gene_p(num_umi, D): + """ + Finds gene_p such that sum of expected detects + matches our data + + Performs a binary search on p in the space of log(p) + """ + low = 1e-12 + high = 1 + + if D == 0: + return 0 + + for ITER in range(40): + attempt = (high * low) ** 0.5 + tot = 0 + + for i in range(len(num_umi)): + tot = tot + 1 - (1 - attempt) ** num_umi[i] + + if abs(tot - D) / D < 1e-3: + break + + if tot > D: + high = attempt + else: + low = attempt + + return (high * low) ** 0.5 + + +def bernoulli_model_scaled(gene_detects, umi_counts): + + D = gene_detects.sum() + + gene_p = find_gene_p(umi_counts, D) + + detect_p = 1 - (1 - gene_p) ** umi_counts + + mu = detect_p + var = detect_p * (1 - detect_p) + x2 = detect_p + + return mu, var, x2 + + +def true_params_scaled(gene_p, umi_counts): + + detect_p = 1 - (1 - gene_p / 10000) ** umi_counts + + mu = detect_p + var = detect_p * (1 - detect_p) + x2 = detect_p + + return mu, var, x2 + + +def bernoulli_model_linear(gene_detects, umi_counts): + + # We modify the 0 UMI counts to 1e-10 to remove the NaN values from the qcut output. + umi_counts[umi_counts == 0] = 1e-10 + + umi_count_bins, bins = pd.qcut( + np.log10(umi_counts), N_BIN_TARGET, labels=False, retbins=True, duplicates="drop" + ) + bin_centers = np.array([bins[i] / 2 + bins[i + 1] / 2 for i in range(len(bins) - 1)]) + + N_BIN = len(bin_centers) + + bin_detects = bin_gene_detection(gene_detects, umi_count_bins, N_BIN) + + lbin_detects = logit(bin_detects) + + X = np.ones((N_BIN, 2)) + X[:, 1] = bin_centers + Y = lbin_detects + + b = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y) + detect_p = ilogit(b[0] + b[1] * np.log10(umi_counts)) + + mu = detect_p + var = detect_p * (1 - detect_p) + x2 = detect_p + + return mu, var, x2 + + +def bernoulli_model_linear_torch(gene_detects, umi_counts, n_bins=30, eps=1e-10): + """ + gene_detects: [genes, cells] binary tensor + umi_counts: [cells] tensor + Returns: mu, var, x2 = each of shape [genes, cells] + """ + device = gene_detects.device + log_umi = torch.log10(umi_counts.clamp(min=eps)) # [cells] + + # Use pd.qcut to get bin indices and edges + bin_indices_np, bin_edges_np = pd.qcut( + log_umi.cpu().numpy(), q=n_bins, labels=False, retbins=True, duplicates="drop" + ) + + dtype = gene_detects.dtype + bin_edges = torch.tensor(bin_edges_np, device=device, dtype=dtype) + bin_indices = torch.tensor(bin_indices_np, device=device, dtype=dtype) + + # Compute bin centers from bin_edges + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) # [n_bins] + + # Compute bin means per gene + detect_sum = torch.zeros(gene_detects.size(0), len(bin_centers), device=device, dtype=dtype) + bin_counts = torch.zeros(len(bin_centers), device=device, dtype=dtype) + + for i in range(len(bin_centers)): + mask = bin_indices == i + bin_counts[i] = mask.sum() + if mask.any(): + detect_sum[:, i] = gene_detects[:, mask].sum(dim=1) + + # Laplace smoothing + bin_detect_rate = (detect_sum + 1) / (bin_counts[None, :] + 2) + + # Fit logit model: y = b0 + b1 * bin_center (per gene) + logit_y = torch.log(bin_detect_rate / (1 - bin_detect_rate + eps) + eps) # [genes, bins] + X = torch.stack([torch.ones_like(bin_centers), bin_centers]) # [2, bins] + X = X.T # [bins, 2] + + # Solve (X^T X)^-1 X^T y for each gene + XT_X = X.T @ X # [2, 2] + XT_X_inv = torch.inverse(XT_X) # [2, 2] + XT = X.T # [2, bins] + b = XT_X_inv @ (XT @ logit_y.T) # [2, genes] + + b0, b1 = b[0], b[1] # [genes] + logit_pred = b0[:, None] + b1[:, None] * log_umi[None, :] # [genes, cells] + detect_p = torch.sigmoid(logit_pred) # [genes, cells] + + mu = detect_p + var = detect_p * (1 - detect_p) + x2 = detect_p + + return mu, var, x2 + + +def ct_bernoulli_model_linear(gene_detects, umi_counts, cell_types): + + mu_ct = np.zeros(len(cell_types)) + var_ct = np.zeros(len(cell_types)) + x2_ct = np.zeros(len(cell_types)) + + for cell_type in np.unique(cell_types): + gene_detects_ct = gene_detects[cell_types == cell_type] + umi_counts_ct = umi_counts[cell_types == cell_type] + + # We modify the 0 UMI counts to 1e-10 to remove the NaN values from the qcut output. + umi_counts_ct[umi_counts_ct == 0] = 1e-10 + + umi_count_bins, bins = pd.qcut( + np.log10(umi_counts_ct), N_BIN_TARGET, labels=False, retbins=True, duplicates="drop" + ) + bin_centers = np.array([bins[i] / 2 + bins[i + 1] / 2 for i in range(len(bins) - 1)]) + + N_BIN = len(bin_centers) + + bin_detects = bin_gene_detection(gene_detects_ct, umi_count_bins, N_BIN) + + lbin_detects = logit(bin_detects) + + X = np.ones((N_BIN, 2)) + X[:, 1] = bin_centers + Y = lbin_detects + + b = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y) + detect_p = ilogit(b[0] + b[1] * np.log10(umi_counts)) + + mu = detect_p + var = detect_p * (1 - detect_p) + x2 = detect_p + + mu_ct[cell_types == cell_type] = mu + var_ct[cell_types == cell_type] = var + x2_ct[cell_types == cell_type] = x2 + + return mu_ct, var_ct, x2_ct + + +bernoulli_model_torch = bernoulli_model_linear_torch +bernoulli_model = bernoulli_model_linear + + +@njit +def logit(p): + return np.log(p / (1 - p)) + + +@njit +def ilogit(q): + return np.exp(q) / (1 + np.exp(q)) + + +@njit +def bin_gene_detection(gene_detects, umi_count_bins, N_BIN): + bin_detects = np.zeros(N_BIN) + bin_totals = np.zeros(N_BIN) + + for i in range(len(gene_detects)): + x = gene_detects[i] + bin_i = umi_count_bins[i] + bin_detects[bin_i] += x + bin_totals[bin_i] += 1 + + # Need to account for 0% detects + # Add 1 to numerator and denominator + # Need to account for 100% detects + # Add 1 to denominator + + return (bin_detects + 1) / (bin_totals + 2) + + +def normal_model(gene_counts, umi_counts): + """ + Simplest Model - just assumes expression data is normal + UMI counts are regressed out + """ + X = np.vstack((np.ones(len(umi_counts)), umi_counts)).T + y = gene_counts.reshape((-1, 1)) + + if umi_counts.var() == 0: + mu = gene_counts.mean() + var = gene_counts.var() + mu = np.repeat(mu, len(umi_counts)) + var = np.repeat(var, len(umi_counts)) + x2 = mu**2 + var + + return mu, var, x2 + + B = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y) + + mu = X.dot(B) + + var = (y - mu).var() + var = np.repeat(var, len(umi_counts)) + + mu = mu.ravel() + + x2 = mu**2 + var + + return mu, var, x2 + + +def normal_model_torch(counts: torch.Tensor, umi_counts: torch.Tensor, eps: float = 1e-10): + """ + Vectorized Normal model in PyTorch for a batch of genes. + + Args: + counts: Tensor of shape [genes, cells], gene expression values. + umi_counts: Tensor of shape [cells], total UMI per cell. + eps: Small constant to avoid instability in matrix inversion. + + Returns + ------- + mu: Tensor of shape [genes, cells], predicted mean per gene per cell. + var: Tensor of shape [genes, cells], constant variance per gene. + x2: Tensor of shape [genes, cells], mu^2 + var. + """ + device = counts.device + genes, cells = counts.shape + + # Design matrix X: [cells, 2] + ones = torch.ones_like(umi_counts).unsqueeze(1) # [cells, 1] + umi = umi_counts.unsqueeze(1) # [cells, 1] + X = torch.cat([ones, umi], dim=1) # [cells, 2] + + XT = X.T # [2, cells] + XT_X = XT @ X # [2, 2] + try: + XT_X_inv = torch.inverse(XT_X + eps * torch.eye(2, device=device)) # [2, 2] + except RuntimeError: + raise ValueError("Design matrix is singular. Consider regularizing or filtering.") + + # Center y (counts) to [cells, genes] for regression, then transpose + Y = counts.T # [cells, genes] + B = XT_X_inv @ (XT @ Y) # [2, genes] + + mu = (X @ B).T # [genes, cells] + var = ((counts - mu) ** 2).mean(dim=1, keepdim=True) # [genes, 1] + var = var.expand_as(mu) # [genes, cells] + x2 = mu**2 + var # [genes, cells] + + return mu, var, x2 + + +def none_model(gene_counts, umi_counts): + + N = gene_counts.size + + mu = np.zeros(N) + var = np.ones(N) + x2 = np.ones(N) + + return mu, var, x2 + + +def none_model_torch(counts: torch.Tensor, umi_counts: torch.Tensor): + """ + 'None' model in PyTorch: returns zero mean and unit variance for all values. + + Args: + counts: Tensor of shape [genes, cells], ignored here. + umi_counts: Tensor of shape [cells], ignored here. + + Returns + ------- + mu: Tensor of zeros [genes, cells] + var: Tensor of ones [genes, cells] + x2: Tensor of ones [genes, cells] + """ + shape = counts.shape + device = counts.device + + mu = torch.zeros(shape, device=device) + var = torch.ones(shape, device=device) + x2 = torch.ones(shape, device=device) + + return mu, var, x2 + + +def apply_model_per_cell_type( + model_fn: Callable, + counts: torch.Tensor, + umi_counts: torch.Tensor, + cell_types: list | torch.Tensor, + **kwargs, +): + """ + Applies a model function to each cell type separately. + + Args: + model_fn: function of form (counts, umi_counts, **kwargs) -> (mu, var, x2) + counts: [genes, cells] tensor + umi_counts: [cells] tensor + cell_types: list or tensor of cell type labels, length = cells + kwargs: other model-specific arguments + + Returns + ------- + mu, var, x2: [genes, cells] tensors, concatenated across all cell types + """ + device = counts.device + + unique_types = cell_types.unique() + genes, cells = counts.shape + + mu_all = torch.empty((genes, cells), dtype=torch.float64, device=device) + var_all = torch.empty((genes, cells), dtype=torch.float64, device=device) + x2_all = torch.empty((genes, cells), dtype=torch.float64, device=device) + + cell_index = np.arange(cells) + + for ct in unique_types: + idx_array = cell_index[cell_types.values == ct] + idx = torch.tensor(idx_array, device=device) + + counts_ct = counts[:, idx] + umi_ct = umi_counts[idx] + + mu, var, x2 = model_fn(counts_ct, umi_ct, **kwargs) + + mu_all[:, idx] = mu + var_all[:, idx] = var + x2_all[:, idx] = x2 + + return mu_all, var_all, x2_all diff --git a/src/scvi/external/harreman/hotspot/modules.py b/src/scvi/external/harreman/hotspot/modules.py new file mode 100755 index 0000000000..c650023f02 --- /dev/null +++ b/src/scvi/external/harreman/hotspot/modules.py @@ -0,0 +1,813 @@ +import time +from typing import Literal + +import numpy as np +import pandas as pd +import torch +from anndata import AnnData +from scipy.cluster.hierarchy import linkage +from scipy.spatial.distance import squareform +from scipy.stats import hypergeom, norm, pearsonr, spearmanr, zscore +from sklearn.decomposition import PCA +from statsmodels.stats.multitest import multipletests +from tqdm import tqdm + +from ..preprocessing.anndata import counts_from_anndata +from ..tools.knn import make_weights_non_redundant +from .local_autocorrelation import center_counts_torch + + +def calculate_module_scores( + adata: AnnData, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Calculate module scores for gene modules across cells. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). Required fields in `adata.uns`: + - 'layer_key': name of the layer from which to extract the expression matrix + - 'model': statistical model used to normalize expression (e.g., 'DANB', 'normal') + - 'umi_counts': total UMI counts per cell + - 'gene_modules_dict': dictionary mapping module IDs (as strings) to lists of gene names + device : torch.device, optional + Device to use for computation (e.g., CUDA or CPU). Defaults to GPU if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + The following results are stored in the `AnnData` object: + - `adata.obsm['module_scores']`: (cells x modules) DataFrame with per-cell module activity scores + - `adata.varm['gene_loadings']`: (genes x modules) DataFrame with gene loadings for each module + - `adata.uns['gene_modules']`: dictionary mapping module names to gene lists + """ + start = time.time() + + layer_key = adata.uns["layer_key"] + model = adata.uns["model"] + + use_raw = layer_key == "use_raw" + modules = adata.uns["gene_modules_dict"].copy() + + umi_counts = adata.uns["umi_counts"] + + modules_to_compute = sorted([x for x in modules.keys() if x != "-1"]) + mod_list = [int(mod) for mod in modules_to_compute] + mod_list.sort() + modules_to_compute = [str(mod) for mod in mod_list] + + if verbose: + print(f"Computing scores for {len(modules_to_compute)} modules...") + + module_scores = {} + gene_loadings = pd.DataFrame(index=adata.var_names) + gene_modules = {} + for module in tqdm(modules_to_compute): + module_genes = modules[module] + + scores, loadings = compute_scores( + adata[:, module_genes], + layer_key, + model, + umi_counts, + device, + ) + + module_name = f"Module {module}" if "Module" not in module else module + module_scores[module_name] = scores + gene_loadings[module_name] = pd.Series(loadings, index=module_genes) + gene_modules[module_name] = module_genes + + module_scores = pd.DataFrame(module_scores) + + module_scores.index = adata.obs_names if not use_raw else adata.raw.obs.index + + adata.obsm["module_scores"] = module_scores + adata.varm["gene_loadings"] = gene_loadings + adata.uns["gene_modules"] = gene_modules + + if verbose: + print("Finished computing module scores in %.3f seconds" % (time.time() - start)) + + return + + +def compute_scores(adata, layer_key, model, num_umi, device, _lambda=0.9): + """ + counts_sub: row-subset of counts matrix with genes in the module + """ + # Get the weights matrix + weights = make_weights_non_redundant(adata.obsp["weights"]).tocoo() + weights = torch.sparse_coo_tensor( + torch.tensor(np.vstack((weights.row, weights.col)), dtype=torch.long, device=device), + torch.tensor(weights.data, dtype=torch.float64, device=device), + torch.Size(weights.shape), + device=device, + ) + + # Get gene expression counts for the module (dense) + counts_sub = counts_from_anndata(adata, layer_key, dense=True) + + # Convert to tensors + num_umi = torch.tensor(num_umi, dtype=torch.float64, device=device) + counts_sub = torch.tensor(counts_sub, dtype=torch.float64, device=device) + + # Center values + sample_specific = "sample_key" in adata.uns.keys() + if sample_specific: + sample_key = adata.uns["sample_key"] + for sample in adata.obs[sample_key].unique(): + subset = np.where(adata.obs[sample_key] == sample)[0] + counts_sub[:, subset] = center_counts_torch( + counts_sub[:, subset], num_umi[subset], model + ) + else: + counts_sub = center_counts_torch(counts_sub, num_umi, model) + + # Smooth the counts using weights + out = torch.matmul( + weights + weights.transpose(0, 1), counts_sub.T + ) # (cells x cells) @ (cells x genes)^T = (cells x genes)^T + weights_sum = ( + torch.sparse.sum(weights, dim=0).to_dense() + torch.sparse.sum(weights, dim=1).to_dense() + ) # shape (cells,) + weights_sum[weights_sum == 0] = 1.0 + out = out / weights_sum[:, None] # normalize + cc_smooth = _lambda * out.T + (1 - _lambda) * counts_sub # (genes x cells) + + # Perform PCA on cells (transpose to cells x genes) + pca_data = cc_smooth.T.cpu().numpy() + pca = PCA(n_components=1) + scores = pca.fit_transform(pca_data) + loadings = pca.components_.T + + # Flip sign if needed + if pca.components_.mean() < 0: + scores *= -1 + loadings *= -1 + scores = scores[:, 0] + loadings = loadings[:, 0] + + return scores, loadings + + +def sort_linkage(Z, node_index, node_values): + """ + Sorts linkage by 'node_values' in place + """ + N = Z.shape[0] + 1 # number of leaves + + if node_index < 0: + return + + left_child = int(Z[node_index, 0] - N) + right_child = int(Z[node_index, 1] - N) + + swap = False + + if left_child < 0 and right_child < 0: + swap = False + elif left_child < 0 and right_child >= 0: + swap = True + elif left_child >= 0 and right_child < 0: + swap = False + else: + if node_values[left_child] > node_values[right_child]: + swap = True + else: + swap = False + + if swap: + Z[node_index, 0] = right_child + N + Z[node_index, 1] = left_child + N + + sort_linkage(Z, left_child, node_values) + sort_linkage(Z, right_child, node_values) + + +def calc_mean_dists(Z, node_index, out_mean_dists): + """ + Calculates the mean density of joins + for sub-trees underneath each node + """ + N = Z.shape[0] + 1 # number of leaves + + left_child = int(Z[node_index, 0] - N) + right_child = int(Z[node_index, 1] - N) + + if left_child < 0: + left_average = 0 + left_merges = 0 + else: + left_average, left_merges = calc_mean_dists(Z, left_child, out_mean_dists) + + if right_child < 0: + right_average = 0 + right_merges = 0 + else: + right_average, right_merges = calc_mean_dists(Z, right_child, out_mean_dists) + + this_height = Z[node_index, 2] + this_merges = left_merges + right_merges + 1 + this_average = ( + left_average * left_merges + right_average * right_merges + this_height + ) / this_merges + + out_mean_dists[node_index] = this_average + + return this_average, this_merges + + +def prop_label(Z, node_index, label, labels, out_clusters): + """ + Propagates node labels downward if they are not -1 + Used to find the correct cluster label at the leaves + """ + N = Z.shape[0] + 1 # number of leaves + + if label == -1: + label = labels[node_index] + + left_child = int(Z[node_index, 0] - N) + right_child = int(Z[node_index, 1] - N) + + if left_child < 0: + out_clusters[left_child + N] = label + else: + prop_label(Z, left_child, label, labels, out_clusters) + + if right_child < 0: + out_clusters[right_child + N] = label + else: + prop_label(Z, right_child, label, labels, out_clusters) + + +def prop_label2(Z, node_index, label, labels, out_clusters): + """ + Propagates node labels downward + Helper method used in assign_modules + """ + N = Z.shape[0] + 1 # number of leaves + + parent_label = label + this_label = labels[node_index] + + if this_label == -1: + new_label = parent_label + else: + new_label = this_label + + left_child = int(Z[node_index, 0] - N) + right_child = int(Z[node_index, 1] - N) + + if left_child < 0: + out_clusters[left_child + N] = new_label + else: + prop_label2(Z, left_child, new_label, labels, out_clusters) + + if right_child < 0: + out_clusters[right_child + N] = new_label + else: + prop_label2(Z, right_child, new_label, labels, out_clusters) + + +def assign_modules(Z, leaf_labels, offset, MIN_THRESHOLD=10, Z_THRESHOLD=3): + clust_i = 0 + + labels = np.ones(Z.shape[0]) * -1 + N = Z.shape[0] + 1 + + mean_dists = np.zeros(Z.shape[0]) + calc_mean_dists(Z, Z.shape[0] - 1, mean_dists) + + for i in range(Z.shape[0]): + ca = int(Z[i, 0]) + cb = int(Z[i, 1]) + + if ca - N < 0: # leaf node + n_members_a = 1 + clust_a = -1 + else: + n_members_a = Z[ca - N, 3] + clust_a = labels[ca - N] + + if cb - N < 0: # leaf node + n_members_b = 1 + clust_b = -1 + else: + n_members_b = Z[cb - N, 3] + clust_b = labels[cb - N] + + if Z[i, 2] > offset - Z_THRESHOLD: + new_clust_assign = -1 + elif n_members_a >= MIN_THRESHOLD and n_members_b >= MIN_THRESHOLD: + # don't join them + # assign the one with the larger mean distance + dist_a = mean_dists[ca - N] + dist_b = mean_dists[cb - N] + if dist_a >= dist_b: + new_clust_assign = clust_a + else: + new_clust_assign = clust_b + elif n_members_a >= MIN_THRESHOLD: + new_clust_assign = clust_a + elif n_members_b >= MIN_THRESHOLD: + new_clust_assign = clust_b + elif (n_members_b + n_members_a) >= MIN_THRESHOLD: + # A new cluster is born! + new_clust_assign = clust_i + clust_i += 1 + else: + new_clust_assign = -1 # Still too small + + labels[i] = new_clust_assign + + out_clusters = np.ones(N) * -2 + prop_label2(Z, Z.shape[0] - 1, labels[-1], labels, out_clusters) + + # remap out_clusters + unique_clusters = list(np.sort(np.unique(out_clusters))) + + if -1 in unique_clusters: + unique_clusters.remove(-1) + + clust_map = {x: i + 1 for i, x in enumerate(unique_clusters)} + clust_map[-1] = -1 + + out_clusters = [clust_map[x] for x in out_clusters] + out_clusters = pd.Series(out_clusters, index=leaf_labels) + + return out_clusters + + +def assign_modules_core(Z, leaf_labels, offset, MIN_THRESHOLD=10, Z_THRESHOLD=3): + clust_i = 0 + + labels = np.ones(Z.shape[0]) * -1 + N = Z.shape[0] + 1 + + for i in range(Z.shape[0]): + ca = int(Z[i, 0]) + cb = int(Z[i, 1]) + + if ca - N < 0: # leaf node + n_members_a = 1 + clust_a = -1 + else: + n_members_a = Z[ca - N, 3] + clust_a = labels[ca - N] + + if cb - N < 0: # leaf node + n_members_b = 1 + clust_b = -1 + else: + n_members_b = Z[cb - N, 3] + clust_b = labels[cb - N] + + if n_members_a >= MIN_THRESHOLD and n_members_b >= MIN_THRESHOLD: + # don't join them + new_clust_assign = -1 + elif Z[i, 2] > offset - Z_THRESHOLD: + new_clust_assign = -1 + elif n_members_a >= MIN_THRESHOLD: + new_clust_assign = clust_a + elif n_members_b >= MIN_THRESHOLD: + new_clust_assign = clust_b + elif (n_members_b + n_members_a) >= MIN_THRESHOLD: + # A new cluster is born! + new_clust_assign = clust_i + clust_i += 1 + else: + new_clust_assign = -1 # Still too small + + labels[i] = new_clust_assign + + out_clusters = np.ones(N) * -2 + prop_label(Z, Z.shape[0] - 1, labels[-1], labels, out_clusters) + + # remap out_clusters + unique_clusters = list(np.sort(np.unique(out_clusters))) + + if -1 in unique_clusters: + unique_clusters.remove(-1) + + clust_map = {x: i + 1 for i, x in enumerate(unique_clusters)} + clust_map[-1] = -1 + + out_clusters = [clust_map[x] for x in out_clusters] + out_clusters = pd.Series(out_clusters, index=leaf_labels) + + return out_clusters + + +def create_modules( + adata: AnnData, + min_gene_threshold: int | None = 20, + fdr_threshold: float | None = 0.05, + z_threshold: float | None = None, + core_only: bool = False, + verbose: bool | None = False, +): + """ + Perform hierarchical clustering on gene-gene local correlation Z-scores to assign gene modules. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData) containing the local correlation Z-scores in `adata.uns['lc_zs']`. + min_gene_threshold : int, optional (default: 20) + Minimum number of genes required to define a module. + fdr_threshold : float, optional (default: 0.05) + FDR threshold used to determine the minimum Z-score significance if `z_threshold` is not provided. + z_threshold : float, optional + If provided, uses this Z-score as the cutoff for module inclusion instead of computing it from FDR. + core_only : bool, optional (default: False) + If True, assigns only tightly correlated (core) genes to modules and leaves others unassigned. + + Returns + ------- + None + The function modifies the `AnnData` object in place by adding the following to `adata.uns`: + - `modules`: pandas Series mapping each gene to a module ID (integer, as string) + - `gene_modules_dict`: dictionary mapping module IDs (as strings) to lists of gene names + - `linkage`: linkage matrix from hierarchical clustering (for visualization or tree operations) + """ + start = time.time() + if verbose: + print("Creating modules...") + + # Determine Z_Threshold from FDR threshold + + Z_scores = adata.uns["lc_zs"] + + if z_threshold is None: + allZ = squareform( # just in case slightly not symmetric + Z_scores.values / 2 + Z_scores.values.T / 2 + ) + allZ = np.sort(allZ) + allP = norm.sf(allZ) + allP_c = multipletests(allP, method="fdr_bh")[1] + ii = np.nonzero(allP_c < fdr_threshold)[0] + if ii.size > 0: + z_threshold = allZ[ii[0]] + else: + z_threshold = allZ[-1] + 1 + + # Compute the linkage matrix + dd = Z_scores.copy().values + np.fill_diagonal(dd, 0) + condensed = squareform(dd) * -1 + offset = condensed.min() * -1 + condensed += offset + Z = linkage(condensed, method="average") + + # Linkage -> Modules + if core_only: + out_clusters = assign_modules_core( + Z, + offset=offset, + MIN_THRESHOLD=min_gene_threshold, + leaf_labels=Z_scores.index, + Z_THRESHOLD=z_threshold, + ) + else: + out_clusters = assign_modules( + Z, + offset=offset, + MIN_THRESHOLD=min_gene_threshold, + leaf_labels=Z_scores.index, + Z_THRESHOLD=z_threshold, + ) + + # Sort the leaves of the linkage matrix (for plotting) + mean_dists = np.zeros(Z.shape[0]) + calc_mean_dists(Z, Z.shape[0] - 1, mean_dists) + linkage_out = Z.copy() + sort_linkage(linkage_out, Z.shape[0] - 1, mean_dists) + + out_clusters.name = "Module" + + gene_modules_dict = {} + for mod in out_clusters.unique(): + gene_modules_dict[str(mod)] = out_clusters[out_clusters == mod].index.tolist() + + adata.uns["modules"] = out_clusters + adata.uns["gene_modules_dict"] = gene_modules_dict + adata.uns["linkage"] = linkage_out + + if verbose: + print("Finished creating modules in %.3f seconds" % (time.time() - start)) + + return + + +def compute_sig_mod_enrichment(adata, norm_data_key, signature_varm_key, use_super_modules): + + gene_modules_key = "gene_modules_sm" if use_super_modules else "gene_modules" + + use_raw = norm_data_key == "use_raw" + genes = adata.raw.var.index if use_raw else adata.var_names + + sig_matrix = ( + adata.varm[signature_varm_key] if not use_raw else adata.raw.varm[signature_varm_key] + ) + gene_modules = adata.uns[gene_modules_key] + + signatures = {} + + for signature in sig_matrix.columns: + if all(x in sig_matrix[signature].unique().tolist() for x in [-1, 1]): + sig_genes_up = sig_matrix[sig_matrix[signature] == 1].index.tolist() + sig_genes_down = sig_matrix[sig_matrix[signature] == -1].index.tolist() + + sig_name_up = signature + "_UP" + sig_name_down = signature + "_DOWN" + + signatures[sig_name_up] = sig_genes_up + signatures[sig_name_down] = sig_genes_down + else: + sig_genes = sig_matrix[sig_matrix[signature] != 0].index.tolist() + signatures[signature] = sig_genes + + pvals_df = pd.DataFrame( + np.nan, index=list(signatures.keys()), columns=list(gene_modules.keys()) + ) + stats_df = pd.DataFrame( + np.nan, index=list(signatures.keys()), columns=list(gene_modules.keys()) + ) + + sig_mod_df = pd.DataFrame(index=genes) + + universe = adata.var_names[adata.var["local_autocorrelation"] == True].tolist() + + # We make sure that the genes present in the signature are just the ones included in the universe + signatures = { + sig: [gene for gene in genes if gene in universe] for sig, genes in signatures.items() + } + + for signature in signatures.keys(): + sig_genes = signatures[signature] + + for module in gene_modules.keys(): + mod_genes = gene_modules[module] + sig_mod_genes = list(set(sig_genes) & set(mod_genes)) + + M = len(universe) + n = len(sig_genes) + N = len(mod_genes) + x = len(sig_mod_genes) + + pval = hypergeom.sf(x - 1, M, n, N) + + if pval < 0.05: + sig_mod_name = signature + "_OVERLAP_" + module + sig_mod_df[sig_mod_name] = 0 + sig_mod_df.loc[sig_mod_genes, sig_mod_name] = 1.0 + + e_overlap = n * N / M + stat = np.log2(x / e_overlap) if e_overlap != 0 else 0 + + pvals_df.loc[signature, module] = pval + stats_df.loc[signature, module] = stat + + FDR_values = multipletests(pvals_df.unstack().values, method="fdr_bh")[1] + FDR_df = pd.Series(FDR_values, index=pvals_df.stack().index).unstack() + + adata.varm["signatures_overlap"] = sig_mod_df + + return pvals_df, stats_df, FDR_df + + +def compute_sig_mod_correlation(adata, method, use_super_modules): + + module_scores_key = "super_module_scores" if use_super_modules else "module_scores" + + signatures = adata.obsm["vision_signatures"].columns.tolist() + modules = adata.obsm[module_scores_key].columns.tolist() + + cor_pval_df = pd.DataFrame(index=modules) + cor_coef_df = pd.DataFrame(index=modules) + + for signature in signatures: + correlation_values = [] + pvals = [] + + for module in modules: + signature_df = adata.obsm["vision_signatures"][signature] + module_df = adata.obsm[module_scores_key][module] + + if method == "pearson": + correlation_value, pval = pearsonr(signature_df, module_df) + elif method == "spearman": + correlation_value, pval = spearmanr(signature_df, module_df) + + correlation_values.append(correlation_value) + pvals.append(pval) + + cor_coef_df[signature] = correlation_values + cor_pval_df[signature] = pvals + + cor_FDR_values = multipletests(cor_pval_df.unstack().values, method="fdr_bh")[1] + cor_FDR_df = pd.Series(cor_FDR_values, index=cor_pval_df.stack().index).unstack() + + return cor_coef_df, cor_pval_df, cor_FDR_df + + +def integrate_vision_hotspot_results( + adata: AnnData, + cor_method: Literal["pearson"] | Literal["spearman"] | None = "pearson", + use_super_modules: bool | None = False, +): + + gene_modules_key = "gene_modules_sm" if use_super_modules else "gene_modules" + + if ("vision_signatures" in adata.obsm) and (len(adata.uns[gene_modules_key].keys()) > 0): + start = time.time() + print("Integrating VISION and Hotspot results...") + + norm_data_key = adata.uns["norm_data_key"] + signature_varm_key = adata.uns["signature_varm_key"] + + pvals_df, stats_df, FDR_df = compute_sig_mod_enrichment( + adata, norm_data_key, signature_varm_key, use_super_modules + ) + adata.uns["sig_mod_enrichment_stats"] = stats_df + adata.uns["sig_mod_enrichment_pvals"] = pvals_df + adata.uns["sig_mod_enrichment_FDR"] = FDR_df + + if cor_method not in ["pearson", "spearman"]: + raise ValueError( + f'Invalid method: {cor_method}. Choose either "pearson" or "spearman".' + ) + + adata.uns["cor_method"] = cor_method + + cor_coef_df, cor_pval_df, cor_FDR_df = compute_sig_mod_correlation( + adata, cor_method, use_super_modules + ) + adata.uns["sig_mod_correlation_coefs"] = cor_coef_df + adata.uns["sig_mod_correlation_pvals"] = cor_pval_df + adata.uns["sig_mod_correlation_FDR"] = cor_FDR_df + + from scvi.external.harreman.vision.signature import compute_signatures_anndata + + adata.obsm["signature_modules_overlap"] = compute_signatures_anndata( + adata, + norm_data_key, + signature_varm_key="signatures_overlap", + signature_names_uns_key=None, + ) + + print( + "Finished integrating VISION and Hotspot results in %.3f seconds" + % (time.time() - start) + ) + + else: + raise ValueError( + "Please make sure VISION has been run and Hotspot has identified at least one module." + ) + + return + + +def compute_top_scoring_modules( + adata: AnnData, + sd: float | None = 1, + use_super_modules: bool | None = False, +): + """ + Identify the top-scoring module (or super-module) for each cell. + + Parameters + ---------- + adata : AnnData + Must contain a matrix of module or super-module scores in: + - ``obsm['module_scores']`` or + - ``obsm['super_module_scores']`` + sd : float, default 1 + Standard deviation threshold to determine strong module activation. + Only modules with ``zscore > sd`` are considered strictly activated. + use_super_modules : bool, default False + If True, select among super-modules instead of standard modules. + + Returns + ------- + pandas.Series + A Series indexed by cell, containing the name of the top-scoring + module/super-module for each cell. + """ + MODULE_KEY = "super_module_scores" if use_super_modules else "module_scores" + + df = pd.DataFrame( + zscore(adata.obsm[MODULE_KEY], axis=0), + index=adata.obsm[MODULE_KEY].index, + columns=adata.obsm[MODULE_KEY].columns, + ) + + top_scoring_modules = pd.Series(index=df.index) + for mod_id, row in df.iterrows(): + above_threshold_low = row > 0 + above_threshold = row > sd + if above_threshold.sum() == 1: + top_scoring_modules[mod_id] = above_threshold.idxmax() + else: + highest_module = ( + row[above_threshold].idxmax() + if above_threshold.sum() > 1 + else row.idxmax() + if above_threshold_low.sum() > 0 + else np.nan + ) + top_scoring_modules[mod_id] = highest_module + + return top_scoring_modules + + +def calculate_super_module_scores( + adata: AnnData, + super_module_dict: dict = None, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Calculate super-module scores for gene super-modules across cells. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). Required fields in `adata.uns`: + - 'layer_key': name of the layer from which to extract the expression matrix + - 'model': statistical model used to normalize expression (e.g., 'DANB', 'normal') + - 'umi_counts': total UMI counts per cell + - 'gene_modules': dictionary mapping module IDs (as strings) to lists of gene names + super_module_dict: dict + Dictionary containing super-module IDs (integers) as keys and a list of associated modules (as integers) as values + device : torch.device, optional + Device to use for computation (e.g., CUDA or CPU). Defaults to GPU if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + The following results are stored in the `AnnData` object: + - `adata.obsm['super_module_scores']`: (cells x super-modules) DataFrame with per-cell super-module activity scores + - `adata.varm['gene_loadings_sm']`: (genes x super-modules) DataFrame with gene loadings for each super-module + - `adata.uns['gene_modules_sm']`: dictionary mapping super-module names to gene lists + """ + start = time.time() + + gene_modules = adata.uns["gene_modules"] + + reverse_mapping = {value: key for key, values in super_module_dict.items() for value in values} + adata.uns["super_modules"] = adata.uns["modules"].replace(reverse_mapping) + + super_module_dict = {key: values for key, values in super_module_dict.items() if key != -1} + + layer_key = adata.uns["layer_key"] + model = adata.uns["model"] + + use_raw = layer_key == "use_raw" + + umi_counts = adata.uns["umi_counts"] + + if verbose: + print(f"Computing scores for {len(super_module_dict.keys())} super-modules...") + + super_module_scores = {} + gene_loadings_sm = pd.DataFrame(index=adata.var_names) + gene_modules_sm = {} + for sm, modules in tqdm(super_module_dict.items()): + super_module = f"Module {sm}" + modules = [f"Module {str(mod)}" for mod in modules] + super_module_genes = [item for key in modules for item in gene_modules.get(key, [])] + + scores, loadings = compute_scores( + adata[:, super_module_genes], + layer_key, + model, + umi_counts, + device, + ) + + super_module_scores[super_module] = scores + gene_loadings_sm[super_module] = pd.Series(loadings, index=super_module_genes) + gene_modules_sm[super_module] = super_module_genes + + super_module_scores = pd.DataFrame(super_module_scores) + super_module_scores.index = adata.obs_names if not use_raw else adata.raw.obs.index + + adata.varm["gene_loadings_sm"] = gene_loadings_sm + + adata.obsm["super_module_scores"] = super_module_scores + adata.uns["gene_modules_sm"] = gene_modules_sm + + print("Finished computing super-module scores in %.3f seconds" % (time.time() - start)) + + return diff --git a/src/scvi/external/harreman/preprocessing/__init__.py b/src/scvi/external/harreman/preprocessing/__init__.py new file mode 100755 index 0000000000..d8da2e61f0 --- /dev/null +++ b/src/scvi/external/harreman/preprocessing/__init__.py @@ -0,0 +1,2 @@ +from .anndata import read_h5ad, setup_deconv_adata, write_h5ad +from .database import extract_interaction_db diff --git a/src/scvi/external/harreman/preprocessing/anndata.py b/src/scvi/external/harreman/preprocessing/anndata.py new file mode 100755 index 0000000000..72c71b88f2 --- /dev/null +++ b/src/scvi/external/harreman/preprocessing/anndata.py @@ -0,0 +1,592 @@ +import time + +import anndata +import numpy as np +import pandas as pd +from scipy.sparse import issparse + + +def counts_from_anndata(adata, layer_key=None, dense=False): + # 1. Extract matrix + if layer_key is None: + counts = adata.X + elif layer_key == "use_raw": + counts = adata.raw.X + else: + counts = adata.layers[layer_key] + + # 2. Transpose efficiently + if issparse(counts): + counts = counts.transpose().tocsr(copy=False) # keep CSR format for efficient row slicing + else: + counts = counts.T # transpose numpy array directly + + # 3. Convert to dense if requested + if dense: + if issparse(counts): + counts = counts.toarray() + else: + counts = np.asarray(counts) + + return counts + + +def setup_anndata( + input_adata: anndata.AnnData, + cell_types: list, + compute_neighbors_on_key: str, + cell_type_key: str, + database_varm_key: str, + sample_key: str | None, + spot_diameter: int, +) -> anndata.AnnData: + + barcode_key = "barcodes" + database = input_adata.varm[database_varm_key] + obs_names = input_adata.obs_names + var_names = input_adata.var_names + obsm_neighbors = input_adata.obsm[compute_neighbors_on_key] + + adatas = [] + for ct in cell_types: + if ct not in input_adata.layers: + continue + X = input_adata.layers[ct] + obs = {barcode_key: obs_names.values, cell_type_key: [ct] * len(obs_names)} + if sample_key is not None: + obs[sample_key] = input_adata.obs[sample_key].values + + adata = anndata.AnnData( + X=X, + obs=obs, + var=pd.DataFrame(index=var_names), + obsm={compute_neighbors_on_key: obsm_neighbors}, + ) + adata.obs_names = [f"{name}_{ct}" for name in obs_names] + adatas.append(adata) + + out_adata = anndata.concat(adatas, axis=0) + + out_adata.uns.update( + { + "database_varm_key": database_varm_key, + "spot_diameter": spot_diameter, + "barcode_key": barcode_key, + } + ) + out_adata.varm[database_varm_key] = database + + # Remove empty rows efficiently + nonzero_mask = ( + out_adata.X.sum(axis=1).A1 > 0 + if hasattr(out_adata.X, "A1") + else out_adata.X.sum(axis=1) > 0 + ) + out_adata._inplace_subset_obs(nonzero_mask) + + return out_adata + + +def setup_deconv_adata( + adata: anndata.AnnData, + compute_neighbors_on_key: str | None = None, + sample_key: str | None = None, + cell_type_list: list | None = None, + cell_type_key: str | None = None, + spot_diameter: int | None = None, + verbose: bool | None = False, +): + """Set up deconvolution AnnData. + + Parameters + ---------- + adata + AnnData object. + compute_neighbors_on_key + Key in `adata.obsm` to use for computing neighbors. If `None`, use neighbors stored in `adata`. If no neighbors have been previously computed an error will be raised. + sample_key + Sample information in case the data contains different samples or samples from different conditions. Input is key in `adata.obs`. + cell_type_list + Cell type or cluster information for the cell-cell communication analysis. Input is a list of keys in `adata.layers`. + cell_type_key + Cell type or cluster information for the cell-cell communication analysis. Input is key in `adata.obs`. + spot_diameter + Spot diameter of the spatial technology the dataset belongs to. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + """ + start = time.time() + if verbose: + print("Setting up deconvolution data...") + + uns = adata.uns + deconv_adata = setup_anndata( + input_adata=adata, + cell_types=cell_type_list, + compute_neighbors_on_key=compute_neighbors_on_key, + cell_type_key=cell_type_key, + database_varm_key=uns["database_varm_key"], + sample_key=sample_key, + spot_diameter=spot_diameter, + ) + + deconv_adata.uns.update( + { + "cell_type_key": cell_type_key, + "layer_key": None, + "deconv_data": True, + "database": uns["database"], + } + ) + + if uns["database"] in {"LR", "both"}: + deconv_adata.uns.update( + { + "ligand": uns["ligand"], + "receptor": uns["receptor"], + "LR_database": uns["LR_database"], + } + ) + + if uns["database"] in {"transporter", "both"}: + deconv_adata.uns.update( + { + "importer": uns["importer"], + "exporter": uns["exporter"], + "import_export": uns["import_export"], + "num_metabolites": uns["num_metabolites"], + "metabolite_database": uns["metabolite_database"], + } + ) + + if verbose: + print("Finished setting up deconvolution data in %.3f seconds" % (time.time() - start)) + + return adata, deconv_adata + + +def modify_uns_hotspot(adata): + if "modules" in adata.uns.keys(): + adata.var["modules"] = adata.uns["modules"] + del adata.uns["modules"] + + if "super_modules" in adata.uns.keys(): + adata.var["super_modules"] = adata.uns["super_modules"] + del adata.uns["super_modules"] + + if "lc_zs" in adata.uns.keys(): + genes = [ + " - ".join(gene) if isinstance(gene, tuple) else gene + for gene in adata.uns["lc_zs"].columns + ] + adata.uns["lc_zs"].index = genes + adata.uns["lc_zs"].columns = genes + + return + + +def modify_uns_harreman(adata): + uns_keys = ["ligand", "receptor", "LR_database", "import_export"] + for uns_key in uns_keys: + if uns_key in adata.uns.keys(): + adata.uns[uns_key] = adata.uns[uns_key].fillna("NA") + + if "LR_database" in adata.uns.keys(): + adata.uns["LR_database"].columns = [ + col.replace(".", "_") for col in adata.uns["LR_database"].columns + ] + adata.uns["LR_database"]["ligand_transmembrane"] = adata.uns["LR_database"][ + "ligand_transmembrane" + ].astype(str) + adata.uns["LR_database"]["receptor_transmembrane"] = adata.uns["LR_database"][ + "receptor_transmembrane" + ].astype(str) + + if "gene_pairs" in adata.uns.keys(): + gene_pairs_tmp = [ + (x, " - ".join(y) if isinstance(y, (list, tuple)) else y) + for x, y in adata.uns["gene_pairs"] + ] + gene_pairs_tmp = [ + (" - ".join(x) if isinstance(x, (list, tuple)) else x, y) for x, y in gene_pairs_tmp + ] + adata.uns["gene_pairs"] = ["_".join(gp) for gp in gene_pairs_tmp] + + if "gene_pairs_sig" in adata.uns.keys(): + gene_pairs_sig_tmp = [ + (x, " - ".join(y) if isinstance(y, (list, tuple)) else y) + for x, y in adata.uns["gene_pairs_sig"] + ] + gene_pairs_sig_tmp = [ + (" - ".join(x) if isinstance(x, (list, tuple)) else x, y) + for x, y in gene_pairs_sig_tmp + ] + adata.uns["gene_pairs_sig"] = ["_".join(gp) for gp in gene_pairs_sig_tmp] + + if "gene_pairs_ind" in adata.uns.keys(): + gene_pairs_ind_tmp = [ + (x, " - ".join(map(str, y)) if isinstance(y, (list, tuple)) else str(y)) + for x, y in adata.uns["gene_pairs_ind"] + ] + gene_pairs_ind_tmp = [ + (" - ".join(map(str, x)) if isinstance(x, (list, tuple)) else str(x), y) + for x, y in gene_pairs_ind_tmp + ] + adata.uns["gene_pairs_ind"] = ["_".join(gp) for gp in gene_pairs_ind_tmp] + + if "gene_pairs_sig_ind" in adata.uns.keys(): + gene_pairs_sig_ind_tmp = [ + (x, " - ".join(map(str, y)) if isinstance(y, (list, tuple)) else str(y)) + for x, y in adata.uns["gene_pairs_sig_ind"] + ] + gene_pairs_sig_ind_tmp = [ + (" - ".join(map(str, x)) if isinstance(x, (list, tuple)) else str(x), y) + for x, y in gene_pairs_sig_ind_tmp + ] + adata.uns["gene_pairs_sig_ind"] = ["_".join(gp) for gp in gene_pairs_sig_ind_tmp] + + if "gene_pairs_per_metabolite" in adata.uns.keys(): + adata.uns["gene_pairs_per_metabolite"] = { + key: { + "gene_pair": [ + (" - ".join(gp_1) if isinstance(gp_1, (list, tuple)) else gp_1, gp_2) + for gp_1, gp_2 in subdict["gene_pair"] + ], + "gene_type": subdict["gene_type"], + } + for key, subdict in adata.uns["gene_pairs_per_metabolite"].items() + } + adata.uns["gene_pairs_per_metabolite"] = { + key: { + "gene_pair": [ + (gp_1, " - ".join(gp_2) if isinstance(gp_2, (list, tuple)) else gp_2) + for gp_1, gp_2 in subdict["gene_pair"] + ], + "gene_type": subdict["gene_type"], + } + for key, subdict in adata.uns["gene_pairs_per_metabolite"].items() + } + + if "gene_pairs_per_ct_pair" in adata.uns.keys(): + adata.uns["gene_pairs_per_ct_pair"] = { + key: [ + (x, " - ".join(y) if isinstance(y, (list, tuple)) else y) for x, y in tuples_list + ] + for key, tuples_list in adata.uns["gene_pairs_per_ct_pair"].items() + } + adata.uns["gene_pairs_per_ct_pair"] = { + key: [ + (" - ".join(x) if isinstance(x, (list, tuple)) else x, y) for x, y in tuples_list + ] + for key, tuples_list in adata.uns["gene_pairs_per_ct_pair"].items() + } + adata.uns["gene_pairs_per_ct_pair"] = { + " - ".join(key): value for key, value in adata.uns["gene_pairs_per_ct_pair"].items() + } + + if "gene_pairs_per_ct_pair_ind" in adata.uns.keys(): + adata.uns["gene_pairs_per_ct_pair_ind"] = { + " - ".join(key): value + for key, value in adata.uns["gene_pairs_per_ct_pair_ind"].items() + } + + if "gene_pairs_per_ct_pair_sig_ind" in adata.uns.keys(): + adata.uns["gene_pairs_per_ct_pair_sig_ind"] = { + " - ".join(key): value + for key, value in adata.uns["gene_pairs_per_ct_pair_sig_ind"].items() + } + + if "gene_pairs_ind_per_ct_pair" in adata.uns.keys(): + adata.uns["gene_pairs_ind_per_ct_pair"] = { + key: [ + (x, " - ".join(map(str, y)) if isinstance(y, (list, tuple)) else y) + for x, y in tuples_list + ] + for key, tuples_list in adata.uns["gene_pairs_ind_per_ct_pair"].items() + } + adata.uns["gene_pairs_ind_per_ct_pair"] = { + key: [ + (" - ".join(map(str, x)) if isinstance(x, (list, tuple)) else x, y) + for x, y in tuples_list + ] + for key, tuples_list in adata.uns["gene_pairs_ind_per_ct_pair"].items() + } + adata.uns["gene_pairs_ind_per_ct_pair"] = { + " - ".join(key): value + for key, value in adata.uns["gene_pairs_ind_per_ct_pair"].items() + } + + if "gene_pairs_ind_per_ct_pair_sig" in adata.uns.keys(): + adata.uns["gene_pairs_ind_per_ct_pair_sig"] = { + key: [ + (x, " - ".join(map(str, y)) if isinstance(y, (list, tuple)) else y) + for x, y in tuples_list + ] + for key, tuples_list in adata.uns["gene_pairs_ind_per_ct_pair_sig"].items() + } + adata.uns["gene_pairs_ind_per_ct_pair_sig"] = { + key: [ + (" - ".join(map(str, x)) if isinstance(x, (list, tuple)) else x, y) + for x, y in tuples_list + ] + for key, tuples_list in adata.uns["gene_pairs_ind_per_ct_pair_sig"].items() + } + adata.uns["gene_pairs_ind_per_ct_pair_sig"] = { + " - ".join(key): value + for key, value in adata.uns["gene_pairs_ind_per_ct_pair_sig"].items() + } + + if "ccc_results" in adata.uns.keys(): + adata.uns["ccc_results"]["cell_com_df_gp"] = adata.uns["ccc_results"][ + "cell_com_df_gp" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + adata.uns["ccc_results"]["cell_com_df_m"] = adata.uns["ccc_results"][ + "cell_com_df_m" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + if "cell_com_df_gp_sig" in adata.uns["ccc_results"].keys(): + adata.uns["ccc_results"]["cell_com_df_gp_sig"] = adata.uns["ccc_results"][ + "cell_com_df_gp_sig" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + adata.uns["ccc_results"]["cell_com_df_m_sig"] = adata.uns["ccc_results"][ + "cell_com_df_m_sig" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + + if "ct_ccc_results" in adata.uns.keys(): + adata.uns["ct_ccc_results"]["cell_com_df_gp"] = adata.uns["ct_ccc_results"][ + "cell_com_df_gp" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + adata.uns["ct_ccc_results"]["cell_com_df_m"] = adata.uns["ct_ccc_results"][ + "cell_com_df_m" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + if "cell_com_df_gp_sig" in adata.uns["ct_ccc_results"].keys(): + adata.uns["ct_ccc_results"]["cell_com_df_gp_sig"] = adata.uns["ct_ccc_results"][ + "cell_com_df_gp_sig" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + adata.uns["ct_ccc_results"]["cell_com_df_m_sig"] = adata.uns["ct_ccc_results"][ + "cell_com_df_m_sig" + ].applymap(lambda x: " - ".join(x) if isinstance(x, (list, tuple)) else x) + + return + + +def write_h5ad( + adata: anndata.AnnData, + filename: str | None = None, +): + """ + Save an AnnData object to disk with Harreman-compatible preprocessing. + + Parameters + ---------- + adata : AnnData + The AnnData object to save. + filename : str, optional + Path to the output `.h5ad` file. If omitted, an error is raised. + + Notes + ----- + The wrapper ensures that custom Harreman/Hotspot fields remain fully restorable + when loading the file with `read_h5ad()`. + """ + if filename is None: + raise ValueError("Please provide the path to save the file.") + elif not filename.endswith("h5ad"): + filename = filename + ".h5ad" + + if "distances" in adata.obsp.keys(): + adata.obsp["distances"] = adata.obsp["distances"].tocsr() + + modify_uns_hotspot(adata) + modify_uns_harreman(adata) + adata.write(filename) + + +def recover_uns_hotspot(adata): + if "modules" not in adata.uns and "modules" in adata.var.columns: + adata.uns["modules"] = adata.var["modules"].dropna().astype(int).copy() + del adata.var["modules"] + + if "super_modules" not in adata.uns and "super_modules" in adata.var.columns: + adata.uns["super_modules"] = adata.var["super_modules"].dropna().astype(int).copy() + del adata.var["super_modules"] + + if "lc_zs" in adata.uns: + adata.uns["lc_zs"].index = [ + tuple(g.split(" - ")) if " - " in g else g for g in adata.uns["lc_zs"].index + ] + adata.uns["lc_zs"].columns = [ + tuple(g.split(" - ")) if " - " in g else g for g in adata.uns["lc_zs"].columns + ] + + +def recover_uns_harreman(adata): + uns_keys = ["ligand", "receptor", "LR_database", "import_export"] + for uns_key in uns_keys: + if uns_key in adata.uns.keys(): + adata.uns[uns_key] = adata.uns[uns_key].replace("NA", np.nan) + + if "LR_database" in adata.uns: + original_columns = [ + "interaction_name", + "pathway_name", + "agonist", + "antagonist", + "co_A_receptor", + "co_I_receptor", + "evidence", + "annotation", + "interaction_name_2", + "is_neurotransmitter", + "ligand.symbol", + "ligand.family", + "ligand.location", + "ligand.keyword", + "ligand.secreted_type", + "ligand.transmembrane", + "receptor.symbol", + "receptor.family", + "receptor.location", + "receptor.keyword", + "receptor.surfaceome_main", + "receptor.surfaceome_sub", + "receptor.adhesome", + "receptor.secreted_type", + "receptor.transmembrane", + "version", + ] + mod_columns = [col.replace(".", "_") for col in original_columns] + adata.uns["LR_database"][mod_columns].columns = original_columns + + if "gene_pairs" in adata.uns: + gene_pairs_tmp = [tuple(gp.split("_")) for gp in adata.uns["gene_pairs"]] + gene_pairs_tmp = [ + (x, list(y.split(" - ")) if " - " in y else y) for x, y in gene_pairs_tmp + ] + adata.uns["gene_pairs"] = [ + (list(x.split(" - ")) if " - " in x else x, y) for x, y in gene_pairs_tmp + ] + + if "gene_pairs_sig" in adata.uns: + gene_pairs_sig_tmp = [tuple(gp.split("_")) for gp in adata.uns["gene_pairs_sig"]] + gene_pairs_sig_tmp = [ + (x, list(y.split(" - ")) if " - " in y else y) for x, y in gene_pairs_sig_tmp + ] + adata.uns["gene_pairs_sig"] = [ + (list(x.split(" - ")) if " - " in x else x, y) for x, y in gene_pairs_sig_tmp + ] + + if "gene_pairs_ind" in adata.uns: + gene_pairs_ind_tmp = [tuple(gp.split("_")) for gp in adata.uns["gene_pairs_ind"]] + gene_pairs_ind_tmp = [ + (x, list(int(val) for val in y.split(" - ")) if " - " in y else int(y)) + for x, y in gene_pairs_ind_tmp + ] + adata.uns["gene_pairs_ind"] = [ + (list(int(val) for val in x.split(" - ")) if " - " in x else int(x), y) + for x, y in gene_pairs_ind_tmp + ] + + if "gene_pairs_sig_ind" in adata.uns: + gene_pairs_sig_ind_tmp = [tuple(gp.split("_")) for gp in adata.uns["gene_pairs_sig_ind"]] + gene_pairs_sig_ind_tmp = [ + (x, list(int(val) for val in y.split(" - ")) if " - " in y else int(y)) + for x, y in gene_pairs_sig_ind_tmp + ] + adata.uns["gene_pairs_sig_ind"] = [ + (list(int(val) for val in x.split(" - ")) if " - " in x else int(x), y) + for x, y in gene_pairs_sig_ind_tmp + ] + + def recover_tuple_or_list(g): + return g.split(" - ") if " - " in g else g + + if "gene_pairs_per_metabolite" in adata.uns: + for key, subdict in adata.uns["gene_pairs_per_metabolite"].items(): + subdict["gene_pair"] = [ + (recover_tuple_or_list(gp1), recover_tuple_or_list(gp2)) + for gp1, gp2 in subdict["gene_pair"] + ] + + if "gene_pairs_per_ct_pair" in adata.uns: + adata.uns["gene_pairs_per_ct_pair"] = { + tuple(k.split(" - ")): [ + (recover_tuple_or_list(pair[0]), recover_tuple_or_list(pair[1])) for pair in v + ] + for k, v in adata.uns["gene_pairs_per_ct_pair"].items() + } + + if "gene_pairs_per_ct_pair_ind" in adata.uns: + adata.uns["gene_pairs_per_ct_pair_ind"] = { + tuple(k.split(" - ")): v for k, v in adata.uns["gene_pairs_per_ct_pair_ind"].items() + } + + if "gene_pairs_per_ct_pair_sig_ind" in adata.uns: + adata.uns["gene_pairs_per_ct_pair_sig_ind"] = { + tuple(k.split(" - ")): v + for k, v in adata.uns["gene_pairs_per_ct_pair_sig_ind"].items() + } + + if "gene_pairs_ind_per_ct_pair" in adata.uns: + adata.uns["gene_pairs_ind_per_ct_pair"] = { + tuple(k.split(" - ")): [ + (recover_tuple_or_list(str(pair[0])), recover_tuple_or_list(str(pair[1]))) + for pair in v + ] + for k, v in adata.uns["gene_pairs_ind_per_ct_pair"].items() + } + + if "gene_pairs_ind_per_ct_pair_sig" in adata.uns: + adata.uns["gene_pairs_ind_per_ct_pair_sig"] = { + tuple(k.split(" - ")): [ + (recover_tuple_or_list(str(pair[0])), recover_tuple_or_list(str(pair[1]))) + for pair in v + ] + for k, v in adata.uns["gene_pairs_ind_per_ct_pair_sig"].items() + } + + def restore_list(x): + if isinstance(x, str) and " - " in x: + return x.split(" - ") + return x + + if "ccc_results" in adata.uns: + for key in ["cell_com_df_gp", "cell_com_df_m", "cell_com_df_gp_sig", "cell_com_df_m_sig"]: + if key in adata.uns["ccc_results"]: + df = adata.uns["ccc_results"][key] + adata.uns["ccc_results"][key] = df.applymap(restore_list) + + if "ct_ccc_results" in adata.uns: + for key in ["cell_com_df_gp", "cell_com_df_m", "cell_com_df_gp_sig", "cell_com_df_m_sig"]: + if key in adata.uns["ct_ccc_results"]: + df = adata.uns["ct_ccc_results"][key] + adata.uns["ct_ccc_results"][key] = df.applymap(restore_list) + + return + + +def read_h5ad( + filename: str, +): + """ + Load an AnnData object from disk and restore Harreman-specific metadata. + + Parameters + ---------- + filename : str + Path to the `.h5ad` file to load. + + Returns + ------- + AnnData + The fully restored AnnData object with Harreman metadata recovered. + """ + adata = anndata.read_h5ad(filename) + + if "genes" in adata.uns: + adata.uns["genes"] = list(adata.uns["genes"]) + + recover_uns_hotspot(adata) + recover_uns_harreman(adata) + + return adata diff --git a/src/scvi/external/harreman/preprocessing/database.py b/src/scvi/external/harreman/preprocessing/database.py new file mode 100755 index 0000000000..29258bd305 --- /dev/null +++ b/src/scvi/external/harreman/preprocessing/database.py @@ -0,0 +1,209 @@ +import time +import pooch +from itertools import zip_longest +from re import compile, match +from typing import Literal + +import pandas as pd +from anndata import AnnData + +IMPORT_METAB_KEY = "IMPORT" +EXPORT_METAB_KEY = "EXPORT" +BOTH_METAB_KEY = "IMPORT_EXPORT" + + +def extract_interaction_db( + adata: AnnData, + use_raw: bool = False, + species: Literal["mouse", "human"] | None = None, + database: Literal["transporter", "LR", "both"] | None = None, + extracellular_only: bool | None = True, + verbose: bool | None = False, +) -> None: + """Extract the metabolite transporter or ligand-receptor (LR) database from .csv files. + + Parameters + ---------- + adata + AnnData object to compute database for. + use_raw + Whether to use adata.raw.X for database computation. + species + Species identity to select the LR database from CellChatDB. + database + Whether to use the transporter database, the LR database, or both. + extracellular_only + Whether to restrict the communication inference to extracellular metabolites. + verbose + Whether to print progress and status messages. + + Returns + ------- + Genes by metabolites (or LRs) dataframe. Index is aligned to genes from adata. + + """ + start = time.time() + if verbose: + print("Extracting interaction database...") + + if species not in {"mouse", "human"}: + raise ValueError(f"Unsupported species: {species}. Choose 'mouse' or 'human'.") + if database not in {"transporter", "LR", "both"}: + raise ValueError("Choose one of: 'transporter', 'LR', or 'both'.") + + adata.uns["species"] = species + index = adata.raw.var.index if use_raw else adata.var_names + df_list = [] + + if database in ["LR", "both"]: + extract_lr_pairs(adata, species) + lr_data = build_LR_matrix( + index, adata.uns["LR_database"], adata.uns["ligand"], adata.uns["receptor"] + ) + df_list.append(lr_data) + + if database in ["transporter", "both"]: + metab_dict = extract_transporter_info(adata, species, extracellular_only) + metab_data = build_transporter_matrix(index, metab_dict) + df_list.append(metab_data) + + database_df = pd.concat(df_list, axis=1).fillna(0) + + adata.uns["database_varm_key"] = "database" + adata.uns["database"] = database + adata.varm["database"] = database_df + + if verbose: + print("Finished extracting interaction database in %.3f seconds" % (time.time() - start)) + + return + + +def build_LR_matrix(index, database, ligands, receptors): + matrix = pd.DataFrame(0, index=index, columns=database.index) + matrix.index = matrix.index.str.lower() + for col in matrix.columns: + for key, df, sign in [("ligand", ligands, 1), ("receptor", receptors, -1)]: + genes = df.loc[col].dropna().astype(str).str.lower() + genes = genes[genes.isin(matrix.index)] + matrix.loc[genes, col] = sign + matrix.index = index + return matrix.loc[:, matrix.any()] + + +def build_transporter_matrix(index, metab_dict): + matrix = pd.DataFrame(0, index=index, columns=metab_dict.keys()) + matrix.index = matrix.index.str.lower() + for metab, gene_dirs in metab_dict.items(): + for direction, sign in [ + (IMPORT_METAB_KEY, -1), + (EXPORT_METAB_KEY, 1), + (BOTH_METAB_KEY, 2), + ]: + genes = pd.Index(gene_dirs.get(direction, [])).str.lower() + genes = genes.intersection(matrix.index) + matrix.loc[genes, metab] = sign + matrix.index = index + return matrix.loc[:, matrix.any()] + + +def extract_transporter_info( + adata: AnnData, + species: Literal["mouse", "human"], + extracellular_only: bool = True, + export_suffix: str = "(_exp|_export)", + import_suffix: str = "(_imp|_import)", + verbose: bool = False, +) -> dict[str, dict[str, list[str]]]: + """Read csv file to extract the metabolite database.""" + S3_BASE = "https://scverse-public-data.s3.eu-central-1.amazonaws.com/scvi-tools/harreman/HarremanDB" + cache = pooch.os_cache("scvi_harreman") + + filenames = { + "extracellular": pooch.retrieve(url=f"{S3_BASE}/HarremanDB_{species}_extracellular.csv", known_hash=None, fname=f"HarremanDB_{species}_extracellular.csv", path=cache, progressbar=False), + "all": pooch.retrieve(url=f"{S3_BASE}/HarremanDB_{species}.csv", known_hash=None, fname=f"HarremanDB_{species}.csv", path=cache, progressbar=False), + "heterodimer": pooch.retrieve(url=f"{S3_BASE}/Heterodimer_info_{species}.csv", known_hash=None, fname=f"Heterodimer_info_{species}.csv", path=cache, progressbar=False), + } + df = pd.read_csv(filenames["extracellular" if extracellular_only else "all"], index_col=0) + heterodimer_info = pd.read_csv(filenames["heterodimer"], index_col=0) + df["Metabolite"] = df["Metabolite"].str.replace("/", "_", regex=False) + + pattern_import = compile(r"(\S+)" + import_suffix + "$") + pattern_export = compile(r"(\S+)" + export_suffix + "$") + + metab_dict = {} + for _, row in df.iterrows(): + metabolite, genes_str = row["Metabolite"], row["Gene"] + if not genes_str: + continue + genes = [g.strip() for g in genes_str.split("/") if g.strip()] + genes = sorted(set(genes)) + m = metabolite.lower() + match_import = match(pattern_import, m) + match_export = match(pattern_export, m) + if match_import: + name, direction = match_import.group(1), IMPORT_METAB_KEY + elif match_export: + name, direction = match_export.group(1), EXPORT_METAB_KEY + else: + name, direction = metabolite, BOTH_METAB_KEY + metab_dict.setdefault( + name, {IMPORT_METAB_KEY: [], EXPORT_METAB_KEY: [], BOTH_METAB_KEY: []} + ) + genes_in_var = pd.Series(genes).isin(adata.var_names) + metab_dict[name][direction] = pd.Series(genes)[genes_in_var].tolist() + + adata.uns["metabolite_database"] = df + adata.uns["heterodimer_info"] = heterodimer_info + adata.uns["num_metabolites"] = df.shape[0] + adata.uns["importer"] = build_metabolite_df(metab_dict, IMPORT_METAB_KEY) + adata.uns["exporter"] = build_metabolite_df(metab_dict, EXPORT_METAB_KEY) + adata.uns["import_export"] = build_metabolite_df(metab_dict, BOTH_METAB_KEY) + + return metab_dict + + +def build_metabolite_df(metab_dict, key): + df = pd.DataFrame.from_dict({k: v[key] for k, v in metab_dict.items()}, orient="index").T + df.columns = [f"{key}{i}" for i in range(df.shape[1])] + return df + + +def extract_lr_pairs(adata, species): + """Extracting LR pairs from CellChatDB.""" + S3_BASE = "https://scverse-public-data.s3.eu-central-1.amazonaws.com/scvi-tools/harreman/CellChatDB" + cache = pooch.os_cache("scvi_harreman") + + interaction_path = pooch.retrieve(url=f"{S3_BASE}/interaction_input_CellChatDB_v2_{species}.csv", known_hash=None, fname=f"interaction_input_CellChatDB_v2_{species}.csv", path=cache, progressbar=False) + complex_path = pooch.retrieve(url=f"{S3_BASE}/complex_input_CellChatDB_v2_{species}.csv", known_hash=None, fname=f"complex_input_CellChatDB_v2_{species}.csv", path=cache, progressbar=False) + + interaction = pd.read_csv(interaction_path, index_col=0).sort_values("annotation") + complex = pd.read_csv(complex_path, index_col=0) + + ligands, receptors = interaction.pop("ligand").values, interaction.pop("receptor").values + + for i in range(len(ligands)): + for n in [ligands, receptors]: + l = n[i] + if l in complex.index: + n[i] = ( + complex.loc[l] + .dropna() + .values[pd.Series(complex.loc[l].dropna().values).isin(adata.var_names)] + ) + else: + n[i] = pd.Series(l).values[pd.Series(l).isin(adata.var_names)] + + lig_df = pd.DataFrame.from_records(zip_longest(*pd.Series(ligands).values)).transpose() + rec_df = pd.DataFrame.from_records(zip_longest(*pd.Series(receptors).values)).transpose() + + lig_df.columns = [f"Ligand{i}" for i in range(lig_df.shape[1])] + rec_df.columns = [f"Receptor{i}" for i in range(rec_df.shape[1])] + + lig_df.index = rec_df.index = interaction.index + + adata.uns["ligand"] = lig_df + adata.uns["receptor"] = rec_df + adata.uns["LR_database"] = interaction + + return diff --git a/src/scvi/external/harreman/tools/__init__.py b/src/scvi/external/harreman/tools/__init__.py new file mode 100755 index 0000000000..5ec411c238 --- /dev/null +++ b/src/scvi/external/harreman/tools/__init__.py @@ -0,0 +1,11 @@ +from .cell_communication import ( + apply_gene_filtering, + compute_cell_communication, + compute_ct_cell_communication, + compute_ct_interacting_cell_scores, + compute_gene_pairs, + compute_interacting_cell_scores, + compute_interaction_module_correlation, + select_significant_interactions, +) +from .knn import compute_knn_graph diff --git a/src/scvi/external/harreman/tools/cell_communication.py b/src/scvi/external/harreman/tools/cell_communication.py new file mode 100755 index 0000000000..6aa5f19948 --- /dev/null +++ b/src/scvi/external/harreman/tools/cell_communication.py @@ -0,0 +1,4453 @@ +import ast +import itertools +import time +import warnings +from collections import defaultdict +from functools import partial +from typing import Literal + +import numpy as np +import pandas as pd +import sparse +import torch +from anndata import AnnData +from numba import jit, njit +from scipy.stats import mannwhitneyu, norm, pearsonr, spearmanr +from statsmodels.stats.multitest import multipletests +from tqdm import tqdm + +from ..hotspot import models + + +def _lazy_import_hotspot(): + """Resolve circular imports lazily.""" + global compute_local_autocorrelation, standardize_counts + from ..hotspot.local_autocorrelation import ( + compute_local_autocorrelation as _cla, + ) + from ..hotspot.local_autocorrelation import ( + standardize_counts as _sc, + ) + + compute_local_autocorrelation = _cla + standardize_counts = _sc + + +compute_local_autocorrelation = None +standardize_counts = None + +from ..preprocessing.anndata import counts_from_anndata +from ..tools.knn import make_weights_non_redundant + + +def apply_gene_filtering( + adata: AnnData, + layer_key: Literal["use_raw"] | str | None = None, + cell_type_key: str | None = None, + model: str | None = None, + feature_elimination: bool | None = False, + threshold: float | None = 0.2, + autocorrelation_filt: bool | None = False, + expression_filt: bool | None = False, + de_filt: bool | None = False, + umi_counts_obs_key: str | None = None, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Applies multi-step gene filtering to an AnnData object. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). + layer_key : str, optional + Key to use from `adata.layers` or `"use_raw"` to use `adata.raw.X`. + cell_type_key : str, optional + Key in `adata.obs` containing cell type annotations. + model : str, optional + Model name for autocorrelation computation. + feature_elimination : bool, optional (default: False) + If True, filters genes based on sparsity across all cells. + threshold : float, optional (default: 0.2) + Minimum fraction of cells in which the gene must be expressed. + autocorrelation_filt : str, optional (default: False) + If True, filters genes based on spatial autocorrelation significance. + expression_filt : str, optional (default: False) + If True, filters genes based on expression in each cell type. + de_filt : str, optional (default: False) + If True, filters genes based on differential expression between each cell type and the rest. + umi_counts_obs_key : str, optional + Key in `adata.obs` with total UMI counts per cell. If `None`, inferred from the expression matrix. + device : torch.device, optional + Device to use for computation (e.g., CUDA or CPU). Defaults to GPU if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + """ + start = time.time() + if verbose: + print("Applying gene filtering...") + + adata.uns["autocorrelation_filt"] = autocorrelation_filt + adata.uns["expression_filt"] = expression_filt + adata.uns["de_filt"] = de_filt + + db_key = adata.uns["database_varm_key"] + + if feature_elimination: + perform_feature_elimination(adata, layer_key, db_key, threshold) + + _lazy_import_hotspot() + if autocorrelation_filt: + compute_local_autocorrelation( + adata=adata, + layer_key=layer_key, + database_varm_key=db_key, + model=model, + umi_counts_obs_key=umi_counts_obs_key, + device=device, + verbose=verbose, + ) + + if expression_filt or de_filt: + if cell_type_key is None: + cell_type_key = adata.uns.get("cell_type_key") + if cell_type_key is None: + raise ValueError('The "cell_type_key" argument needs to be provided.') + + filtered_genes, filtered_genes_ct = filter_genes( + adata, layer_key, db_key, cell_type_key, expression_filt, de_filt, autocorrelation_filt + ) + adata.uns["filtered_genes"] = filtered_genes + adata.uns["filtered_genes_ct"] = filtered_genes_ct + + if verbose: + print("Finished applying gene filtering in %.3f seconds" % (time.time() - start)) + + return + + +def perform_feature_elimination(adata, layer_key, database_varm_key, threshold): + """ + Filters out genes that are too sparse across all cells. + + Parameters + ---------- + adata + Annotated data object (AnnData). + layer_key + Which layer to use (or "use_raw"). + database_varm_key + Key in `adata.varm` pointing to relevant features to filter. + threshold + Minimum fraction of cells in which the gene must be expressed. + """ + use_raw = layer_key == "use_raw" + + metab_matrix = adata.raw.varm[database_varm_key] if use_raw else adata.varm[database_varm_key] + genes = metab_matrix.loc[(metab_matrix != 0).any(axis=1)].index + + counts = counts_from_anndata(adata[:, genes], layer_key, dense=True) + + valid_genes = genes[filter_expr_matrix(counts, threshold=threshold)] + + adata.varm[database_varm_key][~adata.var_names.isin(valid_genes)] = 0 + + return + + +def filter_genes( + adata, + layer_key, + database_varm_key, + cell_type_key, + expression_filt, + de_filt, + autocorrelation_filt, +): + """ + Applies expression and/or DE filtering per cell type. + + Parameters + ---------- + adata + Annotated data object (AnnData). + layer_key + Which layer to use (or "use_raw"). + database_varm_key + Key in `adata.varm` pointing to gene database features. + cell_type_key + Key in `adata.obs` with cell type labels. + expression_filt + Whether to filter based on expression sparsity in each cell type. + de_filt + Whether to filter based on differential expression. + autocorrelation_filt + Whether to restrict to spatially autocorrelated genes. + + Returns + ------- + filtered_genes + List of genes retained across any cell type. + filtered_genes_ct + Dict mapping cell types to their filtered genes. + """ + if autocorrelation_filt: + autocor_results = adata.uns["gene_autocorrelation_results"] + sig_genes = autocor_results.query("Z_FDR < 0.05").index + if len(sig_genes) == 0: + raise ValueError("There are no significantly autocorrelated genes.") + + else: + use_raw = layer_key == "use_raw" + db = adata.raw.varm[database_varm_key] if use_raw else adata.varm[database_varm_key] + sig_genes = db.loc[(db != 0).any(axis=1)].index + + counts = counts_from_anndata(adata[:, sig_genes], layer_key, dense=True) + + cell_types = adata.obs[cell_type_key].values + unique_cts = np.unique(cell_types) + filtered_genes_ct = {} + gene_idx = {g: i for i, g in enumerate(sig_genes)} + + # Precompute masks + masks = {ct: np.where(cell_types == ct)[0] for ct in unique_cts} + not_masks = {ct: np.where(cell_types != ct)[0] for ct in unique_cts} + + if expression_filt: + expr_mask = {ct: filter_expr_matrix(counts[:, masks[ct]], 0.2) for ct in unique_cts} + if de_filt: + de_stats = { + ct: de_threshold(counts[:, masks[ct]], counts[:, not_masks[ct]]) for ct in unique_cts + } + + filtered_genes = set() + for ct in unique_cts: + genes_ct = sig_genes + gene_mask = np.ones(len(sig_genes), dtype=bool) + + if expression_filt: + gene_mask &= expr_mask[ct] + + if de_filt: + stat, pval, cd = de_stats[ct] + fdr = multipletests(pval, method="fdr_bh")[1] + gene_mask &= (fdr < 0.05) & (cd > 0) + + selected = sig_genes[gene_mask] + filtered_genes_ct[ct] = selected.tolist() + filtered_genes.update(selected) + + return sorted(filtered_genes), filtered_genes_ct + + +def filter_expr_matrix(matrix, threshold): + + return (matrix > 0).sum(axis=1) / matrix.shape[1] >= threshold + + +@njit(parallel=True) +def cohens_d(x, y): + + out = np.empty(x.shape[0]) + + for i in range(x.shape[0]): + nx, ny = len(x[i]), len(y[i]) + vx, vy = np.var(x[i], ddof=1), np.var(y[i], ddof=1) + pooled = np.sqrt(((nx - 1) * vx + (ny - 1) * vy) / (nx + ny - 2)) + out[i] = (np.mean(x[i]) - np.mean(y[i])) / pooled if pooled > 0 else 0 + + return out + + +def de_threshold(counts_ct, counts_no_ct): + + stat = np.array( + [ + mannwhitneyu(counts_ct[i], counts_no_ct[i], alternative="greater").statistic + for i in range(counts_ct.shape[0]) + ] + ) + pval = np.array( + [ + mannwhitneyu(counts_ct[i], counts_no_ct[i], alternative="greater").pvalue + for i in range(counts_ct.shape[0]) + ] + ) + cd = cohens_d(counts_ct, counts_no_ct) + + return stat, pval, cd + + +def compute_gene_pairs( + adata: AnnData, + layer_key: Literal["use_raw"] | str | None = None, + cell_type_key: str | None = None, + cell_type_pairs: list | None = None, + ct_specific: bool | None = True, + fix_ct: Literal["all"] | str | None = None, + verbose: bool | None = False, +): + """ + Identifies biologically plausible gene pairs involved in ligand-receptor (LR) signaling or + metabolite transport based on annotated interaction databases and filtered expression data. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). Must include: + - `varm["database"]`: DataFrame indicating gene involvement in interactions. + - `uns["database"]`: 'LR', 'transporter', or 'both'. + - `uns["ligand"]`, `uns["receptor"]` for LR pairs if applicable. + - `uns["metabolite_database"]` and/or `uns["LR_database"]` for pair categorization. + - `obsp["weights"]`: spatial proximity weights. + layer_key : str or "use_raw", optional + Specifies the layer or raw data to use for expression filtering. + cell_type_key : str, optional + Key in `adata.obs` indicating cell type annotation. + cell_type_pairs : list of tuple, optional + List of tuples specifying cell type pairs to consider. If not provided, all combinations are used. + ct_specific : bool, optional (default: True) + If True, restrict gene pair computation to combinations relevant to the given cell type annotations. + fix_ct : str, optional + Whether to restrict the cell type pairs to a particular cell type. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + Results are stored in the following keys in `adata.uns`: `lcs`, `lc_zs`, `lc_z_pvals`, and `lc_z_FDR`. + """ + start = time.time() + if verbose: + print("Computing gene pairs...") + + from_value_to_type = { + "LR": {-1.0: "REC", 1.0: "LIG"}, + "transporter": {-1.0: "IMP", 1.0: "EXP", 2.0: "IMP-EXP"}, + } + + # Setup + layer_key = layer_key or adata.uns.get("layer_key") + use_raw = layer_key == "use_raw" + genes = adata.raw.var.index if use_raw else adata.var_names + adata.uns["fix_ct"] = fix_ct + + if ct_specific: + cell_type_key = cell_type_key or adata.uns.get("cell_type_key") + if cell_type_key is None: + raise ValueError('Please provide the "cell_type_key" argument.') + adata.uns["cell_type_key"] = cell_type_key + cell_types = adata.obs[cell_type_key] if not use_raw else adata.raw.obs[cell_type_key] + cell_types = cell_types.values.astype(str) + + database = adata.varm["database"] + + heterodimer_info = adata.uns.get("heterodimer_info") + if heterodimer_info is not None: + heterodimer_info = heterodimer_info.copy() + heterodimer_info["Genes"] = heterodimer_info["Genes"].apply(ast.literal_eval) + + # Filters + autocorrelation_filt = adata.uns.get("autocorrelation_filt", False) + expression_filt = adata.uns.get("expression_filt", False) + de_filt = adata.uns.get("de_filt", False) + + if expression_filt or de_filt: + filtered_genes = adata.uns["filtered_genes"] + filtered_genes_ct = adata.uns["filtered_genes_ct"] + elif autocorrelation_filt: + autocor_results = adata.uns["gene_autocorrelation_results"] + filtered_genes = autocor_results[autocor_results.Z_FDR < 0.05].index.tolist() + else: + filtered_genes = list(genes) + + if not filtered_genes: + raise ValueError("No genes have passed the filters.") + + filtered_genes_set = set(filtered_genes) + all_genes_set = set(genes) + non_sig_genes = list(all_genes_set - filtered_genes_set) + + # Filter out uninformative metabolites + database.loc[non_sig_genes] = 0 + cols_keep = [ + col + for col in database.columns + if ( + (np.unique(database[col]) != 0).sum() > 1 + or database[col][database[col] != 0].unique().tolist() == [2] + ) + ] + database = database[cols_keep].copy() + adata.varm["database"] = database + + if ct_specific and "filtered_genes_ct" not in adata.uns: + filtered_genes_ct = dict.fromkeys(np.unique(cell_types), filtered_genes) + else: + filtered_genes_ct = adata.uns.get("filtered_genes_ct", {}) + + weights = adata.obsp["weights"] + if ct_specific: + if cell_type_pairs is None: + cell_type_list = list(filtered_genes_ct) + if fix_ct: + cell_type_pairs = list(itertools.product(cell_type_list, repeat=2)) + if fix_ct != "all": + cell_type_pairs = [pair for pair in cell_type_pairs if pair[0] == fix_ct] + else: + cell_type_pairs = list(itertools.combinations_with_replacement(cell_type_list, 2)) + cell_type_pairs_df = pd.Series(cell_type_pairs) + valid_mask = cell_type_pairs_df.apply( + get_interacting_cell_type_pairs, args=(weights, cell_types) + ) + cell_type_pairs = cell_type_pairs_df[valid_mask].tolist() + + # Setup for result aggregation + gene_pairs_per_metabolite = {} + gene_pairs = [] + ct_pairs = [] + gene_pairs_per_ct_pair = {} if ct_specific else None + + if adata.uns["database"] == "both": + metabolites_set = set(adata.uns["metabolite_database"].Metabolite) + LR_pairs_set = set(adata.uns["LR_database"].index) + + for metabolite in database.columns: + metab_genes = database.index[database[metabolite] != 0].tolist() + if not metab_genes: + continue + + gene_pairs_per_metabolite[metabolite] = {"gene_pair": [], "gene_type": []} + + if adata.uns["database"] == "both": + if metabolite in metabolites_set: + int_type = "transporter" + elif metabolite in LR_pairs_set: + int_type = "LR" + else: + raise ValueError( + 'The "metabolite" variable needs to be either a metabolite or a LR pair.' + ) + else: + int_type = adata.uns["database"] + + # Build gene pairs + if int_type == "transporter": + if ( + heterodimer_info is not None + and metabolite in heterodimer_info["Metabolite"].values + ): + for genes_list in heterodimer_info[heterodimer_info["Metabolite"] == metabolite][ + "Genes" + ]: + if all(g in metab_genes for g in genes_list): + metab_genes = [g for g in metab_genes if g not in genes_list] + [ + tuple(genes_list) + ] + + combos = ( + set(itertools.combinations_with_replacement(metab_genes, 2)) + | set(itertools.permutations(metab_genes, 2)) + if ct_specific + else set(itertools.combinations_with_replacement(metab_genes, 2)) + ) + all_pairs = [ + (list(x) if isinstance(x, tuple) else x, list(y) if isinstance(y, tuple) else y) + for x, y in combos + ] + + else: # LR + ligand = adata.uns["ligand"].loc[metabolite].dropna().tolist() + ligand = ligand[0] if len(ligand) == 1 else ligand + receptor = adata.uns["receptor"].loc[metabolite].dropna().tolist() + receptor = receptor[0] if len(receptor) == 1 else receptor + # if len(ligand) == 0 or len(receptor) == 0: + if not ligand or not receptor: + continue + all_pairs = ( + [(ligand, receptor), (receptor, ligand)] if ct_specific else [(ligand, receptor)] + ) + + # Evaluate gene pairs + for var1, var2 in all_pairs: + + def extract_val(var): + if isinstance(var, str): + return database.at[var, metabolite] + else: + vals = list( + { + database.at[v, metabolite] + for v in var + if database.at[v, metabolite] != 0 + } + ) + return vals[0] if len(vals) == 1 else vals + + val1 = extract_val(var1) + val2 = extract_val(var2) + + if not val1 or not val2: + continue + if val1 == val2 and val1 in (1.0, -1.0): + continue + + type1 = from_value_to_type[int_type].get(val1) + type2 = from_value_to_type[int_type].get(val2) + + gene_pairs_per_metabolite[metabolite]["gene_pair"].append((var1, var2)) + gene_pairs_per_metabolite[metabolite]["gene_type"].append((type1, type2)) + + if (var1, var2) not in gene_pairs: + gene_pairs.append((var1, var2)) + if ct_specific: + for ct1, ct2 in cell_type_pairs: + in_ct1 = ( + var1 in filtered_genes_ct[ct1] + if isinstance(var1, str) + else any(v in filtered_genes_ct[ct1] for v in var1) + ) + in_ct2 = ( + var2 in filtered_genes_ct[ct2] + if isinstance(var2, str) + else any(v in filtered_genes_ct[ct2] for v in var2) + ) + if in_ct1 and in_ct2: + if (ct1, ct2) not in ct_pairs: + ct_pairs.append((ct1, ct2)) + gene_pairs_per_ct_pair.setdefault((ct1, ct2), []).append((var1, var2)) + + # Save results + adata.uns.setdefault("gene_pairs", gene_pairs) + if ct_specific: + adata.uns.setdefault("cell_type_pairs", ct_pairs) + adata.uns.setdefault("gene_pairs_per_ct_pair", gene_pairs_per_ct_pair) + adata.uns.setdefault("gene_pairs_per_metabolite", gene_pairs_per_metabolite) + + if verbose: + print("Finished computing gene pairs in %.3f seconds" % (time.time() - start)) + + return + + +def compute_cell_communication( + adata: AnnData, + layer_key_p_test: Literal["use_raw"] | str | None = None, + layer_key_np_test: Literal["use_raw"] | str | None = None, + model: str = None, + center_counts_for_np_test: bool | None = False, + subset_gene_pairs: str | None = None, + M: int | None = 1000, + seed: int | None = 42, + test: Literal["parametric"] | Literal["non-parametric"] | Literal["both"] | None = "both", + mean: Literal["algebraic"] | Literal["geometric"] | None = "algebraic", + check_analytic_null: bool | None = False, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Computes spatially-informed cell-type-agnostic cell-cell communication (CCC) scores and + significance across all gene pairs using both parametric and non-parametric statistical tests. + + Parameters + ---------- + adata : AnnData + Annotated data object. Required fields include: + - `uns["gene_pairs"]`: list of gene pairs to evaluate. + - `uns["gene_pairs_per_metabolite"]`: dictionary mapping metabolites to gene pairs. + - `obsp["weights"]`: sparse matrix encoding spatial cell-cell proximity. + - (Optional) `uns["LR_database"]`: interaction metadata for pathway annotation. + - (Optional) `uns["sample_key"]`: if modeling includes sample-specific factors. + layer_key_p_test : str or "use_raw", optional + Data layer to use for the parametric test. If `"use_raw"`, uses `adata.raw`. + layer_key_np_test : str or "use_raw", optional + Data layer to use for the non-parametric test. If `"use_raw"`, uses `adata.raw`. + model : str, optional + Normalization model to use for centering gene expression. Options include "none", "normal", "bernoulli", or "danb". + center_counts_for_np_test : bool, optional (default: False) + Whether to center expression counts using the specified model before non-parametric testing. + subset_gene_pairs : list, optional + If provided, restricts the analysis to this subset of gene pairs. + M : int, optional (default: 1000) + Number of permutations to use if `permutation_test` is True. + seed : int, optional (default: 42) + Random seed for permutation reproducibility. + test : {'parametric', 'non-parametric', 'both'}, optional (default: 'both') + Specifies which statistical test(s) to run. + mean : {'algebraic', 'geometric'}, optional (default: 'algebraic') + Averaging method for multi-gene interactions. + check_analytic_null : bool, optional (default: False) + Whether to evaluate Z-scores under an analytic null distribution using permutation Z-scores. + device : torch.device, optional + PyTorch device to run computations on. Defaults to CUDA if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + Results are stored in the following `adata.uns` fields: + - `uns["ccc_results"]["p"]`: Parametric test results (gene pair and metabolite scores, Z, p-values, FDR). + - `uns["ccc_results"]["np"]`: Non-parametric test results (communication scores, empirical p-values, FDR). + - `uns["lc_zs"]`: Symmetric matrix of ligand-receptor Z-scores across genes. + - `uns["gene_pair_dict"]`: Dictionary mapping metabolites to index positions of gene pairs. + - `uns["D"]`: Vector of total node degrees per cell (spatial connectivity). + - `uns["genes"]`: Ordered list of involved genes. + - `uns["gene_pairs_ind"]`: Index-referenced version of `uns["gene_pairs"]`. + """ + start = time.time() + if verbose: + print("Starting cell-cell communication analysis...") + + adata.uns["ccc_results"] = {} + + if test not in ["both", "parametric", "non-parametric"]: + raise ValueError( + 'The "test" variable should be one of ["both", "parametric", "non-parametric"].' + ) + + if mean not in ["algebraic", "geometric"]: + raise ValueError('The "mean" variable should be one of ["algebraic", "geometric"].') + + adata.uns["layer_key_p_test"] = layer_key_p_test + adata.uns["layer_key_np_test"] = layer_key_np_test + adata.uns["model"] = model + adata.uns["center_counts_for_np_test"] = center_counts_for_np_test + adata.uns["mean"] = mean + + run_cell_communication_analysis( + adata, + layer_key_p_test, + layer_key_np_test, + model, + center_counts_for_np_test, + subset_gene_pairs, + M, + seed, + test, + mean, + check_analytic_null, + device, + verbose, + ) + + if verbose: + print("Obtaining communication results...") + get_cell_communication_results( + adata, + adata.uns["genes"], + layer_key_p_test, + layer_key_np_test, + model, + adata.uns["D"], + test, + device, + ) + + if verbose: + print( + "Finished computing cell-cell communication analysis in %.3f seconds" + % (time.time() - start) + ) + + return + + +def run_cell_communication_analysis( + adata, + layer_key_p_test, + layer_key_np_test, + model, + center_counts_for_np_test, + subset_gene_pairs, + M, + seed, + test, + mean, + check_analytic_null, + device, + verbose, +): + + use_raw = (layer_key_p_test == "use_raw") & (layer_key_np_test == "use_raw") + + cells = ( + adata.raw.obs.index.values.astype(str) if use_raw else adata.obs_names.values.astype(str) + ) + + sample_specific = "sample_key" in adata.uns + + gene_pairs = adata.uns["gene_pairs"] if subset_gene_pairs is None else subset_gene_pairs + genes = list(np.unique(list(flatten(adata.uns["gene_pairs"])))) + adata.uns["genes"] = genes + adata.uns["cells"] = cells + + # Map gene_pairs to index + gene_pairs_ind = [] + for pair in gene_pairs: + idx1 = ( + [genes.index(g) for g in pair[0]] + if isinstance(pair[0], list) + else genes.index(pair[0]) + ) + idx2 = ( + [genes.index(g) for g in pair[1]] + if isinstance(pair[1], list) + else genes.index(pair[1]) + ) + gene_pairs_ind.append((idx1, idx2)) + adata.uns["gene_pairs_ind"] = gene_pairs_ind + + # Compute weights + weights = make_weights_non_redundant(adata.obsp["weights"]).tocoo() + weights = torch.sparse_coo_tensor( + torch.tensor(np.vstack((weights.row, weights.col)), dtype=torch.long, device=device), + torch.tensor(weights.data, dtype=torch.float64, device=device), + torch.Size(weights.shape), + device=device, + ) + + # Compute node degree + row_degrees = torch.sparse.sum(weights, dim=1).to_dense() + col_degrees = torch.sparse.sum(weights, dim=0).to_dense() + D = row_degrees + col_degrees + + adata.uns["D"] = D.cpu().numpy() + + gene_pairs_per_metabolite = adata.uns["gene_pairs_per_metabolite"] + + metabolite_gene_pair_df = pd.DataFrame.from_dict( + gene_pairs_per_metabolite, orient="index" + ).reset_index() + metabolite_gene_pair_df = metabolite_gene_pair_df.rename(columns={"index": "metabolite"}) + + metabolite_gene_pair_df["gene_pair"] = metabolite_gene_pair_df["gene_pair"].apply( + lambda arr: [(sub_array[0], sub_array[1]) for sub_array in arr] + ) + metabolite_gene_pair_df["gene_type"] = metabolite_gene_pair_df["gene_type"].apply( + lambda arr: [(sub_array[0], sub_array[1]) for sub_array in arr] + ) + + metabolite_gene_pair_df = pd.concat( + [ + metabolite_gene_pair_df["metabolite"], + metabolite_gene_pair_df.explode("gene_pair")["gene_pair"], + metabolite_gene_pair_df.explode("gene_type")["gene_type"], + ], + axis=1, + ) + metabolite_gene_pair_df = metabolite_gene_pair_df.reset_index(drop=True) + + if "LR_database" in adata.uns.keys(): + LR_database = adata.uns["LR_database"] + df_merged = pd.merge( + metabolite_gene_pair_df, + LR_database, + left_on="metabolite", + right_on="interaction_name", + how="left", + ) + LR_df = df_merged.dropna(subset=["pathway_name"]) + metabolite_gene_pair_df["metabolite"][ + metabolite_gene_pair_df.metabolite.isin(LR_df.metabolite) + ] = LR_df["pathway_name"] + + gene_pair_dict = {} + for metabolite, group in metabolite_gene_pair_df.groupby("metabolite"): + idxs = ( + group["gene_pair"] + .apply(lambda gp: gene_pairs.index(gp) if gp in gene_pairs else None) + .dropna() + .tolist() + ) + idxs = [int(ind) for ind in idxs if ind is not None] + if idxs: + gene_pair_dict[metabolite] = idxs + + adata.uns["gene_pair_dict"] = gene_pair_dict + + if test in ["parametric", "both"]: + if verbose: + print("Running the parametric test...") + + adata.uns["ccc_results"]["p"] = {"gp": {}, "m": {}} + + Wtot2 = torch.tensor((weights.data**2).sum(), device=device) + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_p_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + num_umi = counts.sum(dim=0) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_ind: + if isinstance(idx1, list): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, list): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + # Standardize counts + _lazy_import_hotspot() + counts_1 = standardize_counts(adata, counts_1, model, num_umi, sample_specific) + _lazy_import_hotspot() + counts_2 = standardize_counts(adata, counts_2, model, num_umi, sample_specific) + + # Compute CCC scores + WX2t = torch.sparse.mm(weights, counts_2.T) + WtX2t = torch.sparse.mm(weights.transpose(0, 1), counts_2.T) + cs_gp = (counts_1.T * WX2t).sum(0) + (counts_1.T * WtX2t).sum(0) + same_gene_mask = torch.tensor([g1 == g2 for g1, g2 in gene_pairs], device=device) + cs_gp[same_gene_mask] = cs_gp[same_gene_mask] / 2 + adata.uns["ccc_results"]["p"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + + # Compute metabolite-level scores + cs_m = compute_metabolite_cs(cs_gp, gene_pair_dict, interacting_cell_scores=False) + adata.uns["ccc_results"]["p"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + # Compute second moments + WX1t = torch.sparse.mm(weights, counts_1.T) + WtX1t = torch.sparse.mm(weights.transpose(0, 1), counts_1.T) + eg2_a = (WX1t + WtX1t).pow(2).sum(dim=0) + eg2_b = (WX2t + WtX2t).pow(2).sum(dim=0) + eg2s_gp = (eg2_a, eg2_b) + + # Z-score computation + Z_gp, Z_m = compute_p_results(cs_gp, cs_m, gene_pairs_ind, Wtot2, eg2s_gp, gene_pair_dict) + # Convert tensors to numpy for statsmodels and pandas + Z_gp_np = Z_gp.detach().cpu().numpy() + Z_m_np = Z_m.detach().cpu().numpy() + # Compute p-values and FDRs + Z_pvals_gp = norm.sf(Z_gp_np) + Z_pvals_m = norm.sf(Z_m_np) + FDR_gp = multipletests(Z_pvals_gp, method="fdr_bh")[1] + FDR_m = multipletests(Z_pvals_m, method="fdr_bh")[1] + + # Store in AnnData + adata.uns["ccc_results"]["p"]["gp"]["Z"] = Z_gp_np + adata.uns["ccc_results"]["p"]["gp"]["Z_pval"] = Z_pvals_gp + adata.uns["ccc_results"]["p"]["gp"]["Z_FDR"] = FDR_gp + adata.uns["ccc_results"]["p"]["m"]["Z"] = Z_m_np + adata.uns["ccc_results"]["p"]["m"]["Z_pval"] = Z_pvals_m + adata.uns["ccc_results"]["p"]["m"]["Z_FDR"] = FDR_m + + # Symmetric LC Z-score matrix + genes_ = [ + tuple(i) if isinstance(i, list) else i + for i in pd.Series([g for pair in gene_pairs for g in pair]).drop_duplicates() + ] + gene_pairs_ = [ + (tuple(a) if isinstance(a, list) else a, tuple(b) if isinstance(b, list) else b) + for a, b in gene_pairs + ] + lc_zs = pd.DataFrame(np.zeros((len(genes_), len(genes_))), index=genes_, columns=genes_) + for i, (g1, g2) in enumerate(gene_pairs_): + lc_zs.iloc[genes_.index(g1), genes_.index(g2)] = Z_gp_np[i] + # Force diagonal to 0 and symmetrize + np.fill_diagonal(lc_zs.values, 0) + adata.uns["lc_zs"] = (lc_zs + lc_zs.T) / 2 + + if verbose: + print("Parametric test finished.") + + if test in ["non-parametric", "both"]: + if verbose: + print("Running the non-parametric test...") + + adata.uns["ccc_results"]["np"] = {"gp": {}, "m": {}} + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_np_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_ind: + if isinstance(idx1, list): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, list): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + if center_counts_for_np_test: + num_umi = counts.sum(dim=0) + _lazy_import_hotspot() + counts_1 = standardize_counts(adata, counts_1, model, num_umi, sample_specific) + _lazy_import_hotspot() + counts_2 = standardize_counts(adata, counts_2, model, num_umi, sample_specific) + + n_cells = counts_1.shape[1] + same_gene_mask = torch.tensor([g1 == g2 for g1, g2 in gene_pairs], device=device) + + if center_counts_for_np_test and test == "both": + adata.uns["ccc_results"]["np"]["gp"]["cs"] = np.array( + adata.uns["ccc_results"]["p"]["gp"]["cs"] + ) + adata.uns["ccc_results"]["np"]["m"]["cs"] = np.array( + adata.uns["ccc_results"]["p"]["m"]["cs"] + ) + else: + WX2t = torch.sparse.mm(weights, counts_2.T) + WtX2t = torch.sparse.mm(weights.transpose(0, 1), counts_2.T) + cs_gp = (counts_1.T * WX2t).sum(0) + (counts_1.T * WtX2t).sum(0) + cs_gp[same_gene_mask] = cs_gp[same_gene_mask] / 2 + adata.uns["ccc_results"]["np"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + cs_m = compute_metabolite_cs(cs_gp, gene_pair_dict, interacting_cell_scores=False) + adata.uns["ccc_results"]["np"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + perm_cs_gp_a = torch.zeros((counts_1.shape[0], M), dtype=torch.float64, device=device) + perm_cs_gp_b = torch.zeros_like(perm_cs_gp_a) + perm_cs_m_a = torch.zeros((len(gene_pair_dict), M), dtype=torch.float64, device=device) + perm_cs_m_b = torch.zeros_like(perm_cs_m_a) + + if check_analytic_null: + gp_zs_perm_array = torch.zeros_like(perm_cs_gp_a) + gp_pvals_perm_array = torch.zeros_like(perm_cs_gp_a) + m_zs_perm_array = torch.zeros_like(perm_cs_m_a) + m_pvals_perm_array = torch.zeros_like(perm_cs_m_a) + + torch.manual_seed(seed) + for i in tqdm(range(M), desc="Permutation test"): + idx = torch.randperm(n_cells, device=device) + + c1_perm_a = counts_1.clone() + c2_perm_a = counts_2[:, idx] + c1_perm_a[same_gene_mask] = counts_1[same_gene_mask, :][:, idx] + + WX2t_a = torch.sparse.mm(weights, c2_perm_a.T) + WtX2t_a = torch.sparse.mm(weights.transpose(0, 1), c2_perm_a.T) + cs_a = (c1_perm_a.T * WX2t_a).sum(0) + (c1_perm_a.T * WtX2t_a).sum(0) + cs_a[same_gene_mask] = cs_a[same_gene_mask] / 2 + perm_cs_gp_a[:, i] = cs_a + + cs_m_a = compute_metabolite_cs(cs_a, gene_pair_dict, interacting_cell_scores=False) + perm_cs_m_a[:, i] = cs_m_a + + c2_perm_b = counts_2.clone() + c1_perm_b = counts_1[:, idx] + c2_perm_b[same_gene_mask] = counts_2[same_gene_mask, :][:, idx] + + WX2t_b = torch.sparse.mm(weights, c2_perm_b.T) + WtX2t_b = torch.sparse.mm(weights.transpose(0, 1), c2_perm_b.T) + cs_b = (c1_perm_b.T * WX2t_b).sum(0) + (c1_perm_b.T * WtX2t_b).sum(0) + cs_b[same_gene_mask] = cs_b[same_gene_mask] / 2 + perm_cs_gp_b[:, i] = cs_b + + cs_m_b = compute_metabolite_cs(cs_b, gene_pair_dict, interacting_cell_scores=False) + perm_cs_m_b[:, i] = cs_m_b + + if check_analytic_null: + Z_gp_perm, Z_m_perm = compute_p_results( + (cs_a, cs_b), (cs_m_a, cs_m_b), gene_pairs_ind, Wtot2, eg2s_gp, gene_pair_dict + ) + gp_zs_perm_array[:, i] = Z_gp_perm + gp_pvals_perm_array[:, i] = torch.tensor( + norm.sf(Z_gp_perm.cpu().numpy()), device=device + ) + m_zs_perm_array[:, i] = Z_m_perm + m_pvals_perm_array[:, i] = torch.tensor( + norm.sf(Z_m_perm.cpu().numpy()), device=device + ) + + adata.uns["ccc_results"]["np"]["gp"]["perm_cs_a"] = perm_cs_gp_a.detach().cpu().numpy() + adata.uns["ccc_results"]["np"]["gp"]["perm_cs_b"] = perm_cs_gp_b.detach().cpu().numpy() + adata.uns["ccc_results"]["np"]["m"]["perm_cs_a"] = perm_cs_m_a.detach().cpu().numpy() + adata.uns["ccc_results"]["np"]["m"]["perm_cs_b"] = perm_cs_m_b.detach().cpu().numpy() + + x_gp_a = (perm_cs_gp_a > cs_gp[:, None]).sum(dim=1) + x_gp_b = (perm_cs_gp_b > cs_gp[:, None]).sum(dim=1) + x_m_a = (perm_cs_m_a > cs_m[:, None]).sum(dim=1) + x_m_b = (perm_cs_m_b > cs_m[:, None]).sum(dim=1) + + pvals_gp_a = (x_gp_a + 1).float() / (M + 1) + pvals_gp_b = (x_gp_b + 1).float() / (M + 1) + pvals_m_a = (x_m_a + 1).float() / (M + 1) + pvals_m_b = (x_m_b + 1).float() / (M + 1) + + pvals_gp = torch.where(pvals_gp_a > pvals_gp_b, pvals_gp_a, pvals_gp_b) + pvals_m = torch.where(pvals_m_a > pvals_m_b, pvals_m_a, pvals_m_b) + + adata.uns["ccc_results"]["np"]["gp"]["pval"] = pvals_gp.cpu().numpy() + adata.uns["ccc_results"]["np"]["gp"]["FDR"] = multipletests( + pvals_gp.cpu().numpy(), method="fdr_bh" + )[1] + adata.uns["ccc_results"]["np"]["m"]["pval"] = pvals_m.cpu().numpy() + adata.uns["ccc_results"]["np"]["m"]["FDR"] = multipletests( + pvals_m.cpu().numpy(), method="fdr_bh" + )[1] + + if check_analytic_null: + adata.uns["ccc_results"]["np"]["analytic_null"] = { + "gp_zs_perm": gp_zs_perm_array.detach().cpu().numpy(), + "gp_pvals_perm": gp_pvals_perm_array.detach().cpu().numpy(), + "m_zs_perm": m_zs_perm_array.detach().cpu().numpy(), + "m_pvals_perm": m_pvals_perm_array.detach().cpu().numpy(), + } + + if verbose: + print("Non-parametric test finished.") + + return + + +def compute_ct_cell_communication( + adata: AnnData, + layer_key_p_test: Literal["use_raw"] | str | None = None, + layer_key_np_test: Literal["use_raw"] | str | None = None, + model: str = None, + cell_type_key: str | None = None, + center_counts_for_np_test: bool | None = False, + subset_gene_pairs: list | None = None, + subset_metabolites: list | None = None, + fix_gp: bool | None = False, + M: int | None = 1000, + seed: int | None = 42, + test: Literal["parametric"] | Literal["non-parametric"] | Literal["both"] | None = "both", + mean: Literal["algebraic"] | Literal["geometric"] | None = "algebraic", + check_analytic_null: bool | None = False, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Computes cell type-aware cell-cell communication (CCC) scores by stratifying communication + by interacting cell type pairs. Supports parametric and non-parametric statistical inference. + + Parameters + ---------- + adata : AnnData + Annotated data object. Required fields include: + - `uns["gene_pairs"]`: gene pairs involved in communication. + - `uns["gene_pairs_per_metabolite"]`: maps metabolites to gene pairs. + - `uns["gene_pairs_per_ct_pair"]`: gene pairs per cell type pair. + - `obsp["weights"]`: sparse cell-cell proximity matrix. + - `obs[cell_type_key]`: categorical cell type annotations. + - `uns["cell_type_pairs"]`: list of interacting cell type pairs. + - (Optional) `uns["LR_database"]`: for metabolite/pathway annotation. + layer_key_p_test : str or "use_raw", optional + Data layer to use for parametric test. + layer_key_np_test : str or "use_raw", optional + Data layer to use for non-parametric test. + model : str, optional + Normalization model to use for centering gene expression. Options include "none", "normal", "bernoulli", or "danb". + cell_type_key : str, optional + Key in `adata.obs` corresponding to cell type annotations. Required if not stored in `uns`. + center_counts_for_np_test : bool, optional (default: False) + Whether to center expression counts using the specified model before non-parametric testing. + subset_gene_pairs : list, optional + Subset of gene pairs to consider. If None, uses all pairs. + subset_metabolites : list, optional + Subset of metabolites to include in the analysis. + fix_gp : bool, optional (default: False) + If True, keeps gene pair identity fixed during permutation testing, randomizing cell types only. + M : int, optional (default: 1000) + Number of permutations to use if `permutation_test` is True. + seed : int, optional (default: 42) + Random seed for permutation reproducibility. + test : {'parametric', 'non-parametric', 'both'}, optional (default: 'both') + Specifies which statistical test(s) to run. + mean : {'algebraic', 'geometric'}, optional (default: 'algebraic') + Averaging method for multi-gene modules. + check_analytic_null : bool, optional (default: False) + Whether to compute Z-scores and p-values under the null distribution for the permutation test. + device : torch.device, optional + PyTorch device to run computations on. Defaults to CUDA if available. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + + Returns + ------- + None + Results are stored in the following `adata.uns` fields: + - `ct_ccc_results["p"]`: parametric test results (scores, Z, p-values, FDRs) per gene pair and metabolite per cell type pair. + - `ct_ccc_results["np"]`: non-parametric test results (communication scores, empirical p-values, FDRs). + - `gene_pair_dict`: dictionary mapping metabolites to relevant gene pairs. + - `gene_pairs_ind`, `gene_pairs_ind_per_ct_pair`: index-referenced gene pair representations. + - `D`: spatial node degree for each cell per cell type pair. + - `cells`, `genes`: ordered list of cells and genes used in analysis. + - (optional) `ct_ccc_results["np"]["analytic_null"]`: null distributions from permutation test Z-scores and p-values. + """ + start = time.time() + if verbose: + print("Starting cell type-aware cell-cell communication analysis...") + + adata.uns["ct_ccc_results"] = {} + + if test not in ["both", "parametric", "non-parametric"]: + raise ValueError( + 'The "test" variable should be one of ["both", "parametric", "non-parametric"].' + ) + + if mean not in ["algebraic", "geometric"]: + raise ValueError('The "mean" variable should be one of ["algebraic", "geometric"].') + + if "cell_type_key" in adata.uns and cell_type_key is None: + cell_type_key = adata.uns["cell_type_key"] + elif "cell_type_key" not in adata.uns and cell_type_key is None: + raise ValueError('Please provide the "cell_type_key" argument.') + + adata.uns["layer_key_p_test"] = layer_key_p_test + adata.uns["layer_key_np_test"] = layer_key_np_test + adata.uns["model"] = model + adata.uns["cell_type_key"] = cell_type_key + adata.uns["center_counts_for_np_test"] = center_counts_for_np_test + adata.uns["mean"] = mean + + run_ct_cell_communication_analysis( + adata, + layer_key_p_test, + layer_key_np_test, + model, + cell_type_key, + center_counts_for_np_test, + subset_gene_pairs, + subset_metabolites, + fix_gp, + M, + seed, + test, + mean, + check_analytic_null, + device, + verbose, + ) + + if verbose: + print("Obtaining cell type-aware communication results...") + get_ct_cell_communication_results( + adata, + adata.uns["genes"], + adata.uns["cells"], + layer_key_p_test, + layer_key_np_test, + model, + adata.obs[cell_type_key], + adata.uns["cell_type_pairs"], + adata.uns["D"], + test, + device, + ) + + if verbose: + print( + "Finished computing cell type-aware cell-cell communication analysis in %.3f seconds" + % (time.time() - start) + ) + + return + + +def run_ct_cell_communication_analysis( + adata, + layer_key_p_test, + layer_key_np_test, + model, + cell_type_key, + center_counts_for_np_test, + subset_gene_pairs, + subset_metabolites, + fix_gp, + M, + seed, + test, + mean, + check_analytic_null, + device, + verbose, +): + + use_raw = (layer_key_p_test == "use_raw") & (layer_key_np_test == "use_raw") + obs = adata.raw.obs if use_raw else adata.obs + cells = ( + adata.raw.obs.index.values.astype(str) if use_raw else adata.obs_names.values.astype(str) + ) + + sample_specific = "sample_key" in adata.uns + + fix_ct = True if adata.uns["fix_ct"] else False + + gene_pairs = adata.uns["gene_pairs"] if subset_gene_pairs is None else subset_gene_pairs + genes = list(np.unique(list(flatten(adata.uns["gene_pairs"])))) + adata.uns["genes"] = genes + + cell_types = obs[cell_type_key] + cell_type_pairs = adata.uns.get("cell_type_pairs") + gene_pairs_per_ct_pair = adata.uns.get("gene_pairs_per_ct_pair", {}) + + weights = adata.obsp["weights"] + + used_ct_pairs = list(set(ct for cell_type_pair in cell_type_pairs for ct in cell_type_pair)) + all_cell_types = set(cell_types.unique()) + used_ct_pairs_set = set(used_ct_pairs) + if used_ct_pairs_set < all_cell_types: + keep_mask = cell_types[cells].isin(used_ct_pairs).values + keep_indices = np.where(keep_mask)[0] + weights = weights[keep_indices][:, keep_indices] + cells = cells[keep_indices] + cell_types = cell_types.loc[cells] + + adata.uns["cells"] = cells + + weights_ct_pairs = create_weights_ct_pairs( + weights.tocoo(), cell_types, cell_type_pairs, device + ) + + row_degrees = torch.sparse.sum(weights_ct_pairs, dim=2).to_dense() + col_degrees = torch.sparse.sum(weights_ct_pairs, dim=1).to_dense() + D = row_degrees + col_degrees + if used_ct_pairs_set < all_cell_types: + D_full = torch.zeros( + (len(cell_type_pairs), adata.shape[0]), + device=weights_ct_pairs.device, + dtype=weights_ct_pairs.dtype, + ) + D_full[:, keep_indices] = D + adata.uns["D"] = D_full.cpu().numpy() + else: + adata.uns["D"] = D.cpu().numpy() + + # Map gene_pairs to index + gene_pairs_ind = [] + for pair in gene_pairs: + idx1 = ( + [genes.index(g) for g in pair[0]] + if isinstance(pair[0], list) + else genes.index(pair[0]) + ) + idx2 = ( + [genes.index(g) for g in pair[1]] + if isinstance(pair[1], list) + else genes.index(pair[1]) + ) + gene_pairs_ind.append((idx1, idx2)) + adata.uns["gene_pairs_ind"] = gene_pairs_ind + + # Cell-type pair-specific indices + gene_pairs_ind_per_ct_pair = defaultdict(list) + gene_pairs_per_ct_pair_ind = defaultdict(list) + for ct_pair, gpairs in gene_pairs_per_ct_pair.items(): + for pair in gpairs: + if pair not in gene_pairs: + continue + idx = gene_pairs.index(pair) + gene_pairs_ind_per_ct_pair[ct_pair].append(gene_pairs_ind[idx]) + gene_pairs_per_ct_pair_ind[ct_pair].append(idx) + + adata.uns["gene_pairs_ind_per_ct_pair"] = dict(gene_pairs_ind_per_ct_pair) + adata.uns["gene_pairs_per_ct_pair_ind"] = dict(gene_pairs_per_ct_pair_ind) + + def make_hashable(pair): + return tuple(tuple(x) if isinstance(x, list) else x for x in pair) + + gene_pairs_ind_set = {make_hashable(pair) for pair in gene_pairs_ind} + ct_specific_gene_pairs = [ + i + for i, pairs in enumerate(gene_pairs_ind_per_ct_pair.values()) + if {make_hashable(pair) for pair in pairs} < gene_pairs_ind_set + ] + + # Metabolite-gene pair preparation + gp_metab = adata.uns["gene_pairs_per_metabolite"] + metabolite_gene_pair_df = ( + pd.DataFrame.from_dict(gp_metab, orient="index") + .rename_axis("metabolite") + .explode(["gene_pair", "gene_type"]) + .reset_index() + ) + + if "LR_database" in adata.uns: + merged = metabolite_gene_pair_df.merge( + adata.uns["LR_database"], left_on="metabolite", right_on="interaction_name", how="left" + ) + LR_df = merged.dropna(subset=["pathway_name"]) + metabolite_gene_pair_df.loc[ + metabolite_gene_pair_df.metabolite.isin(LR_df.metabolite), "metabolite" + ] = LR_df["pathway_name"].values + + if subset_metabolites: + metabolite_gene_pair_df = metabolite_gene_pair_df[ + metabolite_gene_pair_df.metabolite.isin(subset_metabolites) + ] + + gene_pair_dict = {} + for metabolite, group in metabolite_gene_pair_df.groupby("metabolite"): + idxs = ( + group["gene_pair"] + .apply(lambda gp: gene_pairs.index(gp) if gp in gene_pairs else None) + .dropna() + .tolist() + ) + idxs = [int(ind) for ind in idxs if ind is not None] + if idxs: + gene_pair_dict[metabolite] = idxs + adata.uns["gene_pair_dict"] = gene_pair_dict + + if test in ["parametric", "both"]: + if verbose: + print("Running the parametric test...") + + adata.uns["ct_ccc_results"]["p"] = {"gp": {}, "m": {}} + + weights_sq_data = weights_ct_pairs.values() ** 2 + weights_sq = torch.sparse_coo_tensor( + weights_ct_pairs.indices(), + weights_sq_data, + weights_ct_pairs.shape, + device=weights_ct_pairs.device, + ) + Wtot2 = torch.sparse.sum(weights_sq, dim=(1, 2)).to_dense() + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_p_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + num_umi = counts.sum(dim=0) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_ind: + if isinstance(idx1, list): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, list): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + counts_1 = standardize_ct_counts( + adata, counts_1, model, num_umi, sample_specific, cell_types + ) + counts_2 = standardize_ct_counts( + adata, counts_2, model, num_umi, sample_specific, cell_types + ) + + # Compute CCC scores + cs_gp = torch.zeros((len(cell_type_pairs), counts_1.shape[0]), device=counts_1.device) + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + WX2t = torch.sparse.mm(W, counts_2.T) + cs_gp[ct_pair] = (counts_1.T * WX2t).sum(0) + adata.uns["ct_ccc_results"]["p"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + + cs_m = compute_metabolite_cs_ct( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + interacting_cell_scores=False, + ) + adata.uns["ct_ccc_results"]["p"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + EG2_gp = torch.zeros_like(cs_gp) if fix_ct or fix_gp else Wtot2 + if fix_ct: + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + W_sq_data = W.values() ** 2 + W_sq = torch.sparse_coo_tensor(W.indices(), W_sq_data, W.shape, device=W.device) + X1_sq = counts_1**2 + EG2_gp[ct_pair] = torch.sparse.mm(W_sq, X1_sq.T).sum(0) + elif fix_gp: + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + W_sq_data = W.values() ** 2 + W_sq = torch.sparse_coo_tensor(W.indices(), W_sq_data, W.shape, device=W.device) + X1_sq = counts_1**2 + X2_sq = counts_2**2 + EG2_gp[ct_pair] = (X1_sq.T * torch.sparse.mm(W_sq, X2_sq.T)).sum(0) + + Z_gp, Z_m = compute_ct_p_results( + cs_gp, + cs_m, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + EG2_gp, + cell_type_key, + gene_pair_dict, + ) + + # Convert tensors to numpy for statsmodels and pandas + Z_gp_np = Z_gp.detach().cpu().numpy() + Z_m_np = Z_m.detach().cpu().numpy() + # Compute p-values and FDRs + Z_pvals_gp = norm.sf(Z_gp_np) + Z_pvals_m = norm.sf(Z_m_np) + FDR_gp = multipletests(Z_pvals_gp.flatten(), method="fdr_bh")[1].reshape(Z_pvals_gp.shape) + FDR_m = multipletests(Z_pvals_m.flatten(), method="fdr_bh")[1].reshape(Z_pvals_m.shape) + + # Store in AnnData + adata.uns["ct_ccc_results"]["p"]["gp"]["Z"] = Z_gp_np + adata.uns["ct_ccc_results"]["p"]["gp"]["Z_pval"] = Z_pvals_gp + adata.uns["ct_ccc_results"]["p"]["gp"]["Z_FDR"] = FDR_gp + adata.uns["ct_ccc_results"]["p"]["m"]["Z"] = Z_m_np + adata.uns["ct_ccc_results"]["p"]["m"]["Z_pval"] = Z_pvals_m + adata.uns["ct_ccc_results"]["p"]["m"]["Z_FDR"] = FDR_m + + if test in ["non-parametric", "both"]: + if verbose: + print("Running the non-parametric test...") + + adata.uns["ct_ccc_results"]["np"] = {"gp": {}, "m": {}} + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_np_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_ind: + if isinstance(idx1, list): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, list): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + if center_counts_for_np_test: + num_umi = counts.sum(dim=0) + counts_1 = standardize_ct_counts( + adata, counts_1, model, num_umi, sample_specific, cell_types + ) + counts_2 = standardize_ct_counts( + adata, counts_2, model, num_umi, sample_specific, cell_types + ) + + if center_counts_for_np_test and test == "both": + adata.uns["ct_ccc_results"]["np"]["gp"]["cs"] = np.array( + adata.uns["ct_ccc_results"]["p"]["gp"]["cs"] + ) + adata.uns["ct_ccc_results"]["np"]["m"]["cs"] = np.array( + adata.uns["ct_ccc_results"]["p"]["m"]["cs"] + ) + else: + cs_gp = torch.zeros((len(cell_type_pairs), counts_1.shape[0]), device=counts_1.device) + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + WX2t = torch.sparse.mm(W, counts_2.T) + cs_gp[ct_pair] = (counts_1.T * WX2t).sum(0) + adata.uns["ct_ccc_results"]["np"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + cs_m = compute_metabolite_cs_ct( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + interacting_cell_scores=False, + ) + adata.uns["ct_ccc_results"]["np"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + perm_cs_gp = torch.zeros( + (len(cell_type_pairs), counts_1.shape[0], M), dtype=torch.float64, device=device + ) + perm_cs_m = torch.zeros( + (len(cell_type_pairs), len(gene_pair_dict), M), dtype=torch.float64, device=device + ) + + if check_analytic_null: + gp_zs_perm_array = torch.zeros_like(perm_cs_gp) + gp_pvals_perm_array = torch.zeros_like(perm_cs_gp) + m_zs_perm_array = torch.zeros_like(perm_cs_m) + m_pvals_perm_array = torch.zeros_like(perm_cs_m) + + if fix_gp: + c1_perm = counts_1 + c2_perm = counts_2 + + torch.manual_seed(seed) + for i in tqdm(range(M), desc="Permutation test"): + if fix_gp: + indices = torch.randperm(len(cell_types)).numpy() + shuffled_cell_types = cell_types.iloc[indices].reset_index(drop=True) + weights_ct_pairs = create_weights_ct_pairs( + weights.tocoo(), shuffled_cell_types, cell_type_pairs, device + ) + else: + cell_type_labels = torch.tensor( + cell_types.astype("category").cat.codes.values, device=counts_1.device + ) + idx = torch.empty_like(cell_type_labels, dtype=torch.int64) + + for ct in torch.unique(cell_type_labels): + ct_mask = cell_type_labels == ct + ct_indices = torch.nonzero(ct_mask, as_tuple=True)[0] + permuted_indices = ct_indices[torch.randperm(len(ct_indices))] + idx[ct_indices] = permuted_indices + + c1_perm = counts_1 if fix_ct else counts_1[:, idx.long()] + c2_perm = counts_2[:, idx.long()] + + cs_gp = torch.zeros((len(cell_type_pairs), c1_perm.shape[0]), device=c1_perm.device) + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + WX2t = torch.sparse.mm(W, c2_perm.T) + cs_gp[ct_pair] = (c1_perm.T * WX2t).sum(0) + perm_cs_gp[:, :, i] = cs_gp + + cs_m = compute_metabolite_cs_ct( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + interacting_cell_scores=False, + ) + perm_cs_m[:, :, i] = cs_m + + if check_analytic_null: + Z_gp_perm, Z_m_perm = compute_ct_p_results( + cs_gp, + cs_m, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + EG2_gp, + cell_type_key, + gene_pair_dict, + ) + gp_zs_perm_array[:, :, i] = Z_gp_perm + gp_pvals_perm_array[:, :, i] = torch.tensor( + norm.sf(Z_gp_perm.cpu().numpy()), device=device + ) + m_zs_perm_array[:, :, i] = Z_m_perm + m_pvals_perm_array[:, :, i] = torch.tensor( + norm.sf(Z_m_perm.cpu().numpy()), device=device + ) + + adata.uns["ct_ccc_results"]["np"]["gp"]["perm_cs"] = perm_cs_gp.detach().cpu().numpy() + adata.uns["ct_ccc_results"]["np"]["m"]["perm_cs"] = perm_cs_m.detach().cpu().numpy() + + x_gp = np.sum( + adata.uns["ct_ccc_results"]["np"]["gp"]["perm_cs"] + > adata.uns["ct_ccc_results"]["np"]["gp"]["cs"][:, :, np.newaxis], + axis=2, + ) + x_m = np.sum( + adata.uns["ct_ccc_results"]["np"]["m"]["perm_cs"] + > adata.uns["ct_ccc_results"]["np"]["m"]["cs"][:, :, np.newaxis], + axis=2, + ) + + pvals_gp = (x_gp + 1) / (M + 1) + pvals_m = (x_m + 1) / (M + 1) + + adata.uns["ct_ccc_results"]["np"]["gp"]["pval"] = pvals_gp + adata.uns["ct_ccc_results"]["np"]["gp"]["FDR"] = multipletests( + pvals_gp.flatten(), method="fdr_bh" + )[1].reshape(pvals_gp.shape) + adata.uns["ct_ccc_results"]["np"]["m"]["pval"] = pvals_m + adata.uns["ct_ccc_results"]["np"]["m"]["FDR"] = multipletests( + pvals_m.flatten(), method="fdr_bh" + )[1].reshape(pvals_m.shape) + + if check_analytic_null: + adata.uns["ct_ccc_results"]["np"]["analytic_null"] = { + "gp_zs_perm": gp_zs_perm_array.detach().cpu().numpy(), + "gp_pvals_perm": gp_pvals_perm_array.detach().cpu().numpy(), + "m_zs_perm": m_zs_perm_array.detach().cpu().numpy(), + "m_pvals_perm": m_pvals_perm_array.detach().cpu().numpy(), + } + + adata.uns["cell_types"] = cell_types.tolist() if cell_type_key else None + + if verbose: + print("Non-parametric test finished.") + + return + + +def standardize_ct_counts(adata, counts, model, num_umi, sample_specific, cell_types): + + if sample_specific: + sample_key = adata.uns["sample_key"] + for sample in adata.obs[sample_key].unique(): + subset = np.where(adata.obs[sample_key] == sample)[0] + counts[:, subset] = center_ct_counts_torch( + counts[:, subset], num_umi[subset], model, cell_types[subset] + ) + else: + counts = center_ct_counts_torch(counts, num_umi, model, cell_types) + + return counts + + +def flatten(nested_list): + for item in nested_list: + if isinstance(item, (list, tuple)): + yield from flatten(item) + else: + yield item + + +def create_weights_ct_pairs(weights, cell_types, cell_type_pairs, device): + + indices = torch.tensor([weights.row, weights.col], dtype=torch.long, device=device) + values = torch.tensor(weights.data, dtype=torch.float64, device=device) + shape = weights.shape + + cell_type_cats = cell_types.astype("category") + cell_type_codes = torch.tensor( + cell_type_cats.cat.codes.values, dtype=torch.long, device=device + ) + ct_name_to_code = {name: code for code, name in enumerate(cell_type_cats.cat.categories)} + + row_idx, col_idx = indices + sender_types = cell_type_codes[row_idx] + receiver_types = cell_type_codes[col_idx] + + weights_list = [] + coord_list = [] + + for i, (ct1, ct2) in enumerate(cell_type_pairs): + code1 = ct_name_to_code[ct1] + code2 = ct_name_to_code[ct2] + + pair_mask = (sender_types == code1) & (receiver_types == code2) + if pair_mask.sum() == 0: + continue + + pair_values = values[pair_mask] + pair_coords = torch.stack( + [ + torch.full((pair_values.shape[0],), i, dtype=torch.long, device=device), + row_idx[pair_mask], + col_idx[pair_mask], + ], + dim=0, + ) + + weights_list.append(pair_values) + coord_list.append(pair_coords) + + all_values = torch.cat(weights_list) + all_coords = torch.cat(coord_list, dim=1) + weights_ct_pairs = torch.sparse_coo_tensor( + all_coords, all_values, (len(cell_type_pairs), shape[0], shape[1]), device=device + ) + weights_ct_pairs = weights_ct_pairs.coalesce() + + return weights_ct_pairs + + +def select_significant_interactions( + adata: AnnData, + ct_aware: bool | None = False, + test: Literal["parametric"] | Literal["non-parametric"] | None = "parametric", + use_FDR: bool | None = True, + threshold: float | None = 0.05, +): + """ + Select significant gene pairs or metabolite-mediated interactions based on + FDR/p-value thresholds and (optionally) cell-type–aware tests. + + Parameters + ---------- + adata : AnnData + AnnData object containing: + - ``uns['ccc_results']`` or ``uns['ct_ccc_results']``, each of which includes: + * ``cell_com_df_gp``: DataFrame with statistics for gene pairs + * ``cell_com_df_m``: DataFrame with statistics for metabolites + ct_aware : bool, default False + If True, use cell-type–aware CCC results (``uns['ct_ccc_results']``). + If False, use cell-type-agnostic CCC results (``uns['ccc_results']``). + test : {"parametric", "non-parametric"}, default "parametric" + Determines which statistical columns to use: + - Parametric: ``Z_FDR`` / ``Z_pval``, ``C_p`` + - Non-parametric: ``FDR_np`` / ``pval_np``, ``C_np`` + use_FDR : bool, default True + If True, threshold significance using FDR values. + If False, use raw p-values. + threshold : float, default 0.05 + Significance cutoff applied to the selected statistic (FDR or p-value). + """ + ccc_key = "ct_ccc_results" if ct_aware else "ccc_results" + sig_key = "FDR" if use_FDR else "pval" + + if test == "parametric": + FDR_values_gp = adata.uns[ccc_key]["cell_com_df_gp"][f"Z_{sig_key}"].values + C_values_gp = adata.uns[ccc_key]["cell_com_df_gp"]["C_p"].values + FDR_values_m = adata.uns[ccc_key]["cell_com_df_m"][f"Z_{sig_key}"].values + C_values_m = adata.uns[ccc_key]["cell_com_df_m"]["C_p"].values + elif test == "non-parametric": + FDR_values_gp = adata.uns[ccc_key]["cell_com_df_gp"][f"{sig_key}_np"].values + C_values_gp = adata.uns[ccc_key]["cell_com_df_gp"]["C_np"].values + FDR_values_m = adata.uns[ccc_key]["cell_com_df_m"][f"{sig_key}_np"].values + C_values_m = adata.uns[ccc_key]["cell_com_df_m"]["C_np"].values + else: + raise ValueError('The "test" variable should be one of ["parametric", "non-parametric"].') + + # Gene pair + adata.uns[ccc_key]["cell_com_df_gp"]["selected"] = ( + (FDR_values_gp < threshold) & (C_values_gp > 0) + if test == "non-parametric" + else (FDR_values_gp < threshold) + ) + cell_com_df_gp = adata.uns[ccc_key]["cell_com_df_gp"] + adata.uns[ccc_key]["cell_com_df_gp_sig"] = cell_com_df_gp[ + cell_com_df_gp.selected == True + ].copy() + + # Metabolite + adata.uns[ccc_key]["cell_com_df_m"]["selected"] = ( + (FDR_values_m < threshold) & (C_values_m > 0) + if test == "non-parametric" + else (FDR_values_m < threshold) + ) + cell_com_df_m = adata.uns[ccc_key]["cell_com_df_m"] + adata.uns[ccc_key]["cell_com_df_m_sig"] = cell_com_df_m[cell_com_df_m.selected == True].copy() + + return + + +def compute_interacting_cell_scores( + adata: str | AnnData, + center_counts_for_np_test: bool | None = False, + test: Literal["parametric"] | Literal["non-parametric"] | Literal["both"] | None = "both", + restrict_significance: Literal["gene pairs"] + | Literal["metabolites"] + | Literal["both"] + | None = "both", + compute_significance: Literal["parametric"] + | Literal["non-parametric"] + | Literal["both"] + | None = "both", + M: int | None = 1000, + seed: int | None = 42, + check_analytic_null: bool | None = False, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Compute interacting cell scores for gene pairs and metabolites. + + Parameters + ---------- + adata : AnnData or str + AnnData object containing: + - ``uns['model']`` and ``uns['mean']`` (expression normalization model) + - ``uns['gene_pairs']``, ``uns['gene_pairs_per_metabolite']`` + - ``obsp['weights']``: sparse spatial weight matrix + - ``uns['ccc_results']`` (for significance filtering) + center_counts_for_np_test : bool, optional + If True, center/normalize counts prior to the non-parametric test. + test : {"parametric", "non-parametric", "both"} + Which interacting cell score tests to compute. + restrict_significance : {"gene pairs", "metabolites", "both"} + Use only significant gene pairs/metabolites from CCC results. + compute_significance : {"parametric", "non-parametric", "both"} + Whether to compute significance (p-values, FDR) in each test. + M : int, default 1000 + Number of permutations for the non-parametric test. + seed : int, default 42 + Random seed for permutation reproducibility. + check_analytic_null : bool, default False + If True, evaluate the analytic null distribution during permutations. + device : torch.device + CPU or GPU device for tensor operations. + verbose : bool, default False + Print status updates. + + Returns + ------- + None + Results are stored in ``adata.uns['interacting_cell_results']``. + """ + start = time.time() + if verbose: + print("Computing gene pair and metabolite scores...") + + adata.uns["interacting_cell_results"] = {} + + model = adata.uns["model"] + mean = adata.uns["mean"] + + if test not in ["both", "parametric", "non-parametric"]: + raise ValueError( + 'The "test" variable should be one of ["both", "parametric", "non-parametric"].' + ) + + if restrict_significance is not None and restrict_significance not in [ + "both", + "gene pairs", + "metabolites", + ]: + raise ValueError( + 'The "restrict_significance" variable should be one of ["both", "gene pairs", "metabolites"].' + ) + + if compute_significance is not None and compute_significance not in [ + "both", + "parametric", + "non-parametric", + ]: + raise ValueError( + 'The "compute_significance" variable should be one of ["both", "parametric", "non-parametric"].' + ) + + sample_specific = "sample_key" in adata.uns + + layer_key_p_test = adata.uns.get("layer_key_p_test", None) + layer_key_np_test = adata.uns.get("layer_key_np_test", None) + use_raw = (layer_key_p_test == "use_raw") and (layer_key_np_test == "use_raw") + + gene_pairs = adata.uns.get("gene_pairs", None) + gene_pairs_per_metabolite = adata.uns["gene_pairs_per_metabolite"] + + def to_tuple(x): + # Recursively convert lists to tuples + if isinstance(x, list): + return tuple(to_tuple(i) for i in x) + return x + + metabolite_gene_pair_df = pd.DataFrame.from_dict( + gene_pairs_per_metabolite, orient="index" + ).reset_index() + metabolite_gene_pair_df = metabolite_gene_pair_df.rename(columns={"index": "metabolite"}) + metabolite_gene_pair_df["gene_pair"] = metabolite_gene_pair_df["gene_pair"].apply( + lambda arr: [(to_tuple(gp[0]), to_tuple(gp[1])) for gp in arr] + ) + metabolite_gene_pair_df["gene_type"] = metabolite_gene_pair_df["gene_type"].apply( + lambda arr: [(to_tuple(gt[0]), to_tuple(gt[1])) for gt in arr] + ) + metabolite_gene_pair_df = pd.concat( + [ + metabolite_gene_pair_df["metabolite"], + metabolite_gene_pair_df.explode("gene_pair")["gene_pair"], + metabolite_gene_pair_df.explode("gene_type")["gene_type"], + ], + axis=1, + ).reset_index(drop=True) + + if "LR_database" in adata.uns: + LR_database = adata.uns["LR_database"] + df_merged = pd.merge( + metabolite_gene_pair_df, + LR_database, + left_on="metabolite", + right_on="interaction_name", + how="left", + ) + LR_df = df_merged.dropna(subset=["pathway_name"]) + metabolite_gene_pair_df["metabolite"][ + metabolite_gene_pair_df.metabolite.isin(LR_df.metabolite) + ] = LR_df["pathway_name"] + + if restrict_significance in ["both", "gene pairs"]: + cell_com_gp_df = adata.uns["ccc_results"]["cell_com_df_gp_sig"].copy() + cell_com_gp_df[["Gene 1", "Gene 2"]] = cell_com_gp_df[["Gene 1", "Gene 2"]].applymap( + lambda x: tuple(x) if isinstance(x, list) else x + ) + + gene_pairs_set = set([tuple(x) for x in cell_com_gp_df[["Gene 1", "Gene 2"]].values]) + metabolite_gene_pair_df = metabolite_gene_pair_df[ + metabolite_gene_pair_df["gene_pair"].isin(gene_pairs_set) + ] + + if restrict_significance in ["both", "metabolites"]: + cell_com_m_df = adata.uns["ccc_results"]["cell_com_df_m_sig"].copy() + metabolite_set = set(cell_com_m_df["Metabolite"].values) + metabolite_gene_pair_df = metabolite_gene_pair_df[ + metabolite_gene_pair_df["metabolite"].isin(metabolite_set) + ] + + genes = adata.uns["genes"] + gene_pairs_sig = [] + if gene_pairs: + for g1, g2 in gene_pairs: + g1 = tuple(g1) if isinstance(g1, list) else g1 + g2 = tuple(g2) if isinstance(g2, list) else g2 + if not metabolite_gene_pair_df[metabolite_gene_pair_df["gene_pair"] == (g1, g2)].empty: + gene_pairs_sig.append((g1, g2)) + + adata.uns["gene_pairs_sig"] = gene_pairs_sig + + gene_pairs_sig_ind = [] + for g1, g2 in gene_pairs_sig: + idx1 = tuple([genes.index(g) for g in g1]) if isinstance(g1, tuple) else genes.index(g1) + idx2 = tuple([genes.index(g) for g in g2]) if isinstance(g2, tuple) else genes.index(g2) + gene_pairs_sig_ind.append((idx1, idx2)) + + adata.uns["gene_pairs_sig_ind"] = gene_pairs_sig_ind + + if "barcode_key" in adata.uns: + barcode_key = adata.uns["barcode_key"] + cells = pd.Series(adata.obs[barcode_key].tolist()) + else: + cells = adata.obs_names if not use_raw else adata.raw.obs_names + + # Compute weights + weights = make_weights_non_redundant(adata.obsp["weights"]).tocoo() + weights = torch.sparse_coo_tensor( + torch.tensor(np.vstack((weights.row, weights.col)), dtype=torch.long, device=device), + torch.tensor(weights.data, dtype=torch.float64, device=device), + torch.Size(weights.shape), + device=device, + ) + + gene_pair_dict = {} + for metabolite, group in metabolite_gene_pair_df.groupby("metabolite"): + idxs = ( + group["gene_pair"] + .apply(lambda gp: gene_pairs_sig.index(gp) if gp in gene_pairs_sig else None) + .dropna() + .tolist() + ) + idxs = [int(ind) for ind in idxs if ind is not None] + if idxs: + gene_pair_dict[metabolite] = idxs + metabolites = list(gene_pair_dict.keys()) + + adata.uns["metabolites"] = metabolites + + gene_pairs_sig_names = [ + "_".join("_".join(g) if isinstance(g, tuple) else g for g in gp) for gp in gene_pairs_sig + ] + + adata.uns["gene_pairs_sig_names"] = gene_pairs_sig_names + + if test in ["parametric", "both"]: + if verbose: + print("Running the parametric test...") + + adata.uns["interacting_cell_results"]["p"] = {"gp": {}, "m": {}} + + Wtot2 = torch.tensor((weights.data**2).sum(), device=device) + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_p_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + num_umi = counts.sum(dim=0) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_sig_ind: + if isinstance(idx1, tuple): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, tuple): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + _lazy_import_hotspot() + counts_1 = standardize_counts(adata, counts_1, model, num_umi, sample_specific) + _lazy_import_hotspot() + counts_2 = standardize_counts(adata, counts_2, model, num_umi, sample_specific) + + # Compute CCC scores + WX2t = torch.sparse.mm(weights, counts_2.T) + WtX2t = torch.sparse.mm(weights.transpose(0, 1), counts_2.T) + cs_gp = (counts_1.T * WX2t) + (counts_1.T * WtX2t) + same_gene_mask = torch.tensor([g1 == g2 for g1, g2 in gene_pairs_sig], device=device) + cs_gp[:, same_gene_mask] = cs_gp[:, same_gene_mask] / 2 + adata.uns["interacting_cell_results"]["p"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + + # Compute metabolite-level scores + cs_m = compute_metabolite_cs(cs_gp, gene_pair_dict, interacting_cell_scores=True) + adata.uns["interacting_cell_results"]["p"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + if compute_significance in ["parametric", "both"]: + # Compute second moments + WX1t = torch.sparse.mm(weights, counts_1.T) + WtX1t = torch.sparse.mm(weights.transpose(0, 1), counts_1.T) + eg2_a = (WX1t + WtX1t).pow(2) + eg2_b = (WX2t + WtX2t).pow(2) + eg2s_gp = (eg2_a, eg2_b) + + Z_gp, Z_m = compute_p_int_cell_results_no_ct( + cs_gp, cs_m, gene_pairs_sig_ind, Wtot2, eg2s_gp, gene_pair_dict + ) + + Z_gp_np = Z_gp.detach().cpu().numpy() + Z_m_np = Z_m.detach().cpu().numpy() + # Compute p-values and FDRs + Z_pvals_gp = norm.sf(Z_gp_np) + Z_pvals_m = norm.sf(Z_m_np) + FDR_gp = multipletests(Z_pvals_gp.flatten(), method="fdr_bh")[1].reshape( + Z_pvals_gp.shape + ) + FDR_m = multipletests(Z_pvals_m.flatten(), method="fdr_bh")[1].reshape(Z_pvals_m.shape) + + adata.uns["interacting_cell_results"]["p"]["gp"]["Z"] = Z_gp_np + adata.uns["interacting_cell_results"]["p"]["gp"]["Z_pval"] = Z_pvals_gp + adata.uns["interacting_cell_results"]["p"]["gp"]["Z_FDR"] = FDR_gp + adata.uns["interacting_cell_results"]["p"]["m"]["Z"] = Z_m_np + adata.uns["interacting_cell_results"]["p"]["m"]["Z_pval"] = Z_pvals_m + adata.uns["interacting_cell_results"]["p"]["m"]["Z_FDR"] = FDR_m + + # P-value + mask_gp = adata.uns["interacting_cell_results"]["p"]["gp"]["Z_pval"] < 0.05 + mask_m = adata.uns["interacting_cell_results"]["p"]["m"]["Z_pval"] < 0.05 + + cs_gp_sig = adata.uns["interacting_cell_results"]["p"]["gp"]["cs"].copy() + cs_m_sig = adata.uns["interacting_cell_results"]["p"]["m"]["cs"].copy() + + cs_gp_sig[~mask_gp] = np.nan + cs_m_sig[~mask_m] = np.nan + adata.uns["interacting_cell_results"]["p"]["gp"]["cs_sig_pval"] = cs_gp_sig + adata.uns["interacting_cell_results"]["p"]["m"]["cs_sig_pval"] = cs_m_sig + + # FDR + mask_gp = adata.uns["interacting_cell_results"]["p"]["gp"]["Z_FDR"] < 0.05 + mask_m = adata.uns["interacting_cell_results"]["p"]["m"]["Z_FDR"] < 0.05 + + cs_gp_sig = adata.uns["interacting_cell_results"]["p"]["gp"]["cs"].copy() + cs_m_sig = adata.uns["interacting_cell_results"]["p"]["m"]["cs"].copy() + + cs_gp_sig[~mask_gp] = np.nan + cs_m_sig[~mask_m] = np.nan + adata.uns["interacting_cell_results"]["p"]["gp"]["cs_sig_FDR"] = cs_gp_sig + adata.uns["interacting_cell_results"]["p"]["m"]["cs_sig_FDR"] = cs_m_sig + + if verbose: + print("Parametric test finished.") + + if test in ["non-parametric", "both"]: + if verbose: + print("Running the non-parametric test...") + + adata.uns["interacting_cell_results"]["np"] = {"gp": {}, "m": {}} + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_np_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_sig_ind: + if isinstance(idx1, tuple): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, tuple): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + if center_counts_for_np_test: + num_umi = counts.sum(dim=0) + _lazy_import_hotspot() + counts_1 = standardize_counts(adata, counts_1, model, num_umi, sample_specific) + _lazy_import_hotspot() + counts_2 = standardize_counts(adata, counts_2, model, num_umi, sample_specific) + + n_cells = counts_1.shape[1] + same_gene_mask = torch.tensor([g1 == g2 for g1, g2 in gene_pairs_sig], device=device) + + if center_counts_for_np_test and test == "both": + adata.uns["interacting_cell_results"]["np"]["gp"]["cs"] = np.array( + adata.uns["interacting_cell_results"]["p"]["gp"]["cs"] + ) + adata.uns["interacting_cell_results"]["np"]["m"]["cs"] = np.array( + adata.uns["interacting_cell_results"]["p"]["m"]["cs"] + ) + else: + WX2t = torch.sparse.mm(weights, counts_2.T) + WtX2t = torch.sparse.mm(weights.transpose(0, 1), counts_2.T) + cs_gp = (counts_1.T * WX2t) + (counts_1.T * WtX2t) + cs_gp[:, same_gene_mask] = cs_gp[:, same_gene_mask] / 2 + adata.uns["interacting_cell_results"]["np"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + cs_m = compute_metabolite_cs(cs_gp, gene_pair_dict, interacting_cell_scores=True) + adata.uns["interacting_cell_results"]["np"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + if compute_significance in ["non-parametric", "both"]: + perm_cs_gp_a = torch.zeros( + (n_cells, counts_1.shape[0], M), dtype=torch.float64, device=device + ) + perm_cs_gp_b = torch.zeros_like(perm_cs_gp_a) + perm_cs_m_a = torch.zeros( + (n_cells, len(gene_pair_dict), M), dtype=torch.float64, device=device + ) + perm_cs_m_b = torch.zeros_like(perm_cs_m_a) + + if check_analytic_null: + gp_zs_perm_array = torch.zeros_like(perm_cs_gp_a) + gp_pvals_perm_array = torch.zeros_like(perm_cs_gp_a) + m_zs_perm_array = torch.zeros_like(perm_cs_m_a) + m_pvals_perm_array = torch.zeros_like(perm_cs_m_a) + + torch.manual_seed(seed) + for i in tqdm(range(M), desc="Permutation test"): + idx = torch.randperm(n_cells, device=device) + + c1_perm_a = counts_1.clone() + c2_perm_a = counts_2[:, idx] + c1_perm_a[same_gene_mask] = counts_1[same_gene_mask, :][:, idx] + + WX2t_a = torch.sparse.mm(weights, c2_perm_a.T) + WtX2t_a = torch.sparse.mm(weights.transpose(0, 1), c2_perm_a.T) + cs_a = (c1_perm_a.T * WX2t_a) + (c1_perm_a.T * WtX2t_a) + cs_a[:, same_gene_mask] = cs_a[:, same_gene_mask] / 2 + perm_cs_gp_a[:, :, i] = cs_a + + cs_m_a = compute_metabolite_cs(cs_a, gene_pair_dict, interacting_cell_scores=True) + perm_cs_m_a[:, :, i] = cs_m_a + + c2_perm_b = counts_2.clone() + c1_perm_b = counts_1[:, idx] + c2_perm_b[same_gene_mask] = counts_2[same_gene_mask, :][:, idx] + + WX2t_b = torch.sparse.mm(weights, c2_perm_b.T) + WtX2t_b = torch.sparse.mm(weights.transpose(0, 1), c2_perm_b.T) + cs_b = (c1_perm_b.T * WX2t_b) + (c1_perm_b.T * WtX2t_b) + cs_b[:, same_gene_mask] = cs_b[:, same_gene_mask] / 2 + perm_cs_gp_b[:, :, i] = cs_b + + cs_m_b = compute_metabolite_cs(cs_b, gene_pair_dict, interacting_cell_scores=True) + perm_cs_m_b[:, :, i] = cs_m_b + + if check_analytic_null: + Z_gp_perm, Z_m_perm = compute_p_results( + (cs_a, cs_b), + (cs_m_a, cs_m_b), + gene_pairs_ind, + Wtot2, + eg2s_gp, + gene_pair_dict, + ) + gp_zs_perm_array[:, :, i] = Z_gp_perm + gp_pvals_perm_array[:, :, i] = torch.tensor( + norm.sf(Z_gp_perm.cpu().numpy()), device=device + ) + m_zs_perm_array[:, :, i] = Z_m_perm + m_pvals_perm_array[:, :, i] = torch.tensor( + norm.sf(Z_m_perm.cpu().numpy()), device=device + ) + + adata.uns["interacting_cell_results"]["np"]["gp"]["perm_cs_a"] = ( + perm_cs_gp_a.detach().cpu().numpy() + ) + adata.uns["interacting_cell_results"]["np"]["gp"]["perm_cs_b"] = ( + perm_cs_gp_b.detach().cpu().numpy() + ) + adata.uns["interacting_cell_results"]["np"]["m"]["perm_cs_a"] = ( + perm_cs_m_a.detach().cpu().numpy() + ) + adata.uns["interacting_cell_results"]["np"]["m"]["perm_cs_b"] = ( + perm_cs_m_b.detach().cpu().numpy() + ) + + x_gp_a = (perm_cs_gp_a > cs_gp[:, :, None]).sum(dim=2) + x_gp_b = (perm_cs_gp_b > cs_gp[:, :, None]).sum(dim=2) + x_m_a = (perm_cs_m_a > cs_m[:, :, None]).sum(dim=2) + x_m_b = (perm_cs_m_b > cs_m[:, :, None]).sum(dim=2) + + pvals_gp_a = (x_gp_a + 1).float() / (M + 1) + pvals_gp_b = (x_gp_b + 1).float() / (M + 1) + pvals_m_a = (x_m_a + 1).float() / (M + 1) + pvals_m_b = (x_m_b + 1).float() / (M + 1) + + pvals_gp = torch.where(pvals_gp_a > pvals_gp_b, pvals_gp_a, pvals_gp_b) + pvals_m = torch.where(pvals_m_a > pvals_m_b, pvals_m_a, pvals_m_b) + + pvals_gp = pvals_gp.cpu().numpy() + pvals_m = pvals_m.cpu().numpy() + + adata.uns["interacting_cell_results"]["np"]["gp"]["pval"] = pvals_gp + adata.uns["interacting_cell_results"]["np"]["gp"]["FDR"] = multipletests( + pvals_gp.flatten(), method="fdr_bh" + )[1].reshape(pvals_gp.shape) + adata.uns["interacting_cell_results"]["np"]["m"]["pval"] = pvals_m + adata.uns["interacting_cell_results"]["np"]["m"]["FDR"] = multipletests( + pvals_m.flatten(), method="fdr_bh" + )[1].reshape(pvals_m.shape) + + if check_analytic_null: + adata.uns["interacting_cell_results"]["np"]["analytic_null"] = { + "gp_zs_perm": gp_zs_perm_array.detach().cpu().numpy(), + "gp_pvals_perm": gp_pvals_perm_array.detach().cpu().numpy(), + "m_zs_perm": m_zs_perm_array.detach().cpu().numpy(), + "m_pvals_perm": m_pvals_perm_array.detach().cpu().numpy(), + } + + # P-value + mask_gp = adata.uns["interacting_cell_results"]["np"]["gp"]["pval"] < 0.05 + mask_m = adata.uns["interacting_cell_results"]["np"]["m"]["pval"] < 0.05 + + cs_gp_sig = adata.uns["interacting_cell_results"]["np"]["gp"]["cs"].copy() + cs_m_sig = adata.uns["interacting_cell_results"]["np"]["m"]["cs"].copy() + + cs_gp_sig[~mask_gp] = np.nan + cs_m_sig[~mask_m] = np.nan + adata.uns["interacting_cell_results"]["np"]["gp"]["cs_sig_pval"] = cs_gp_sig + adata.uns["interacting_cell_results"]["np"]["m"]["cs_sig_pval"] = cs_m_sig + + # FDR + mask_gp = adata.uns["interacting_cell_results"]["np"]["gp"]["FDR"] < 0.05 + mask_m = adata.uns["interacting_cell_results"]["np"]["m"]["FDR"] < 0.05 + + cs_gp_sig = adata.uns["interacting_cell_results"]["np"]["gp"]["cs"].copy() + cs_m_sig = adata.uns["interacting_cell_results"]["np"]["m"]["cs"].copy() + + cs_gp_sig[~mask_gp] = np.nan + cs_m_sig[~mask_m] = np.nan + adata.uns["interacting_cell_results"]["np"]["gp"]["cs_sig_FDR"] = cs_gp_sig + adata.uns["interacting_cell_results"]["np"]["m"]["cs_sig_FDR"] = cs_m_sig + + if verbose: + print("Non-parametric test finished.") + + if verbose: + print( + "Finished computing gene pair and metabolite scores in %.3f seconds" + % (time.time() - start) + ) + + return + + +def compute_ct_interacting_cell_scores( + adata: str | AnnData, + center_counts_for_np_test: bool | None = False, + test: Literal["parametric"] | Literal["non-parametric"] | Literal["both"] | None = "both", + restrict_significance: Literal["gene pairs"] + | Literal["metabolites"] + | Literal["both"] + | None = "both", + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + verbose: bool | None = False, +): + """ + Compute cell-type–aware interacting cell scores for gene pairs and metabolites. + + Parameters + ---------- + adata : AnnData or str + Must contain: + - ``uns['model']``, ``uns['mean']`` + - ``uns['cell_type_key']`` and ``obs[cell_type_key]`` for cell types + - ``uns['gene_pairs']``, ``uns['gene_pairs_per_ct_pair']`` + - ``uns['gene_pairs_per_metabolite']`` + - ``uns['ct_ccc_results']`` with significance information + - ``obsp['weights']`` (spatial proximity matrix) + center_counts_for_np_test : bool, default False + Whether to standardize counts before the non-parametric test. + test : {"parametric", "non-parametric", "both"} + Which statistical test(s) to run. + restrict_significance : {"gene pairs", "metabolites", "both"} + Only use cell-type-pair interactions that were significant in cell-type-aware CCC results. + device : torch.device + CPU or GPU device for PyTorch computations. + verbose : bool, default False + Print detailed progress messages. + """ + start = time.time() + if verbose: + print("Computing cell type-aware gene pair and metabolite scores...") + + adata.uns["ct_interacting_cell_results"] = {} + + model = adata.uns["model"] + mean = adata.uns["mean"] + + if test not in ["both", "parametric", "non-parametric"]: + raise ValueError( + 'The "test" variable should be one of ["both", "parametric", "non-parametric"].' + ) + + if restrict_significance not in ["both", "gene pairs", "metabolites"]: + raise ValueError( + 'The "restrict_significance" variable should be one of ["both", "gene pairs", "metabolites"].' + ) + + sample_specific = "sample_key" in adata.uns + + layer_key_p_test = adata.uns.get("layer_key_p_test", None) + layer_key_np_test = adata.uns.get("layer_key_np_test", None) + use_raw = (layer_key_p_test == "use_raw") and (layer_key_np_test == "use_raw") + + obs = adata.raw.obs if use_raw else adata.obs + cells = ( + adata.raw.obs.index.values.astype(str) if use_raw else adata.obs_names.values.astype(str) + ) + + gene_pairs = adata.uns.get("gene_pairs", None) + gene_pairs_per_ct_pair = adata.uns.get("gene_pairs_per_ct_pair", None) + gene_pairs_per_metabolite = adata.uns["gene_pairs_per_metabolite"] + + gp_metab = adata.uns["gene_pairs_per_metabolite"] + metabolite_gene_pair_df = ( + pd.DataFrame.from_dict(gp_metab, orient="index") + .rename_axis("metabolite") + .explode(["gene_pair", "gene_type"]) + .reset_index() + ) + + if "LR_database" in adata.uns: + merged = metabolite_gene_pair_df.merge( + adata.uns["LR_database"], left_on="metabolite", right_on="interaction_name", how="left" + ) + LR_df = merged.dropna(subset=["pathway_name"]) + metabolite_gene_pair_df.loc[ + metabolite_gene_pair_df.metabolite.isin(LR_df.metabolite), "metabolite" + ] = LR_df["pathway_name"].values + + cell_type_pairs = adata.uns.get("cell_type_pairs") + cell_type_pairs = [tuple(x) for x in cell_type_pairs] + + cell_com_gp_df = adata.uns["ct_ccc_results"]["cell_com_df_gp_sig"].copy() + if restrict_significance in ["both", "gene pairs"]: + ct_pairs_gp_set = set( + [tuple(x) for x in cell_com_gp_df[["Cell Type 1", "Cell Type 2"]].values] + ) + cell_type_pairs = [ct_pair for ct_pair in cell_type_pairs if ct_pair in ct_pairs_gp_set] + + cell_com_gp_df[["Gene 1", "Gene 2"]] = cell_com_gp_df[["Gene 1", "Gene 2"]].applymap( + lambda x: tuple(x) if isinstance(x, list) else x + ) + + gene_pairs_set = set([tuple(x) for x in cell_com_gp_df[["Gene 1", "Gene 2"]].values]) + metabolite_gene_pair_df = metabolite_gene_pair_df[ + metabolite_gene_pair_df["gene_pair"].isin(gene_pairs_set) + ] + + cell_com_m_df = adata.uns["ct_ccc_results"]["cell_com_df_m_sig"].copy() + if restrict_significance in ["both", "metabolites"]: + ct_pairs_m_set = set( + [tuple(x) for x in cell_com_m_df[["Cell Type 1", "Cell Type 2"]].values] + ) + missing_ct_pairs = [ + ct_pair for ct_pair in ct_pairs_m_set if ct_pair not in cell_type_pairs + ] + if len(missing_ct_pairs) > 0: + warnings.warn( + f'The following cell type pairs are not included in the "cell_type_pairs" set: {missing_ct_pairs}' + ) + + metabolite_set = set(cell_com_m_df["metabolite"].values) + metabolite_gene_pair_df = metabolite_gene_pair_df[ + metabolite_gene_pair_df["metabolite"].isin(metabolite_set) + ] + + if metabolite_gene_pair_df.empty: + if restrict_significance == "both": + raise ValueError( + "There are no significant gene pairs that belong to a significant metabolite." + ) + if restrict_significance == "gene pairs": + raise ValueError("There are no significant gene pairs.") + if restrict_significance == "metabolites": + raise ValueError("There are no significant metabolites.") + + genes = adata.uns["genes"] + gene_pairs_sig = [] + if gene_pairs: + for g1, g2 in gene_pairs: + g1 = tuple(g1) if isinstance(g1, list) else g1 + g2 = tuple(g2) if isinstance(g2, list) else g2 + if not metabolite_gene_pair_df[metabolite_gene_pair_df["gene_pair"] == (g1, g2)].empty: + gene_pairs_sig.append((g1, g2)) + + gene_pairs_sig_ind = [] + for pair in gene_pairs_sig: + idx1 = ( + [genes.index(g) for g in pair[0]] + if isinstance(pair[0], list) + else genes.index(pair[0]) + ) + idx2 = ( + [genes.index(g) for g in pair[1]] + if isinstance(pair[1], list) + else genes.index(pair[1]) + ) + gene_pairs_sig_ind.append((idx1, idx2)) + + cell_type_key = adata.uns.get("cell_type_key") + cell_types = obs[cell_type_key] + gene_pairs_per_ct_pair = adata.uns.get("gene_pairs_per_ct_pair", {}) + + weights = adata.obsp["weights"] + + used_ct_pairs = list(set(ct for cell_type_pair in cell_type_pairs for ct in cell_type_pair)) + all_cell_types = set(cell_types.unique()) + used_ct_pairs_set = set(used_ct_pairs) + if used_ct_pairs_set < all_cell_types: + keep_mask = cell_types[cells].isin(used_ct_pairs).values + keep_indices = np.where(keep_mask)[0] + weights = weights[keep_indices][:, keep_indices] + cells = cells[keep_indices] + cell_types = cell_types.loc[ + cells + ] # Eventually only keep the cell type pairs with at least one significant gene pair + + weights_ct_pairs = create_weights_ct_pairs( + weights.tocoo(), cell_types, cell_type_pairs, device + ) + + gene_pairs_per_ct_pair_sig = {} + for ct_pair in gene_pairs_per_ct_pair.keys(): + if ct_pair not in cell_type_pairs: + continue + cell_com_df_ct_pair = cell_com_gp_df[ + (cell_com_gp_df["Cell Type 1"] == ct_pair[0]) + & (cell_com_gp_df["Cell Type 2"] == ct_pair[1]) + ] + gene_pairs_per_ct_pair_sig[ct_pair] = [ + tuple(x) for x in cell_com_df_ct_pair[["Gene 1", "Gene 2"]].values + ] + + # Cell-type pair-specific indices + gene_pairs_ind_per_ct_pair_sig = defaultdict(list) + gene_pairs_per_ct_pair_sig_ind = defaultdict(list) + for ct_pair, gpairs in gene_pairs_per_ct_pair_sig.items(): + for pair in gpairs: + if pair not in gene_pairs_sig: + continue + idx = gene_pairs_sig.index(pair) + gene_pairs_ind_per_ct_pair_sig[ct_pair].append(gene_pairs_sig_ind[idx]) + gene_pairs_per_ct_pair_sig_ind[ct_pair].append(idx) + + adata.uns["gene_pairs_ind_per_ct_pair_sig"] = dict(gene_pairs_ind_per_ct_pair_sig) + adata.uns["gene_pairs_per_ct_pair_sig_ind"] = dict(gene_pairs_per_ct_pair_sig_ind) + + def make_hashable(pair): + return tuple(tuple(x) if isinstance(x, list) else x for x in pair) + + gene_pairs_sig_ind_set = {make_hashable(pair) for pair in gene_pairs_sig_ind} + ct_specific_gene_pairs = [ + i + for i, pairs in enumerate(gene_pairs_ind_per_ct_pair_sig.values()) + if {make_hashable(pair) for pair in pairs} < gene_pairs_sig_ind_set + ] + + gene_pair_dict = {} + for metabolite, group in metabolite_gene_pair_df.groupby("metabolite"): + idxs = ( + group["gene_pair"] + .apply(lambda gp: gene_pairs_sig.index(gp) if gp in gene_pairs_sig else None) + .dropna() + .tolist() + ) + idxs = [int(ind) for ind in idxs if ind is not None] + if idxs: + gene_pair_dict[metabolite] = idxs + metabolites = list(gene_pair_dict.keys()) + + gene_pair_to_metabolite_indices = defaultdict(set) + for met_idx, met in enumerate(metabolites): + for gene_pair_idx in gene_pair_dict.get(met, []): + gene_pair_to_metabolite_indices[gene_pair_idx].add(met_idx) + + ct_pair_to_metabolite_indices = {} + for ct_pair, gene_pair_indices in gene_pairs_per_ct_pair_sig_ind.items(): + met_indices = set() + for gi in gene_pair_indices: + met_indices.update(gene_pair_to_metabolite_indices.get(gi, [])) + ct_pair_to_metabolite_indices[ct_pair] = sorted(met_indices) + + def concat_tuple_elements(t, sep="_"): + return sep.join( + s for item in t for s in (item if isinstance(item, (list, tuple)) else [item]) + ) + + if test in ["parametric", "both"]: + if verbose: + print("Running the parametric test...") + + adata.uns["ct_interacting_cell_results"]["p"] = {"gp": {}, "m": {}} + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_p_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + num_umi = counts.sum(dim=0) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_sig_ind: + if isinstance(idx1, list): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, list): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + counts_1 = standardize_ct_counts( + adata[cells, :], counts_1, model, num_umi, sample_specific, cell_types + ) + counts_2 = standardize_ct_counts( + adata[cells, :], counts_2, model, num_umi, sample_specific, cell_types + ) + + cs_gp = torch.zeros( + (len(cell_type_pairs), counts_1.shape[1], counts_1.shape[0]), device=counts_1.device + ) + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + WX2t = torch.sparse.mm(W, counts_2.T) + cs_gp[ct_pair] = counts_1.T * WX2t + adata.uns["ct_interacting_cell_results"]["p"]["gp"]["cs"] = cs_gp.detach().cpu().numpy() + + cs_m = compute_metabolite_cs_ct( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_sig_ind, + ct_specific_gene_pairs, + interacting_cell_scores=True, + ) + adata.uns["ct_interacting_cell_results"]["p"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + column_names = [] + scores = [] + for i, ct_pair in enumerate(gene_pairs_per_ct_pair_sig_ind.keys()): + gp_list = gene_pairs_per_ct_pair_sig_ind[ct_pair] + for gp in gp_list: + if np.all( + adata.uns["ct_interacting_cell_results"]["p"]["gp"]["cs"][i, :, gp] == 0 + ): + continue + column_names.append( + f"{' - '.join(ct_pair)}: {concat_tuple_elements(gene_pairs_sig[gp])}" + ) + scores.append(adata.uns["ct_interacting_cell_results"]["p"]["gp"]["cs"][i, :, gp]) + + cs_gp_df = pd.DataFrame( + {column_names[i]: array for i, array in enumerate(scores)}, index=cells + ) + if used_ct_pairs_set < all_cell_types: + cs_gp_df = cs_gp_df.reindex(adata.obs_names, fill_value=0) + adata.obsm["ct_interacting_cell_results_p_gp_cs_df"] = cs_gp_df + + column_names = [] + scores = [] + for i, ct_pair in enumerate(ct_pair_to_metabolite_indices.keys()): + metab_list = ct_pair_to_metabolite_indices[ct_pair] + for metab in metab_list: + if np.all( + adata.uns["ct_interacting_cell_results"]["p"]["m"]["cs"][i, :, metab] == 0 + ): + continue + column_names.append(f"{' - '.join(ct_pair)}: {metabolites[metab]}") + scores.append( + adata.uns["ct_interacting_cell_results"]["p"]["m"]["cs"][i, :, metab] + ) + + cs_m_df = pd.DataFrame( + {column_names[i]: array for i, array in enumerate(scores)}, index=cells + ) + if used_ct_pairs_set < all_cell_types: + cs_m_df = cs_m_df.reindex(adata.obs_names, fill_value=0) + adata.obsm["ct_interacting_cell_results_p_m_cs_df"] = cs_m_df + + if verbose: + print("Parametric test finished.") + + if test in ["non-parametric", "both"]: + if verbose: + print("Running the non-parametric test...") + + adata.uns["ct_interacting_cell_results"]["np"] = {"gp": {}, "m": {}} + + # Load counts + counts = counts_from_anndata(adata[cells, genes], layer_key_np_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + + # Prepare counts_1 and counts_2 + counts_1 = [] + counts_2 = [] + for idx1, idx2 in gene_pairs_sig_ind: + if isinstance(idx1, tuple): + c1 = ( + counts[idx1, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx1, :] + 1e-8).mean(dim=0)) + ) + else: + c1 = counts[idx1, :] + if isinstance(idx2, tuple): + c2 = ( + counts[idx2, :].mean(dim=0) + if mean == "algebraic" + else torch.exp(torch.log(counts[idx2, :] + 1e-8).mean(dim=0)) + ) + else: + c2 = counts[idx2, :] + counts_1.append(c1) + counts_2.append(c2) + + counts_1 = torch.stack(counts_1) + counts_2 = torch.stack(counts_2) + + if center_counts_for_np_test: + num_umi = counts.sum(dim=0) + _lazy_import_hotspot() + counts_1 = standardize_counts(adata, counts_1, model, num_umi, sample_specific) + _lazy_import_hotspot() + counts_2 = standardize_counts(adata, counts_2, model, num_umi, sample_specific) + + n_cells = counts_1.shape[1] + same_gene_mask = torch.tensor([g1 == g2 for g1, g2 in gene_pairs_sig], device=device) + + if center_counts_for_np_test and test == "both": + adata.uns["ct_interacting_cell_results"]["np"]["gp"]["cs"] = np.array( + adata.uns["ct_interacting_cell_results"]["p"]["gp"]["cs"] + ) + adata.uns["ct_interacting_cell_results"]["np"]["m"]["cs"] = np.array( + adata.uns["ct_interacting_cell_results"]["p"]["m"]["cs"] + ) + else: + cs_gp = torch.zeros( + (len(cell_type_pairs), counts_1.shape[1], counts_1.shape[0]), + device=counts_1.device, + ) + for ct_pair in range(len(cell_type_pairs)): + W = weights_ct_pairs[ct_pair].coalesce() + WX2t = torch.sparse.mm(W, counts_2.T) + cs_gp[ct_pair] = counts_1.T * WX2t + adata.uns["ct_interacting_cell_results"]["np"]["gp"]["cs"] = ( + cs_gp.detach().cpu().numpy() + ) + cs_m = compute_metabolite_cs_ct( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_sig_ind, + ct_specific_gene_pairs, + interacting_cell_scores=True, + ) + adata.uns["ct_interacting_cell_results"]["np"]["m"]["cs"] = cs_m.detach().cpu().numpy() + + column_names = [] + scores = [] + for i, ct_pair in enumerate(gene_pairs_per_ct_pair_sig_ind.keys()): + gp_list = gene_pairs_per_ct_pair_sig_ind[ct_pair] + for gp in gp_list: + if np.all( + adata.uns["ct_interacting_cell_results"]["np"]["gp"]["cs"][i, :, gp] == 0 + ): + continue + column_names.append( + f"{' - '.join(ct_pair)}: {concat_tuple_elements(gene_pairs_sig[gp])}" + ) + scores.append(adata.uns["ct_interacting_cell_results"]["np"]["gp"]["cs"][i, :, gp]) + + cs_gp_df = pd.DataFrame( + {column_names[i]: array for i, array in enumerate(scores)}, index=cells + ) + if used_ct_pairs_set < all_cell_types: + cs_gp_df = cs_gp_df.reindex(adata.obs_names, fill_value=0) + adata.obsm["ct_interacting_cell_results_np_gp_cs_df"] = cs_gp_df + + column_names = [] + scores = [] + for i, ct_pair in enumerate(ct_pair_to_metabolite_indices.keys()): + metab_list = ct_pair_to_metabolite_indices[ct_pair] + for metab in metab_list: + if np.all( + adata.uns["ct_interacting_cell_results"]["np"]["m"]["cs"][i, :, metab] == 0 + ): + continue + column_names.append(f"{' - '.join(ct_pair)}: {metabolites[metab]}") + scores.append( + adata.uns["ct_interacting_cell_results"]["np"]["m"]["cs"][i, :, metab] + ) + + cs_m_df = pd.DataFrame( + {column_names[i]: array for i, array in enumerate(scores)}, index=cells + ) + if used_ct_pairs_set < all_cell_types: + cs_m_df = cs_m_df.reindex(adata.obs_names, fill_value=0) + adata.obsm["ct_interacting_cell_results_np_m_cs_df"] = cs_m_df + + if verbose: + print("Non-parametric test finished.") + + if verbose: + print( + "Finished computing cell type-aware gene pair and metabolite scores in %.3f seconds" + % (time.time() - start) + ) + + return + + +def compute_metabolite_cs( + cs_gp: torch.Tensor, gene_pair_dict: dict, interacting_cell_scores: bool = False +) -> torch.Tensor: + """ + Computes metabolite-level communication scores from gene-pair scores. + + Parameters + ---------- + cs_gp : torch.Tensor + - If interacting_cell_scores is False: shape (gene_pairs,) + - If interacting_cell_scores is True: shape (cells, gene_pairs) + gene_pair_dict : dict + Maps metabolite names to a list of indices (ints) referring to gene-pairs. + interacting_cell_scores : bool, optional + Whether cs_gp contains per-cell scores. + + Returns + ------- + cs_m : torch.Tensor + - If interacting_cell_scores is False: shape (num_metabolites,) + - If interacting_cell_scores is True: shape (cells, num_metabolites) + """ + device = cs_gp.device + scores = [] + + for indices in gene_pair_dict.values(): + idx_tensor = torch.tensor(indices, device=device, dtype=torch.long) + if interacting_cell_scores: + summed = cs_gp[:, idx_tensor].sum(dim=1) # shape: (cells,) + else: + summed = cs_gp[idx_tensor].sum() # scalar + scores.append(summed) + + if interacting_cell_scores: + cs_m = torch.stack(scores, dim=1) # shape: (cells, metabolites) + else: + cs_m = torch.stack(scores) # shape: (metabolites,) + + return cs_m + + +def compute_metabolite_cs_ct( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_ind=None, + ct_specific_gene_pairs=None, + interacting_cell_scores=False, +): + if cell_type_key and ct_specific_gene_pairs: + for i, ct_pair in enumerate(gene_pairs_per_ct_pair_ind.keys()): + if i not in ct_specific_gene_pairs: + continue + mask_dim = 2 if interacting_cell_scores else 1 + mask = np.ones(cs_gp.shape[mask_dim], dtype=bool) + mask[gene_pairs_per_ct_pair_ind[ct_pair]] = False + if interacting_cell_scores: + cs_gp[i, :, mask] = 0 + else: + cs_gp[i, mask] = 0 + + device = cs_gp.device + scores = [] + + for indices in gene_pair_dict.values(): + idx_tensor = torch.tensor(indices, device=device, dtype=torch.long) + if interacting_cell_scores: + summed = cs_gp[:, :, idx_tensor].sum(dim=2) # shape: (cells,) + else: + summed = cs_gp[:, idx_tensor].sum(dim=1) # scalar + scores.append(summed) + + if interacting_cell_scores: + cs_m = torch.stack(scores, dim=2) # shape: (cells, metabolites) + else: + cs_m = torch.stack(scores, dim=1) + + return cs_m + + +def compute_metabolite_cs_old( + cs_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_ind=None, + ct_specific_gene_pairs=None, + interacting_cell_scores=False, +): + if cell_type_key and ct_specific_gene_pairs: + for i, ct_pair in enumerate(gene_pairs_per_ct_pair_ind.keys()): + if i not in ct_specific_gene_pairs: + continue + mask = np.ones(cs_gp.shape[1], dtype=bool) + mask[gene_pairs_per_ct_pair_ind[ct_pair]] = False + cs_gp[i, mask] = 0 + + cells_metabolites = [] + for metabolite, gene_pair_indices in gene_pair_dict.items(): + if interacting_cell_scores: + summed_values = ( + cs_gp[:, :, gene_pair_indices].sum(axis=2) + if cell_type_key + else cs_gp[:, gene_pair_indices].sum(axis=1) + ) + cells_metabolites.append(summed_values) + else: + summed_values = ( + cs_gp[:, gene_pair_indices].sum(axis=1) + if cell_type_key + else cs_gp[gene_pair_indices].sum(axis=0) + ) + cells_metabolites.append(summed_values) + if interacting_cell_scores: + axis = 2 if cell_type_key else 1 + else: + axis = 1 if cell_type_key else 0 + cs_m = np.stack(cells_metabolites, axis=axis) + + return cs_m + + +def ensure_tuple(x): + return tuple(tuple(i) if isinstance(i, list) else i for i in x) + + +def compute_CCC_scores( + counts_1: np.array, + counts_2: np.array, + weights: sparse.COO, + gene_pairs: list, +): + + if len(weights.shape) == 3: + scores = (counts_1.T * np.tensordot(weights, counts_2.T, axes=([2], [0]))).sum(axis=1) + else: + same_gene_mask = np.array([pair1 == pair2 for pair1, pair2 in gene_pairs]) + scores = (counts_1.T * (weights @ counts_2.T)).sum(axis=0) + ( + counts_1.T * (weights.T @ counts_2.T) + ).sum(axis=0) + scores[same_gene_mask] = scores[same_gene_mask] / 2 + + return scores + + +def compute_int_CCC_scores( + counts_1: np.array, + counts_2: np.array, + weights: sparse.COO, + gene_pairs: list, +): + + if len(weights.shape) == 3: + scores = counts_1.T * np.tensordot(weights, counts_2.T, axes=([2], [0])) + else: + same_gene_mask = np.array([pair1 == pair2 for pair1, pair2 in gene_pairs]) + scores = (counts_1.T * (weights @ counts_2.T)) + (counts_1.T * (weights.T @ counts_2.T)) + scores[:, same_gene_mask] = scores[:, same_gene_mask] / 2 + + return scores + + +def get_ct_pair_weights(weights, cell_type_pairs, cell_types): + + w_nrow, w_ncol = weights.shape + n_ct_pairs = len(cell_type_pairs) + + extract_weights_results = partial( + extract_ct_pair_weights, + weights=weights, + cell_type_pairs=cell_type_pairs, + cell_types=cell_types, + ) + results = list(map(extract_weights_results, cell_type_pairs)) + + w_new_data_all = [x[0] for x in results] + w_new_coords_3d_all = [x[1] for x in results] + w_new_coords_3d_all = np.hstack(w_new_coords_3d_all) + w_new_data_all = np.concatenate(w_new_data_all) + + weights_ct_pairs = sparse.COO( + w_new_coords_3d_all, w_new_data_all, shape=(n_ct_pairs, w_nrow, w_ncol) + ) + + return weights_ct_pairs + + +def extract_ct_pair_weights(ct_pair, weights, cell_type_pairs, cell_types): + + i = cell_type_pairs.index(ct_pair) + + ct_t, ct_u = cell_type_pairs[i] + ct_t_mask = cell_types.values == ct_t + ct_t_mask_coords = np.argwhere(ct_t_mask) + n_ct_t = len(ct_t_mask_coords) + ct_u_mask = cell_types.values == ct_u + ct_u_mask_coords = np.argwhere(ct_u_mask) + n_ct_u = len(ct_u_mask_coords) + + w_old_coords = weights.coords + + w_row_coords, w_col_coords = np.meshgrid(ct_t_mask_coords, ct_u_mask_coords, indexing="ij") + w_row_coords = w_row_coords.ravel() + w_col_coords = w_col_coords.ravel() + w_new_coords = np.vstack((w_row_coords, w_col_coords)) + + # w_matching_indices = np.where(np.all(np.isin(w_old_coords.T, w_new_coords.T), axis=1))[0] + w_matching_indices = np.isin(w_old_coords[0], ct_t_mask_coords.flatten()) & np.isin( + w_old_coords[1], ct_u_mask_coords.flatten() + ) + w_new_data = weights.data[w_matching_indices] + w_new_coords = w_old_coords[:, w_matching_indices] + + w_coord_3d = np.full(w_new_coords.shape[1], fill_value=i) + w_new_coords_3d = np.vstack((w_coord_3d, w_new_coords)) + + return (w_new_data, w_new_coords_3d) + + +def get_interacting_cell_type_pairs(x, weights, cell_types): + ct_1, ct_2 = x + + ct_1_bin = cell_types == ct_1 + ct_2_bin = cell_types == ct_2 + + weights = weights.tocsc() + cell_types_weights = weights[ct_1_bin,][:, ct_2_bin] + + return bool(cell_types_weights.nnz) + + +def conditional_eg2_cellcom_gp(counts_1, counts_2, weights): + + if len(weights.shape) == 3: + counts_1_sq = counts_1**2 + counts_2_sq = counts_2**2 + weights_sq_data = weights.data**2 + weights_sq = sparse.COO(weights.coords, weights_sq_data, shape=weights.shape) + out_eg2_a = np.tensordot(counts_1_sq, weights_sq, axes=([1], [1])).sum(axis=2).T + out_eg2_b = np.tensordot(weights_sq, counts_2_sq, axes=([2], [1])).sum(axis=1) + out_eg2s = (out_eg2_a, out_eg2_b) + else: + out_eg2_a = (((weights + weights.T) @ counts_1.T) ** 2).sum(axis=0) + out_eg2_b = (((weights + weights.T) @ counts_2.T) ** 2).sum(axis=0) + out_eg2s = (out_eg2_a, out_eg2_b) + + return out_eg2s + + +def conditional_eg2_gp_score(counts, weights): + + counts_sq = counts**2 + if len(weights.shape) == 3: + # weights_t = weights.transpose(axes=(0, 2, 1)) + # weights = weights + weights_t + weights_sq_data = weights.data**2 + weights_sq = sparse.COO(weights.coords, weights_sq_data, shape=weights.shape) + out_eg2_a = np.transpose(np.tensordot(counts_sq, weights_sq, axes=([1], [1])), (1, 2, 0)) + out_eg2_b = np.tensordot(weights_sq, counts_sq, axes=([2], [1])) + out_eg2s = (out_eg2_a, out_eg2_b) + else: + out_eg2s = ((weights + weights.T) @ counts.T) ** 2 + + return out_eg2s + + +def compute_ct_p_results( + C_gp, + C_m, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + EG2_gp, + cell_type_key, + gene_pair_dict, +): + + EG2_gp = EG2_gp.unsqueeze(1).expand(-1, C_gp.shape[1]) if len(EG2_gp.shape) == 1 else EG2_gp + + stdG = torch.sqrt(EG2_gp) + stdG[stdG == 0] = 1 + + Z_gp = C_gp / stdG + + EG2_m = compute_metabolite_cs_ct( + EG2_gp, + cell_type_key, + gene_pair_dict, + gene_pairs_per_ct_pair_ind, + ct_specific_gene_pairs, + interacting_cell_scores=False, + ) + if not isinstance(EG2_m, torch.Tensor): + device = EG2_gp.device + EG2_m = torch.tensor(EG2_m, device=device, dtype=torch.float64) + + stdG_m = torch.sqrt(EG2_m) + stdG_m[stdG_m == 0] = 1 + + Z_m = C_m / stdG_m + + return Z_gp, Z_m + + +def compute_p_results(C_gp, C_m, gene_pairs_ind, Wtot2, eg2s_gp, gene_pair_dict): + + device = Wtot2.device + n_gp = len(gene_pairs_ind) + + # Convert indices + same_gene_mask = torch.tensor( + [ + (isinstance(g1, int) and isinstance(g2, int) and g1 == g2) + or (isinstance(g1, list) and isinstance(g2, list) and sorted(g1) == sorted(g2)) + for g1, g2 in gene_pairs_ind + ], + device=device, + ) + + # Unpack second moments + EG2_a = eg2s_gp[0].clone() + EG2_b = eg2s_gp[1].clone() + EG2_a[same_gene_mask] = Wtot2 + EG2_b[same_gene_mask] = Wtot2 + + stdG_a = torch.sqrt(EG2_a) + stdG_b = torch.sqrt(EG2_b) + stdG_a[stdG_a == 0] = 1 + stdG_b[stdG_b == 0] = 1 + + # Compute gene-pair Z-scores + if isinstance(C_gp, tuple): + C_gp_0, C_gp_1 = C_gp + z_0 = C_gp_0 / stdG_a + z_1 = C_gp_1 / stdG_b + mask = torch.abs(z_0) < torch.abs(z_1) + Z_gp = torch.where(mask, z_0, z_1) + EG2_gp = torch.where(mask, EG2_a, EG2_b) + else: + C_gp = C_gp + z_a = C_gp / stdG_a + z_b = C_gp / stdG_b + mask = torch.abs(z_a) < torch.abs(z_b) + Z_gp = torch.where(mask, z_a, z_b) + EG2_gp = torch.where(mask, EG2_a, EG2_b) + + # Compute metabolite-level expected variance + EG2_m = compute_metabolite_cs(EG2_gp, gene_pair_dict, interacting_cell_scores=False) + if not isinstance(EG2_m, torch.Tensor): + EG2_m = torch.tensor(EG2_m, device=device, dtype=torch.float64) + + stdG_m = torch.sqrt(EG2_m) + stdG_m[stdG_m == 0] = 1 + + # Compute metabolite Z-scores + if isinstance(C_m, tuple): + C_m_0, C_m_1 = C_m + z_0 = C_m_0 / stdG_m + z_1 = C_m_1 / stdG_m + Z_m = torch.where(torch.abs(z_0) < torch.abs(z_1), z_0, z_1) + else: + Z_m = C_m / stdG_m + + return Z_gp, Z_m + + +def compute_p_results_old( + ct_pair, + cell_type_pairs, + cs_gp, + cs_m, + gene_pairs_ind, + gene_pairs_ind_per_ct_pair, + Wtot2, + eg2s_gp, + cell_type_key, + gene_pair_dict, +): + i = cell_type_pairs.index(ct_pair) + gene_pair_cor_ct = ( + (cs_gp[0][i, :], cs_gp[1][i, :]) if isinstance(cs_gp, tuple) else cs_gp[i, :] + ) + C_m = (cs_m[0][i, :], cs_m[1][i, :]) if isinstance(cs_m, tuple) else cs_m[i, :] + gene_pairs_ind_ct_pair = gene_pairs_ind_per_ct_pair[ + ct_pair + ] # If we consider all the gene pairs (irrespective of the cell type pair) use 'gene_pairs_ind' directly + + eg2s_a, eg2s_b = eg2s_gp + C_gp = [] + EG2_a = [] + EG2_b = [] + for gene_pair_ind_ct_pair in gene_pairs_ind_ct_pair: + idx = gene_pairs_ind.index(gene_pair_ind_ct_pair) + g1_ind, g2_ind = gene_pair_ind_ct_pair + lc_gp = ( + (gene_pair_cor_ct[0][idx], gene_pair_cor_ct[1][idx]) + if isinstance(gene_pair_cor_ct, tuple) + else gene_pair_cor_ct[idx] + ) + eg2_a = eg2_b = Wtot2[i] + # if g1_ind == g2_ind: + # eg2_a = eg2_b = Wtot2[i] + # else: + # eg2_a = eg2s_a[i, idx] + # eg2_b = eg2s_b[i, idx] + C_gp.append(lc_gp) + EG2_a.append(eg2_a) + EG2_b.append(eg2_b) + + # Gene pairs + + EG = [0 for i in range(len(gene_pairs_ind_ct_pair))] + + stdG_a = [(EG2_a[i] - EG[i] ** 2) ** 0.5 for i in range(len(gene_pairs_ind_ct_pair))] + stdG_a = [1 if stdG_a[i] == 0 else stdG_a[i] for i in range(len(stdG_a))] + + stdG_b = [(EG2_b[i] - EG[i] ** 2) ** 0.5 for i in range(len(gene_pairs_ind_ct_pair))] + stdG_b = [1 if stdG_b[i] == 0 else stdG_b[i] for i in range(len(stdG_b))] + + if isinstance(C_gp[0], tuple): + Z_gp_a = [(C_gp[i][0] - EG[i]) / stdG_a[i] for i in range(len(gene_pairs_ind_ct_pair))] + Z_gp_b = [(C_gp[i][1] - EG[i]) / stdG_b[i] for i in range(len(gene_pairs_ind_ct_pair))] + else: + Z_gp_a = [(C_gp[i] - EG[i]) / stdG_a[i] for i in range(len(gene_pairs_ind_ct_pair))] + Z_gp_b = [(C_gp[i] - EG[i]) / stdG_b[i] for i in range(len(gene_pairs_ind_ct_pair))] + + EG2_m_a = compute_metabolite_cs_old( + np.array(EG2_a), + cell_type_key=None, + gene_pair_dict=gene_pair_dict, + interacting_cell_scores=False, + ) + EG2_m_b = compute_metabolite_cs_old( + np.array(EG2_b), + cell_type_key=None, + gene_pair_dict=gene_pair_dict, + interacting_cell_scores=False, + ) + + # Metabolites + + EG = [0 for i in range(len(gene_pair_dict.keys()))] + + stdG_a = [(EG2_m_a[i] - EG[i] ** 2) ** 0.5 for i in range(len(gene_pair_dict.keys()))] + stdG_a = [1 if stdG_a[i] == 0 else stdG_a[i] for i in range(len(stdG_a))] + + stdG_b = [(EG2_m_b[i] - EG[i] ** 2) ** 0.5 for i in range(len(gene_pair_dict.keys()))] + stdG_b = [1 if stdG_b[i] == 0 else stdG_b[i] for i in range(len(stdG_b))] + + if isinstance(C_m, tuple): + Z_m_a = [(C_m[0][i] - EG[i]) / stdG_a[i] for i in range(len(gene_pair_dict.keys()))] + Z_m_b = [(C_m[1][i] - EG[i]) / stdG_b[i] for i in range(len(gene_pair_dict.keys()))] + else: + Z_m_a = [(C_m[i] - EG[i]) / stdG_a[i] for i in range(len(gene_pair_dict.keys()))] + Z_m_b = [(C_m[i] - EG[i]) / stdG_b[i] for i in range(len(gene_pair_dict.keys()))] + + return (C_gp, Z_gp_a, Z_gp_b, C_m, Z_m_a, Z_m_b) + + +def compute_p_int_cell_results( + ct_pair, + cell_type_pairs, + cs_gp, + cs_m, + gene_pairs_ind, + gene_pairs_ind_per_ct_pair, + Wtot2, + eg2s_gp, + cell_type_key, + gene_pair_dict, +): + i = cell_type_pairs.index(ct_pair) + gene_pair_cor_ct = cs_gp[i, :, :] + C_m = cs_m[i, :, :] + gene_pairs_ind_ct_pair = gene_pairs_ind_per_ct_pair[ + ct_pair + ] # If we consider all the gene pairs (irrespective of the cell type pair) use 'gene_pairs_ind' directly + + eg2s_a, eg2s_b = eg2s_gp + C_gp = [] + EG2_a = [] + EG2_b = [] + for gene_pair_ind_ct_pair in gene_pairs_ind_ct_pair: + idx = gene_pairs_ind.index(gene_pair_ind_ct_pair) + g1_ind, g2_ind = gene_pair_ind_ct_pair + lc_gp = gene_pair_cor_ct[:, idx] + if g1_ind == g2_ind: + eg2_a = eg2_b = Wtot2[i, :] + else: + eg2_a = ( + eg2s_a[i, :, g1_ind] + if type(g1_ind) is not list + else np.max(eg2s_a[i, :, g1_ind], axis=0) + ) + eg2_b = ( + eg2s_b[i, :, g2_ind] + if type(g2_ind) is not list + else np.max(eg2s_b[i, :, g2_ind], axis=0) + ) + C_gp.append(lc_gp) + EG2_a.append(eg2_a) + EG2_b.append(eg2_b) + C_gp = np.column_stack(C_gp) + EG2_a = np.column_stack(EG2_a) + EG2_b = np.column_stack(EG2_b) + + # Gene pairs + + EG = np.zeros(C_gp.shape) + + stdG_a = (EG2_a - EG**2) ** 0.5 + stdG_a[stdG_a == 0] = 1 + + stdG_b = (EG2_b - EG**2) ** 0.5 + stdG_b[stdG_b == 0] = 1 + + Z_gp = np.where( + np.abs((C_gp - EG) / stdG_a) < np.abs((C_gp - EG) / stdG_b), + (C_gp - EG) / stdG_a, + (C_gp - EG) / stdG_b, + ) + + EG2_gp = np.where(np.abs((C_gp - EG) / stdG_a) < np.abs((C_gp - EG) / stdG_b), EG2_a, EG2_b) + + EG2_m = compute_metabolite_cs( + EG2_gp, cell_type_key=None, gene_pair_dict=gene_pair_dict, interacting_cell_scores=True + ) + + # Metabolites + + EG = np.zeros(C_m.shape) + + stdG = (EG2_m - EG**2) ** 0.5 + stdG[stdG == 0] = 1 + + Z_m = (C_m - EG) / stdG + + return (C_gp, Z_gp, C_m, Z_m) + + +def compute_p_int_cell_results_no_ct(C_gp, C_m, gene_pairs_ind, Wtot2, eg2s_gp, gene_pair_dict): + + device = Wtot2.device + n_gp = len(gene_pairs_ind) + + # Convert indices + same_gene_mask = torch.tensor( + [ + (isinstance(g1, int) and isinstance(g2, int) and g1 == g2) + or (isinstance(g1, list) and isinstance(g2, list) and sorted(g1) == sorted(g2)) + for g1, g2 in gene_pairs_ind + ], + device=device, + ) + + # Unpack second moments + EG2_a = eg2s_gp[0].clone() + EG2_b = eg2s_gp[1].clone() + EG2_a[:, same_gene_mask] = Wtot2 + EG2_b[:, same_gene_mask] = Wtot2 + + stdG_a = torch.sqrt(EG2_a) + stdG_b = torch.sqrt(EG2_b) + stdG_a[stdG_a == 0] = 1 + stdG_b[stdG_b == 0] = 1 + + # Compute gene-pair Z-scores + if isinstance(C_gp, tuple): + C_gp_0, C_gp_1 = C_gp + z_0 = C_gp_0 / stdG_a + z_1 = C_gp_1 / stdG_b + mask = torch.abs(z_0) < torch.abs(z_1) + Z_gp = torch.where(mask, z_0, z_1) + EG2_gp = torch.where(mask, EG2_a, EG2_b) + else: + C_gp = C_gp + z_a = C_gp / stdG_a + z_b = C_gp / stdG_b + mask = torch.abs(z_a) < torch.abs(z_b) + Z_gp = torch.where(mask, z_a, z_b) + EG2_gp = torch.where(mask, EG2_a, EG2_b) + + # Compute metabolite-level expected variance + EG2_m = compute_metabolite_cs(EG2_gp, gene_pair_dict, interacting_cell_scores=True) + if not isinstance(EG2_m, torch.Tensor): + EG2_m = torch.tensor(EG2_m, device=device, dtype=torch.float64) + + stdG_m = torch.sqrt(EG2_m) + stdG_m[stdG_m == 0] = 1 + + # Compute metabolite Z-scores + if isinstance(C_m, tuple): + C_m_0, C_m_1 = C_m + z_0 = C_m_0 / stdG_m + z_1 = C_m_1 / stdG_m + Z_m = torch.where(torch.abs(z_0) < torch.abs(z_1), z_0, z_1) + else: + Z_m = C_m / stdG_m + + return (Z_gp, Z_m) + + +def compute_np_results( + ct_pair, + cell_type_pairs, + cs_gp, + cs_m, + pvals_gp, + pvals_m, + gene_pair_dict, + gene_pairs_ind, + gene_pairs_ind_per_ct_pair, +): + i = cell_type_pairs.index(ct_pair) + gene_pair_cor_gp_ct = cs_gp[i, :] + C_m = cs_m[i, :] + pvals_gp_a, pvals_gp_b = pvals_gp + pvals_gp_a_ct = pvals_gp_a[i, :] + pvals_gp_b_ct = pvals_gp_b[i, :] + pvals_m_a, pvals_m_b = pvals_m + p_values_m_a = pvals_m_a[i, :] + p_values_m_b = pvals_m_b[i, :] + gene_pairs_ind_ct_pair = gene_pairs_ind_per_ct_pair[ + ct_pair + ] # If we consider all the gene pairs (irrespective of the cell type pair) use 'gene_pairs_ind' directly + + C_gp = [] + p_values_gp_a = [] + p_values_gp_b = [] + for gene_pair_ind_ct_pair in gene_pairs_ind_ct_pair: + idx = gene_pairs_ind.index(gene_pair_ind_ct_pair) + lc_gp = gene_pair_cor_gp_ct[idx] + p_value_gp_a = pvals_gp_a_ct[idx] + p_value_gp_b = pvals_gp_b_ct[idx] + + C_gp.append(lc_gp.reshape(1)) + p_values_gp_a.append(p_value_gp_a.reshape(1)) + p_values_gp_b.append(p_value_gp_b.reshape(1)) + + C_gp = list(np.concatenate(C_gp)) + p_values_gp_a = list(np.concatenate(p_values_gp_a)) + p_values_gp_b = list(np.concatenate(p_values_gp_b)) + + # C_m = compute_metabolite_cs(np.array(C_gp), cell_type_key=None, gene_pair_dict=gene_pair_dict, interacting_cell_scores=False) + + return (C_gp, p_values_gp_a, p_values_gp_b, C_m, p_values_m_a, p_values_m_a) + + +def get_ct_cell_communication_results( + adata, + genes, + cells, + layer_key_p_test, + layer_key_np_test, + model, + cell_types, + cell_type_pairs, + D, + test, + device, +): + + gene_pairs_ind_per_ct_pair = adata.uns["gene_pairs_ind_per_ct_pair"] + gene_pair_dict = adata.uns["gene_pair_dict"] + genes = adata.uns["genes"] + + sample_specific = "sample_key" in adata.uns + + if isinstance(D, np.ndarray): + D = torch.tensor(D, dtype=torch.float64, device=device) + + def idx_to_gene(idx): + return [genes[i] for i in idx] if isinstance(idx, list) else genes[idx] + + records = [ + { + "Cell Type 1": ct1, + "Cell Type 2": ct2, + "Gene 1": idx_to_gene(gp[0]), + "Gene 2": idx_to_gene(gp[1]), + } + for (ct1, ct2), gp_list in gene_pairs_ind_per_ct_pair.items() + for gp in gp_list + ] + cell_com_df_gp = pd.DataFrame.from_records(records) + + # Generate metabolite interaction table + ct_pairs = list(gene_pairs_ind_per_ct_pair.keys()) + metabolites = list(gene_pair_dict.keys()) + cell_com_df_m = pd.DataFrame( + [ + {"Cell Type 1": ct1, "Cell Type 2": ct2, "metabolite": m} + for (ct1, ct2), m in itertools.product(ct_pairs, metabolites) + ] + ) + + if test in ["parametric", "both"]: + suffix = "p" + # Gene pair + c_values = adata.uns["ct_ccc_results"][suffix]["gp"]["cs"] + z_values = adata.uns["ct_ccc_results"][suffix]["gp"]["Z"] + p_values = adata.uns["ct_ccc_results"][suffix]["gp"]["Z_pval"] + fdr_values = adata.uns["ct_ccc_results"][suffix]["gp"]["Z_FDR"] + cell_com_df_gp[f"C_{suffix}"] = c_values.flatten() + cell_com_df_gp["Z"] = z_values.flatten() + cell_com_df_gp["Z_pval"] = p_values.flatten() + cell_com_df_gp["Z_FDR"] = fdr_values.flatten() + + counts = counts_from_anndata(adata[:, genes], layer_key_p_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + num_umi = counts.sum(dim=0) + counts_std = standardize_ct_counts( + adata, counts, model, num_umi, sample_specific, cell_types + ) + + c_values_norm = normalize_ct_values( + counts_std, cell_types, cell_type_pairs, gene_pairs_ind_per_ct_pair, c_values, D + ) + adata.uns["ct_ccc_results"][suffix]["gp"]["cs_norm"] = c_values_norm.cpu().numpy() + cell_com_df_gp[f"C_norm_{suffix}"] = c_values_norm.cpu().numpy().flatten() + + # Metabolite + c_values = adata.uns["ct_ccc_results"][suffix]["m"]["cs"] + z_values = adata.uns["ct_ccc_results"][suffix]["m"]["Z"] + p_values = adata.uns["ct_ccc_results"][suffix]["m"]["Z_pval"] + fdr_values = adata.uns["ct_ccc_results"][suffix]["m"]["Z_FDR"] + cell_com_df_m[f"C_{suffix}"] = c_values.flatten() + cell_com_df_m["Z"] = z_values.flatten() + cell_com_df_m["Z_pval"] = p_values.flatten() + cell_com_df_m["Z_FDR"] = fdr_values.flatten() + + if test in ["non-parametric", "both"]: + suffix = "np" + # Gene pair + c_values = adata.uns["ct_ccc_results"][suffix]["gp"]["cs"] + p_values = adata.uns["ct_ccc_results"][suffix]["gp"]["pval"] + fdr_values = adata.uns["ct_ccc_results"][suffix]["gp"]["FDR"] + cell_com_df_gp[f"C_{suffix}"] = c_values.flatten() + cell_com_df_gp[f"pval_{suffix}"] = p_values.flatten() + cell_com_df_gp[f"FDR_{suffix}"] = fdr_values.flatten() + + counts = counts_from_anndata(adata[:, genes], layer_key_np_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + if adata.uns.get("center_counts_for_np_test", False): + num_umi = counts.sum(dim=0) + counts = standardize_ct_counts( + adata, counts, model, num_umi, sample_specific, cell_types + ) + + c_values_norm = normalize_ct_values( + counts, cell_types, cell_type_pairs, gene_pairs_ind_per_ct_pair, c_values, D + ) + adata.uns["ct_ccc_results"][suffix]["gp"]["cs_norm"] = c_values_norm.cpu().numpy() + cell_com_df_gp[f"C_norm_{suffix}"] = c_values_norm.cpu().numpy().flatten() + + # Metabolite + c_values = adata.uns["ct_ccc_results"][suffix]["m"]["cs"] + p_values = adata.uns["ct_ccc_results"][suffix]["m"]["pval"] + fdr_values = adata.uns["ct_ccc_results"][suffix]["m"]["FDR"] + cell_com_df_m[f"C_{suffix}"] = c_values.flatten() + cell_com_df_m[f"pval_{suffix}"] = p_values.flatten() + cell_com_df_m[f"FDR_{suffix}"] = fdr_values.flatten() + + adata.uns["ct_ccc_results"]["cell_com_df_gp"] = cell_com_df_gp + adata.uns["ct_ccc_results"]["cell_com_df_m"] = cell_com_df_m + + return + + +def get_cell_communication_results_old( + adata, + genes, + cells, + layer_key_p_test, + layer_key_np_test, + model, + cell_types, + cell_type_pairs, + D, + test, +): + + gene_pairs_ind_per_ct_pair = adata.uns["gene_pairs_ind_per_ct_pair"] + gene_pair_dict = adata.uns["gene_pair_dict"] + genes = adata.uns["genes"] + + def map_to_genes(x): + if isinstance(x, list): + return [genes[i] for i in x] + else: + return genes[x] + + cell_com_df_gp = ( + pd.DataFrame.from_dict(gene_pairs_ind_per_ct_pair, orient="index") + .stack() + .to_frame() + .reset_index() + ) + cell_com_df_gp = cell_com_df_gp.drop(["level_1"], axis=1) + cell_com_df_gp = cell_com_df_gp.rename(columns={"level_0": "cell_type_pair", 0: "gene_pair"}) + cell_com_df_gp["Cell Type 1"], cell_com_df_gp["Cell Type 2"] = zip( + *cell_com_df_gp["cell_type_pair"] + ) + cell_com_df_gp["Gene 1"], cell_com_df_gp["Gene 2"] = zip(*cell_com_df_gp["gene_pair"]) + cell_com_df_gp["Gene 1"] = cell_com_df_gp["Gene 1"].apply(map_to_genes) + cell_com_df_gp["Gene 2"] = cell_com_df_gp["Gene 2"].apply(map_to_genes) + cell_com_df_gp = cell_com_df_gp.drop(["cell_type_pair", "gene_pair"], axis=1) + + ct_pair_metab = list( + itertools.product(gene_pairs_ind_per_ct_pair.keys(), gene_pair_dict.keys()) + ) + cell_com_df_m = pd.DataFrame(ct_pair_metab, columns=["cell_type_pair", "metabolite"]) + cell_com_df_m["Cell Type 1"], cell_com_df_m["Cell Type 2"] = zip( + *cell_com_df_m["cell_type_pair"] + ) + cell_com_df_m = cell_com_df_m.drop(["cell_type_pair"], axis=1) + + if test in ["parametric", "both"]: + # Gene pair + c_values = adata.uns["ct_ccc_results"]["p"]["gp"]["cs"] + cell_com_df_gp["C_p"] = c_values.flatten() + z_values_a = adata.uns["ct_ccc_results"]["p"]["gp"]["Z_a"] + cell_com_df_gp["Z_a"] = z_values_a.flatten() + z_values_b = adata.uns["ct_ccc_results"]["p"]["gp"]["Z_b"] + cell_com_df_gp["Z_b"] = z_values_b.flatten() + p_values_a = adata.uns["ct_ccc_results"]["p"]["gp"]["Z_pval_a"] + cell_com_df_gp["Z_pval_a"] = p_values_a.flatten() + p_values_b = adata.uns["ct_ccc_results"]["p"]["gp"]["Z_pval_b"] + cell_com_df_gp["Z_pval_b"] = p_values_b.flatten() + FDR_values_a = adata.uns["ct_ccc_results"]["p"]["gp"]["Z_FDR_a"] + cell_com_df_gp["Z_FDR_a"] = FDR_values_a.flatten() + FDR_values_b = adata.uns["ct_ccc_results"]["p"]["gp"]["Z_FDR_b"] + cell_com_df_gp["Z_FDR_b"] = FDR_values_b.flatten() + + counts = counts_from_anndata(adata[cells, genes], layer_key_p_test, dense=True) + num_umi = counts.sum(axis=0) + counts_std = counts_std = create_centered_counts_ct(counts, model, num_umi, cell_types) + counts_std = np.nan_to_num(counts_std) + + c_values_norm = normalize_values_old( + counts_std, cell_types, cell_type_pairs, gene_pairs_ind_per_ct_pair, c_values, D + ) + adata.uns["ct_ccc_results"]["p"]["gp"]["cs_norm"] = c_values_norm + cell_com_df_gp["C_norm_p"] = c_values_norm.flatten() + + # Metabolite + c_values = adata.uns["ct_ccc_results"]["p"]["m"]["cs"] + cell_com_df_m["C_p"] = c_values.flatten() + z_values_a = adata.uns["ct_ccc_results"]["p"]["m"]["Z_a"] + cell_com_df_m["Z_a"] = z_values_a.flatten() + z_values_b = adata.uns["ct_ccc_results"]["p"]["m"]["Z_b"] + cell_com_df_m["Z_b"] = z_values_b.flatten() + p_values_a = adata.uns["ct_ccc_results"]["p"]["m"]["Z_pval_a"] + cell_com_df_m["Z_pval_a"] = p_values_a.flatten() + p_values_b = adata.uns["ct_ccc_results"]["p"]["m"]["Z_pval_b"] + cell_com_df_m["Z_pval_b"] = p_values_b.flatten() + FDR_values_a = adata.uns["ct_ccc_results"]["p"]["m"]["Z_FDR_a"] + cell_com_df_m["Z_FDR_a"] = FDR_values_a.flatten() + FDR_values_b = adata.uns["ct_ccc_results"]["p"]["m"]["Z_FDR_b"] + cell_com_df_m["Z_FDR_b"] = FDR_values_b.flatten() + + if test in ["non-parametric", "both"]: + # Gene pair + c_values = adata.uns["ct_ccc_results"]["np"]["gp"]["cs"] + cell_com_df_gp["C_np"] = c_values.flatten() + p_values_a = adata.uns["ct_ccc_results"]["np"]["gp"]["pval_a"] + cell_com_df_gp["pval_np_a"] = p_values_a.flatten() + p_values_b = adata.uns["ct_ccc_results"]["np"]["gp"]["pval_b"] + cell_com_df_gp["pval_np_b"] = p_values_b.flatten() + FDR_values_a = adata.uns["ct_ccc_results"]["np"]["gp"]["FDR_a"] + cell_com_df_gp["FDR_np_a"] = FDR_values_a.flatten() + FDR_values_b = adata.uns["ct_ccc_results"]["np"]["gp"]["FDR_b"] + cell_com_df_gp["FDR_np_b"] = FDR_values_b.flatten() + + counts = counts_from_anndata(adata[:, genes], layer_key_np_test, dense=True) + if adata.uns["center_counts_for_np_test"]: + num_umi = counts.sum(axis=0) + counts = create_centered_counts(counts, model, num_umi) + counts = np.nan_to_num(counts) + c_values_norm = normalize_values_old( + counts, cell_types, cell_type_pairs, gene_pairs_ind_per_ct_pair, c_values, D + ) + adata.uns["ct_ccc_results"]["np"]["gp"]["cs_norm"] = c_values_norm + cell_com_df_gp["C_norm_np"] = c_values_norm.flatten() + + # Metabolite + c_values = adata.uns["ct_ccc_results"]["np"]["m"]["cs"] + cell_com_df_m["C_np"] = c_values.flatten() + p_values_a = adata.uns["ct_ccc_results"]["np"]["m"]["pval_a"] + cell_com_df_m["pval_np_a"] = p_values_a.flatten() + p_values_b = adata.uns["ct_ccc_results"]["np"]["m"]["pval_b"] + cell_com_df_m["pval_np_b"] = p_values_b.flatten() + FDR_values_a = adata.uns["ct_ccc_results"]["np"]["m"]["FDR_a"] + cell_com_df_m["FDR_np_a"] = FDR_values_a.flatten() + FDR_values_b = adata.uns["ct_ccc_results"]["np"]["m"]["FDR_b"] + cell_com_df_m["FDR_np_b"] = FDR_values_b.flatten() + + adata.uns["ct_ccc_results"]["cell_com_df_gp"] = cell_com_df_gp + adata.uns["ct_ccc_results"]["cell_com_df_m"] = cell_com_df_m + + return + + +def get_cell_communication_results( + adata, + genes, + layer_key_p_test, + layer_key_np_test, + model, + D, + test, + device, +): + + gene_pairs = adata.uns["gene_pairs"] + gene_pairs_ind = adata.uns["gene_pairs_ind"] + gene_pair_dict = adata.uns["gene_pair_dict"] + + sample_specific = "sample_key" in adata.uns + + if isinstance(D, np.ndarray): + D = torch.tensor(D, dtype=torch.float64, device=device) + + # Initialize dataframes + cell_com_df_gp = pd.DataFrame(gene_pairs, columns=["Gene 1", "Gene 2"]) + cell_com_df_m = pd.DataFrame({"Metabolite": list(gene_pair_dict.keys())}) + + if test in ["parametric", "both"]: + suffix = "p" + # Gene pair + c_values = adata.uns["ccc_results"][suffix]["gp"]["cs"] + z_values = adata.uns["ccc_results"][suffix]["gp"]["Z"] + p_values = adata.uns["ccc_results"][suffix]["gp"]["Z_pval"] + fdr_values = adata.uns["ccc_results"][suffix]["gp"]["Z_FDR"] + cell_com_df_gp[f"C_{suffix}"] = c_values + cell_com_df_gp["Z"] = z_values + cell_com_df_gp["Z_pval"] = p_values + cell_com_df_gp["Z_FDR"] = fdr_values + + counts = counts_from_anndata(adata[:, genes], layer_key_p_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + num_umi = counts.sum(dim=0) + _lazy_import_hotspot() + counts_std = standardize_counts(adata, counts, model, num_umi, sample_specific) + + c_values_norm = normalize_values(counts_std, gene_pairs_ind, c_values, D) + adata.uns["ccc_results"][suffix]["gp"]["cs_norm"] = c_values_norm.cpu().numpy() + cell_com_df_gp[f"C_norm_{suffix}"] = c_values_norm.cpu().numpy() + + # Metabolite + c_values = adata.uns["ccc_results"][suffix]["m"]["cs"] + z_values = adata.uns["ccc_results"][suffix]["m"]["Z"] + p_values = adata.uns["ccc_results"][suffix]["m"]["Z_pval"] + fdr_values = adata.uns["ccc_results"][suffix]["m"]["Z_FDR"] + cell_com_df_m[f"C_{suffix}"] = c_values + cell_com_df_m["Z"] = z_values + cell_com_df_m["Z_pval"] = p_values + cell_com_df_m["Z_FDR"] = fdr_values + + if test in ["non-parametric", "both"]: + suffix = "np" + # Gene pair + c_values = adata.uns["ccc_results"][suffix]["gp"]["cs"] + p_values = adata.uns["ccc_results"][suffix]["gp"]["pval"] + fdr_values = adata.uns["ccc_results"][suffix]["gp"]["FDR"] + cell_com_df_gp[f"C_{suffix}"] = c_values + cell_com_df_gp[f"pval_{suffix}"] = p_values + cell_com_df_gp[f"FDR_{suffix}"] = fdr_values + + counts = counts_from_anndata(adata[:, genes], layer_key_np_test, dense=True) + counts = torch.tensor(counts, dtype=torch.float64, device=device) + if adata.uns.get("center_counts_for_np_test", False): + num_umi = counts.sum(dim=0) + _lazy_import_hotspot() + counts = standardize_counts(adata, counts, model, num_umi, sample_specific) + + c_values_norm = normalize_values(counts, gene_pairs_ind, c_values, D) + adata.uns["ccc_results"][suffix]["gp"]["cs_norm"] = c_values_norm.cpu().numpy() + cell_com_df_gp[f"C_norm_{suffix}"] = c_values_norm.cpu().numpy() + + # Metabolite + c_values = adata.uns["ccc_results"][suffix]["m"]["cs"] + p_values = adata.uns["ccc_results"][suffix]["m"]["pval"] + fdr_values = adata.uns["ccc_results"][suffix]["m"]["FDR"] + cell_com_df_m[f"C_{suffix}"] = c_values + cell_com_df_m[f"pval_{suffix}"] = p_values + cell_com_df_m[f"FDR_{suffix}"] = fdr_values + + adata.uns["ccc_results"]["cell_com_df_gp"] = cell_com_df_gp + adata.uns["ccc_results"]["cell_com_df_m"] = cell_com_df_m + + return + + +def normalize_ct_values( + counts, + cell_types, + cell_type_pairs, + gene_pairs_per_ct_pair_ind, + lcs, + D, +): + + if isinstance(cell_types, pd.Series): + cell_types = cell_types.values + + if isinstance(lcs, np.ndarray): + lcs = torch.tensor(lcs, dtype=counts.dtype, device=counts.device) + + c_values_norm = torch.empty_like(lcs, dtype=counts.dtype, device=counts.device) + + for i, ct_pair in enumerate(cell_type_pairs): + ct_t, _ = ct_pair + + ct_mask = cell_types == ct_t + if isinstance(ct_mask, np.ndarray): + ct_mask = torch.tensor(ct_mask, device=counts.device) + + counts_ct = counts[:, ct_mask] + D_ct = D[i][ct_mask] + gene_pairs_ind = gene_pairs_per_ct_pair_ind[ct_pair] + + lc_maxs = compute_max_cs(D_ct, counts_ct, gene_pairs_ind) + lc_maxs = torch.where(lc_maxs == 0, torch.tensor(1.0, device=counts.device), lc_maxs) + + c_values = lcs[i] if lcs.ndim == 2 else lcs[i : i + 1] # allow 1D or 2D lcs + c_values_norm[i] = c_values / lc_maxs + c_values_norm[i] = torch.where( + torch.isinf(c_values_norm[i]), + torch.tensor(1.0, device=counts.device), + c_values_norm[i], + ) + + return c_values_norm + + +def normalize_values(counts, gene_pairs_ind, lcs, D): + """ + Normalize communication scores (lcs) using maximum possible score estimates. + """ + lc_maxs = compute_max_cs(D, counts, gene_pairs_ind) + lc_maxs = torch.where(lc_maxs == 0, torch.tensor(1.0, device=lc_maxs.device), lc_maxs) + if isinstance(lcs, np.ndarray): + lcs = torch.tensor(lcs, dtype=lc_maxs.dtype, device=lc_maxs.device) + c_values_norm = lcs / lc_maxs + c_values_norm = torch.where( + torch.isinf(c_values_norm), torch.tensor(1.0, device=c_values_norm.device), c_values_norm + ) + return c_values_norm + + +def normalize_values_old( + counts, + cell_types, + cell_type_pairs, + gene_pairs_per_ct_pair_ind, + lcs, + D, +): + + c_values_norm = np.zeros(lcs.shape) + for i in range(len(cell_type_pairs)): + ct_pair = cell_type_pairs[i] + ct_t, ct_u = ct_pair + cell_type_t_mask = [ct == ct_t for ct in cell_types] + counts_ct_t = counts[:, cell_type_t_mask] + D_ct_t = D[i][cell_type_t_mask] + gene_pairs_ind = gene_pairs_per_ct_pair_ind[ct_pair] + lc_maxs = compute_max_cs_old(D_ct_t, counts_ct_t, gene_pairs_ind) + lc_maxs[lc_maxs == 0] = 1 + c_values_norm[i] = lcs[i] / lc_maxs + c_values_norm[i][np.isinf(c_values_norm[i])] = 1 + + return c_values_norm + + +def compute_max_cs(node_degrees, counts, gene_pairs_ind): + """ + Compute max communication scores per gene pair. + """ + result = torch.empty(len(gene_pairs_ind), dtype=counts.dtype, device=counts.device) + + for i, (g1, _) in enumerate(gene_pairs_ind): + if isinstance(g1, list): + vals = counts[g1].mean(dim=0) + else: + vals = counts[g1] + result[i] = compute_max_cs_gp(vals, node_degrees) + + return result + + +def compute_max_cs_old(node_degrees, counts, gene_pairs_ind): + + result = np.zeros(len(gene_pairs_ind)) + for i, gene_pair_ind in enumerate(gene_pairs_ind): + vals = ( + counts[gene_pair_ind[0]] + if type(gene_pair_ind[0]) is not list + else np.mean(counts[gene_pair_ind[0]], axis=0) + ) + result[i] = compute_max_cs_gp_old(vals, node_degrees) + + return result + + +def compute_max_cs_gp(vals, node_degrees): + """ + Compute max communication score for a single gene (vector). + """ + return 0.5 * torch.sum(node_degrees * vals**2) + + +@jit(nopython=True) +def compute_max_cs_gp_old(vals, node_degrees): + tot = 0.0 + + for i in range(node_degrees.size): + tot += node_degrees[i] * (vals[i] ** 2) + + return tot / 2 + + +def center_ct_counts_torch(counts, num_umi, model, cell_types): + """ + counts: Tensor [genes, cells] + num_umi: Tensor [cells] + model: 'bernoulli', 'danb', 'normal', or 'none' + + Returns + ------- + Centered counts within cell types: Tensor [genes, cells] + """ + # Binarize if using Bernoulli + if model == "bernoulli": + counts = (counts > 0).double() + mu, var, _ = models.apply_model_per_cell_type( + models.bernoulli_model_torch, counts, num_umi, cell_types + ) + elif model == "danb": + mu, var, _ = models.apply_model_per_cell_type( + models.danb_model_torch, counts, num_umi, cell_types + ) + elif model == "normal": + mu, var, _ = models.apply_model_per_cell_type( + models.normal_model_torch, counts, num_umi, cell_types + ) + elif model == "none": + mu, var, _ = models.apply_model_per_cell_type( + models.none_model_torch, counts, num_umi, cell_types + ) + else: + raise ValueError(f"Unsupported model type: {model}") + + # Avoid division by zero + std = torch.sqrt(var) + std[std == 0] = 1.0 + + centered = (counts - mu) / std + centered[centered == 0] = 0 # Optional: to match old behavior + + return centered + + +def create_centered_counts(counts, model, num_umi): + """ + Creates a matrix of centered/standardized counts given + the selected statistical model + """ + out = np.zeros_like(counts, dtype="double") + + for i in range(out.shape[0]): + vals_x = counts[i] + + out_x = create_centered_counts_row(vals_x, model, num_umi) + + out[i] = out_x + + return out + + +def create_centered_counts_ct(counts, model, num_umi, cell_types): + """ + Creates a matrix of centered/standardized counts given + the selected statistical model + """ + out = np.zeros_like(counts, dtype="double") + + for i in range(out.shape[0]): + vals_x = counts[i] + + out_x = create_centered_counts_row_ct(vals_x, model, num_umi, cell_types) + + out[i] = out_x + + return out + + +def create_centered_counts_row_ct(vals_x, model, num_umi, cell_types): + if model == "bernoulli": + vals_x = (vals_x > 0).astype("double") + mu_x, var_x, x2_x = models.bernoulli_model(vals_x, num_umi) + + elif model == "danb": + mu_x, var_x, x2_x = models.ct_danb_model(vals_x, num_umi, cell_types) + + elif model == "normal": + mu_x, var_x, x2_x = models.normal_model(vals_x, num_umi) + + elif model == "none": + mu_x, var_x, x2_x = models.none_model(vals_x, num_umi) + + else: + raise Exception(f"Invalid Model: {model}") + + var_x[var_x == 0] = 1 + out_x = (vals_x - mu_x) / (var_x**0.5) + out_x[out_x == 0] = 0 + + return out_x + + +def create_centered_counts_row(vals_x, model, num_umi): + if model == "bernoulli": + vals_x = (vals_x > 0).astype("double") + mu_x, var_x, x2_x = models.bernoulli_model(vals_x, num_umi) + + elif model == "danb": + mu_x, var_x, x2_x = models.danb_model(vals_x, num_umi) + + elif model == "normal": + mu_x, var_x, x2_x = models.normal_model(vals_x, num_umi) + + elif model == "none": + mu_x, var_x, x2_x = models.none_model(vals_x, num_umi) + + else: + raise Exception(f"Invalid Model: {model}") + + var_x[var_x == 0] = 1 + out_x = (vals_x - mu_x) / (var_x**0.5) + out_x[out_x == 0] = 0 + + return out_x + + +def compute_Z_scores_cellcom_p( + ct_pair, cell_type_pairs, gene_pair_cor, gene_pairs_per_ct_pair_ind, Wtot2, eg2s +): + i = cell_type_pairs.index(ct_pair) + gene_pair_cor_ct = gene_pair_cor[i, :, :] + gene_pairs_ind = gene_pairs_per_ct_pair_ind[ + ct_pair + ] # If we consider all the gene pairs (irrespective of the cell type pair) use 'gene_pairs_ind' directly + + C = [] + EG2 = [] + for gene_pair_ind in gene_pairs_ind: + g1_ind, g2_ind = gene_pair_ind + lc = gene_pair_cor_ct[g1_ind, g2_ind] + if g1_ind == g2_ind: + eg2 = Wtot2[i] + else: + eg2 = eg2s[i][g1_ind] + C.append(lc) + EG2.append(eg2) + + EG = [0 for i in range(len(gene_pairs_ind))] + + stdG = [(EG2[i] - EG[i] ** 2) ** 0.5 for i in range(len(gene_pairs_ind))] + stdG = [1 if stdG[i] == 0 else stdG[i] for i in range(len(stdG))] + + Z = [(C[i] - EG[i]) / stdG[i] for i in range(len(gene_pairs_ind))] + + return (C, Z) + + +def extract_results_cellcom_np( + ct_pair, cell_type_pairs, gene_pair_cor, pvals, gene_pairs_per_ct_pair_ind +): + i = cell_type_pairs.index(ct_pair) + gene_pair_cor_ct = gene_pair_cor[i, :, :] + pvals_ct = pvals[i, :, :] + gene_pairs_ind = gene_pairs_per_ct_pair_ind[ + ct_pair + ] # If we consider all the gene pairs (irrespective of the cell type pair) use 'gene_pairs_ind' directly + + C = [] + p_values = [] + for gene_pair_ind in gene_pairs_ind: + g1_ind, g2_ind = gene_pair_ind + lc = gene_pair_cor_ct[g1_ind, g2_ind] + p_value = pvals_ct[g1_ind, g2_ind] + + C.append(lc.reshape(1)) + p_values.append(p_value.reshape(1)) + + C = list(np.concatenate(C)) + p_values = list(np.concatenate(p_values)) + + return (C, p_values) + + +# @njit +def expand_ct_pairs_cellcom(pairs, vals, N): + out = [np.zeros((N, N)) for k in range(len(vals))] + + for k in range(len(out)): + for i in range(len(pairs)): + x = pairs[i, 0] + y = pairs[i, 1] + v = vals[k][i] + + out[k][x, y] = v + + return out + + +def compute_local_cov_pairs_max(node_degrees, counts): + """ + For a Genes x Cells count matrix, compute the maximal pair-wise correlation + between any two genes + """ + N_GENES = counts.shape[0] + + gene_maxs = np.zeros(N_GENES) + for i in range(N_GENES): + gene_maxs[i] = compute_local_cov_max(counts[i].todense(), node_degrees) + + result = gene_maxs.reshape((-1, 1)) + gene_maxs.reshape((1, -1)) + result = result / 2 + return result + + +@jit(nopython=True) +def compute_local_cov_max(vals, node_degrees): + tot = 0.0 + + for i in range(node_degrees.size): + tot += node_degrees[i] * (vals[i] ** 2) + + return tot / 2 + + +def get_ct_pair_counts_and_weights( + counts, weights, cell_type_pairs, cell_types, gene_pairs_per_ct_pair_ind +): + + c_nrow, c_ncol = counts.shape + w_nrow, w_ncol = weights.shape + n_ct_pairs = len(cell_type_pairs) + + extract_counts_weights_results = partial( + extract_ct_pair_counts_weights, + counts=counts, + weights=weights, + cell_type_pairs=cell_type_pairs, + cell_types=cell_types, + gene_pairs_per_ct_pair_ind=gene_pairs_per_ct_pair_ind, + ) + results = list(map(extract_counts_weights_results, cell_type_pairs)) + + c_new_data_t_all = [x[0] for x in results] + c_new_coords_3d_t_all = [x[1] for x in results] + c_new_coords_3d_t_all = np.hstack(c_new_coords_3d_t_all) + c_new_data_t_all = np.concatenate(c_new_data_t_all) + + c_new_data_u_all = [x[2] for x in results] + c_new_coords_3d_u_all = [x[3] for x in results] + c_new_coords_3d_u_all = np.hstack(c_new_coords_3d_u_all) + c_new_data_u_all = np.concatenate(c_new_data_u_all) + + w_new_data_all = [x[4] for x in results] + w_new_coords_3d_all = [x[5] for x in results] + w_new_coords_3d_all = np.hstack(w_new_coords_3d_all) + w_new_data_all = np.concatenate(w_new_data_all) + + counts_ct_pairs_t = sparse.COO( + c_new_coords_3d_t_all, c_new_data_t_all, shape=(n_ct_pairs, c_nrow, c_ncol) + ) + counts_ct_pairs_u = sparse.COO( + c_new_coords_3d_u_all, c_new_data_u_all, shape=(n_ct_pairs, c_nrow, c_ncol) + ) + weights_ct_pairs = sparse.COO( + w_new_coords_3d_all, w_new_data_all, shape=(n_ct_pairs, w_nrow, w_ncol) + ) + + return counts_ct_pairs_t, counts_ct_pairs_u, weights_ct_pairs + + +def get_ct_pair_counts_and_weights_null( + counts_ct_pairs_t, counts_ct_pairs_u, weights, cell_type_pairs, cell_types, M +): + + n_cells = counts_ct_pairs_t.shape[2] + cell_permutations = np.vstack([np.random.permutation(n_cells) for _ in range(M)]) + + n_ct_pairs, c_nrow, c_ncol = counts_ct_pairs_t.shape + w_nrow, w_ncol = weights.shape + + extract_counts_weights_results_null = partial( + extract_ct_pair_counts_weights_null, + permutations=cell_permutations, + counts_ct_pairs_t=counts_ct_pairs_t, + counts_ct_pairs_u=counts_ct_pairs_u, + weights=weights, + cell_type_pairs=cell_type_pairs, + cell_types=cell_types, + ) + results_null = list(map(extract_counts_weights_results_null, cell_permutations)) + + c_null_data_t_all = [x[0] for x in results_null] + c_null_coords_4d_t_all = [x[1] for x in results_null] + c_null_coords_4d_t_all = np.hstack(c_null_coords_4d_t_all) + c_null_data_t_all = np.concatenate(c_null_data_t_all) + + c_null_data_u_all = [x[2] for x in results_null] + c_null_coords_4d_u_all = [x[3] for x in results_null] + c_null_coords_4d_u_all = np.hstack(c_null_coords_4d_u_all) + c_null_data_u_all = np.concatenate(c_null_data_u_all) + + w_null_data_all = [x[4] for x in results_null] + w_null_coords_4d_all = [x[5] for x in results_null] + w_null_coords_4d_all = np.hstack(w_null_coords_4d_all) + w_null_data_all = np.concatenate(w_null_data_all) + + counts_ct_pairs_t_null = sparse.COO( + c_null_coords_4d_t_all, c_null_data_t_all, shape=(M, n_ct_pairs, c_nrow, c_ncol) + ) + counts_ct_pairs_u_null = sparse.COO( + c_null_coords_4d_u_all, c_null_data_u_all, shape=(M, n_ct_pairs, c_nrow, c_ncol) + ) + weights_ct_pairs_null = sparse.COO( + w_null_coords_4d_all, w_null_data_all, shape=(M, n_ct_pairs, w_nrow, w_ncol) + ) + + return counts_ct_pairs_t_null, counts_ct_pairs_u_null, weights_ct_pairs_null + + +def extract_ct_pair_counts_weights( + ct_pair, counts, weights, cell_type_pairs, cell_types, gene_pairs_per_ct_pair_ind +): + + i = cell_type_pairs.index(ct_pair) + + ct_t, ct_u = cell_type_pairs[i] + gene_pairs_per_ct_pair_ind_i = gene_pairs_per_ct_pair_ind[(ct_t, ct_u)] + ct_t_mask = cell_types.values == ct_t + ct_t_mask_coords = np.argwhere(ct_t_mask) + ct_u_mask = cell_types.values == ct_u + ct_u_mask_coords = np.argwhere(ct_u_mask) + + ct_t_genes = np.unique([t[0] for t in gene_pairs_per_ct_pair_ind_i]) + ct_u_genes = np.unique([t[1] for t in gene_pairs_per_ct_pair_ind_i]) + + c_old_coords = counts.coords + w_old_coords = weights.coords + + # Counts + + c_row_coords_t, c_col_coords_t = np.meshgrid(ct_t_genes, ct_t_mask_coords, indexing="ij") + c_row_coords_t = c_row_coords_t.ravel() + c_col_coords_t = c_col_coords_t.ravel() + c_new_coords_t = np.vstack((c_row_coords_t, c_col_coords_t)) + + c_row_coords_u, c_col_coords_u = np.meshgrid(ct_u_genes, ct_u_mask_coords, indexing="ij") + c_row_coords_u = c_row_coords_u.ravel() + c_col_coords_u = c_col_coords_u.ravel() + c_new_coords_u = np.vstack((c_row_coords_u, c_col_coords_u)) + + c_matching_indices_t = np.where(np.all(np.isin(c_old_coords.T, c_new_coords_t.T), axis=1))[0] + c_new_data_t = counts.data[c_matching_indices_t] + c_new_coords_t = c_old_coords[:, c_matching_indices_t] + + c_matching_indices_u = np.where(np.all(np.isin(c_old_coords.T, c_new_coords_u.T), axis=1))[0] + c_new_data_u = counts.data[c_matching_indices_u] + c_new_coords_u = c_old_coords[:, c_matching_indices_u] + + c_coord_3d_t = np.full(c_new_coords_t.shape[1], fill_value=i) + c_new_coords_3d_t = np.vstack((c_coord_3d_t, c_new_coords_t)) + + c_coord_3d_u = np.full(c_new_coords_u.shape[1], fill_value=i) + c_new_coords_3d_u = np.vstack((c_coord_3d_u, c_new_coords_u)) + + # Weights + + w_row_coords, w_col_coords = np.meshgrid(ct_t_mask_coords, ct_u_mask_coords, indexing="ij") + w_row_coords = w_row_coords.ravel() + w_col_coords = w_col_coords.ravel() + w_new_coords = np.vstack((w_row_coords, w_col_coords)) + + w_matching_indices = np.where(np.all(np.isin(w_old_coords.T, w_new_coords.T), axis=1))[0] + w_new_data = weights.data[w_matching_indices] + w_new_coords = w_old_coords[:, w_matching_indices] + + w_coord_3d = np.full(w_new_coords.shape[1], fill_value=i) + w_new_coords_3d = np.vstack((w_coord_3d, w_new_coords)) + + return ( + c_new_data_t, + c_new_coords_3d_t, + c_new_data_u, + c_new_coords_3d_u, + w_new_data, + w_new_coords_3d, + ) + + +def extract_ct_pair_counts_weights_null( + permutation, + permutations, + counts_ct_pairs_t, + counts_ct_pairs_u, + weights, + cell_type_pairs, + cell_types, +): + + i = np.where(np.all(permutations == permutation, axis=1))[0] + cell_types_perm = pd.Series(cell_types[permutation]) + + counts_ct_pairs_t_perm = counts_ct_pairs_t[:, :, permutation] + c_perm_coords_t = counts_ct_pairs_t_perm.coords + c_perm_data_t = counts_ct_pairs_t_perm.data + counts_ct_pairs_u_perm = counts_ct_pairs_u[:, :, permutation] + c_perm_coords_u = counts_ct_pairs_u_perm.coords + c_perm_data_u = counts_ct_pairs_u_perm.data + + extract_weights_results = partial( + extract_ct_pair_weights, + weights=weights, + cell_type_pairs=cell_type_pairs, + cell_types=cell_types_perm, + ) + weights_results = list(map(extract_weights_results, cell_type_pairs)) + + w_perm_data_all = [x[0] for x in weights_results] + w_perm_coords_3d_all = [x[1] for x in weights_results] + w_perm_coords_3d_all = np.hstack(w_perm_coords_3d_all) + w_perm_data_all = np.concatenate(w_perm_data_all) + + c_coord_4d_t = np.full(c_perm_coords_t.shape[1], fill_value=i) + c_perm_coords_4d_t = np.vstack((c_coord_4d_t, c_perm_coords_t)) + + c_coord_4d_u = np.full(c_perm_coords_u.shape[1], fill_value=i) + c_perm_coords_4d_u = np.vstack((c_coord_4d_u, c_perm_coords_u)) + + w_coord_4d_u = np.full(w_perm_coords_3d_all.shape[1], fill_value=i) + w_perm_coords_4d_u = np.vstack((w_coord_4d_u, w_perm_coords_3d_all)) + + return ( + c_perm_data_t, + c_perm_coords_4d_t, + c_perm_data_u, + c_perm_coords_4d_u, + w_perm_data_all, + w_perm_coords_4d_u, + ) + + +def compute_interaction_module_correlation( + adata: AnnData, + cor_method: Literal["pearson"] | Literal["spearman"] | None = "pearson", + test: Literal["parametric"] | Literal["non-parametric"] | None = None, + interaction_type: Literal["metabolite"] | Literal["gene_pair"] | None = "metabolite", + only_sig_values: bool | None = False, + normalize_values: bool | None = False, + use_FDR: bool | None = True, + use_super_modules: bool | None = False, +): + """ + Compute correlations between interacting cell scores and module scores. + + Parameters + ---------- + adata : AnnData + Must contain: + - ``uns['interacting_cell_results']`` (parametric or non-parametric interacting cell scores) + - ``obsm['module_scores']`` or ``obsm['super_module_scores']`` + - ``uns['metabolites']`` or ``uns['gene_pairs_sig_names']`` + cor_method : {"pearson", "spearman"}, default "pearson" + Statistical method used to compute correlations. + test : {"parametric", "non-parametric"} + Which interacting-cell score set to use. + - `"parametric"` → uses ``uns['interacting_cell_results']['p']`` + - `"non-parametric"` → uses ``uns['interacting_cell_results']['np']`` + interaction_type : {"metabolite", "gene_pair"}, default "metabolite" + Select whether to correlate: + - metabolite scores, or + - gene pair scores. + only_sig_values : bool, default False + If True, use only significant interacting cell score values (`cs_sig_pval` or `cs_sig_FDR`). + normalize_values : bool, default False + Apply min–max normalization to interacting cell score values per interaction. + use_FDR : bool, default True + If ``only_sig_values=True``, determines whether to filter by FDR or raw p-values. + use_super_modules : bool, default False + Whether to use super-module scores (``obsm['super_module_scores']``) instead of module scores. + """ + MODULE_KEY = "super_module_scores" if use_super_modules else "module_scores" + + if cor_method not in ["pearson", "spearman"]: + raise ValueError(f'Invalid method: {cor_method}. Choose either "pearson" or "spearman".') + + adata.uns["cor_method"] = cor_method + + if test not in ["parametric", "non-parametric"]: + raise ValueError('The "test" variable should be one of ["parametric", "non-parametric"].') + + test_str = "p" if test == "parametric" else "np" + + if interaction_type not in ["metabolite", "gene_pair"]: + raise ValueError( + 'The "interaction_type" variable should be one of ["metabolite", "gene_pair"].' + ) + + interaction_type_str = "m" if interaction_type == "metabolite" else "gp" + + if only_sig_values: + sig_str = "FDR" if use_FDR else "pval" + interaction_scores = adata.uns["interacting_cell_results"][test_str][interaction_type_str][ + f"cs_sig_{sig_str}" + ] + else: + interaction_scores = adata.uns["interacting_cell_results"][test_str][interaction_type_str][ + "cs" + ] + + if normalize_values: + interaction_scores = interaction_scores.apply( + lambda x: (x - x.min()) / (x.max() - x.min()), axis=0 + ) # We apply min-max normalization + + interaction_type_names_key = ( + "metabolites" if interaction_type == "metabolite" else "gene_pairs_sig_names" + ) + interaction_scores = pd.DataFrame( + interaction_scores, index=adata.obs_names, columns=adata.uns[interaction_type_names_key] + ) + + metabolites = interaction_scores.columns.tolist() + modules = adata.obsm[MODULE_KEY].columns.tolist() + + cor_pval_df = pd.DataFrame(index=modules) + cor_coef_df = pd.DataFrame(index=modules) + + for metab in metabolites: + correlation_values = [] + pvals = [] + + for module in modules: + metab_df = interaction_scores[metab] + module_df = adata.obsm[MODULE_KEY][module] + + if cor_method == "pearson": + correlation_value, pval = pearsonr(metab_df, module_df) + elif cor_method == "spearman": + correlation_value, pval = spearmanr(metab_df, module_df) + + correlation_values.append(correlation_value) + pvals.append(pval) + + cor_coef_df[metab] = correlation_values + cor_pval_df[metab] = pvals + + cor_pval_df = cor_pval_df.replace(np.nan, 1) + cor_coef_df = cor_coef_df.replace(np.nan, 0) + cor_FDR_values = multipletests(cor_pval_df.values.flatten(), method="fdr_bh")[1] + cor_FDR_df = pd.DataFrame( + cor_FDR_values.reshape(cor_pval_df.shape), + index=cor_pval_df.index, + columns=cor_pval_df.columns, + ) + + adata.uns["interaction_module_correlation_coefs"] = cor_coef_df + adata.uns["interaction_module_correlation_pvals"] = cor_pval_df + adata.uns["interaction_module_correlation_FDR"] = cor_FDR_df + + return diff --git a/src/scvi/external/harreman/tools/knn.py b/src/scvi/external/harreman/tools/knn.py new file mode 100755 index 0000000000..a598d9528e --- /dev/null +++ b/src/scvi/external/harreman/tools/knn.py @@ -0,0 +1,419 @@ +import itertools +import time +from math import ceil + +import numpy as np +import pandas as pd +import sparse +from anndata import AnnData +from scipy.sparse import csr_matrix, lil_matrix +from sklearn.neighbors import NearestNeighbors, radius_neighbors_graph +from sklearn.preprocessing import normalize +from tqdm import tqdm + + +def compute_knn_graph( + adata: AnnData, + compute_neighbors_on_key: str | None = None, + distances_obsp_key: str | None = None, + weighted_graph: bool | None = False, + neighborhood_radius: int | None = None, + n_neighbors: int | None = None, + neighborhood_factor: int | None = 3, + sample_key: str | None = None, + tree=None, + verbose: bool | None = False, +): + """Computes the spatial proximity graph. + + Parameters + ---------- + adata + AnnData object. + compute_neighbors_on_key + Key in `adata.obsm` to use for computing neighbors. If `None`, use neighbors stored in `adata`. If no neighbors have been previously computed an error will be raised. + distances_obsp_key + Distances encoding cell-cell similarities directly. Shape is (cells x cells). Input is key in `adata.obsp`. + weighted_graph + Whether or not to create a weighted graph. + neighborhood_radius + Neighborhood radius. + n_neighbors + Neighborhood size. + neighborhood_factor + Used when creating a weighted graph. Sets how quickly weights decay relative to the distances within the neighborhood. The weight for a cell with a distance d will decay as exp(-d^2/D) where D is the distance to the `n_neighbors`/`neighborhood_factor`-th neighbor. + sample_key + Sample information in case the data contains different samples or samples from different conditions. Input is key in `adata.obs`. + tree + Root tree node. Can be created using ete3.Tree + verbose + Whether to print progress and status messages. + """ + start = time.time() + + if n_neighbors is None and neighborhood_radius is None: + raise ValueError("Either 'n_neighbors' or 'neighborhood_radius' needs to be provided.") + + if tree is not None: + try: + all_leaves = [] + for x in tree: + if x.is_leaf(): + all_leaves.append(x.name) + except: + raise ValueError("Can't parse supplied tree") + + if len(all_leaves) != adata.shape[0] or len(set(all_leaves) & set(adata.obs_names)) != len( + all_leaves + ): + raise ValueError("Tree leaf labels don't match observations in supplied AnnData") + + if weighted_graph: + raise ValueError( + "When using `tree` as the metric space, `weighted_graph=True` is not supported" + ) + tree_neighbors_and_weights(adata, tree, n_neighbors=n_neighbors) + + if compute_neighbors_on_key is not None: + if verbose: + print("Computing the neighborhood graph...") + compute_neighbors( + adata=adata, + compute_neighbors_on_key=compute_neighbors_on_key, + n_neighbors=n_neighbors, + neighborhood_radius=neighborhood_radius, + sample_key=sample_key, + verbose=verbose, + ) + else: + if distances_obsp_key is not None and distances_obsp_key in adata.obsp: + if verbose: + print("Computing the neighborhood graph from distances...") + compute_neighbors_from_distances( + adata, + distances_obsp_key, + n_neighbors, + sample_key, + verbose, + ) + + if "distances" in adata.obsp: + if verbose: + print("Computing the weights...") + compute_weights( + adata, + weighted_graph, + neighborhood_factor, + ) + + if verbose: + print("Finished computing the KNN graph in %.3f seconds" % (time.time() - start)) + + return + + +def compute_neighbors( + adata: AnnData, + compute_neighbors_on_key: str = None, + n_neighbors: int | None = None, + neighborhood_radius: int | None = None, + sample_key: str | None = None, + verbose: bool | None = False, +) -> None: + """ + Computes a nearest-neighbors graph on the AnnData object using either + radius-based or k-nearest neighbors. + + Parameters + ---------- + adata : AnnData + Annotated data object (AnnData). + compute_neighbors_on_key : str, optional + Key in `adata.obsm` to compute neighbors on (e.g., spatial coordinates or PCA). + If None, defaults to 'spatial'. + n_neighbors : int, optional + Number of nearest neighbors to compute (for kNN graph). + neighborhood_radius : int, optional + Radius to use for radius-based graph computation (in Euclidean space). + Only used if `n_neighbors` is not provided. + sample_key : str, optional + Key in `adata.obs` indicating batch/sample identity. + Ensures neighbors are only computed within each sample. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + """ + if compute_neighbors_on_key not in adata.obsm: + raise ValueError(f"{compute_neighbors_on_key} not found in adata.obsm") + + coords = adata.obsm[compute_neighbors_on_key] + n_cells = adata.n_obs + distances = lil_matrix((n_cells, n_cells)) + + if sample_key is not None: + if verbose: + print(f"Restricting graph within samples using '{sample_key}'...") + adata.uns["sample_key"] = sample_key + samples = adata.obs[sample_key].unique() + for sample in tqdm(samples): + sample_mask = adata.obs[sample_key] == sample + sample_indices = np.where(sample_mask)[0] + sample_coords = coords[sample_mask] + if len(sample_indices) == 0: + continue + if n_neighbors is not None: + nn = NearestNeighbors(n_neighbors=n_neighbors + 1, algorithm="ball_tree").fit( + sample_coords + ) + dist = nn.kneighbors_graph(sample_coords, mode="distance") + elif neighborhood_radius is not None: + dist = radius_neighbors_graph( + sample_coords, radius=neighborhood_radius, mode="distance", include_self=False + ) + else: + raise ValueError("Either n_neighbors or neighborhood_radius must be specified.") + dist = dist.tocoo() + distances[sample_indices[dist.row], sample_indices[dist.col]] = dist.data + else: + if n_neighbors is not None: + nn = NearestNeighbors(n_neighbors=n_neighbors + 1, algorithm="ball_tree").fit(coords) + distances = nn.kneighbors_graph(coords, mode="distance") + elif neighborhood_radius is not None: + distances = radius_neighbors_graph( + coords, radius=neighborhood_radius, mode="distance", include_self=False + ) + else: + raise ValueError("Either n_neighbors or neighborhood_radius must be specified.") + + # Deconvolution-aware neighborhood + if adata.uns.get("deconv_data", False): + if verbose: + print("Adding intra-spot connections...") + spot_diameter = adata.uns["spot_diameter"] + idx = adata.obs.groupby("barcodes").indices + rows, cols = [], [] + for barcode, inds in idx.items(): + if len(inds) < 2: + continue + # Efficient pairwise combinations without permutations + inds = np.array(inds) + i, j = np.meshgrid(inds, inds, indexing="ij") + mask = i != j + rows.extend(i[mask]) + cols.extend(j[mask]) + extra_distances = csr_matrix( + (np.full(len(rows), spot_diameter / 2), (rows, cols)), shape=(n_cells, n_cells) + ) + distances += extra_distances + + adata.obsp["distances"] = distances.tocsr() + + return + + +def compute_neighbors_from_distances( + adata: AnnData, + distances_obsp_key: str = "distances", + sample_key: str | None = None, + verbose: bool | None = False, +) -> None: + """ + Builds a neighborhood graph using a precomputed distance matrix. + + Parameters + ---------- + adata : AnnData + Annotated data object containing a full distance matrix in `obsp`. + distances_obsp_key : str + Key in `adata.obsp` with the full distance matrix. + sample_key : str, optional + Key in `adata.obs` to enforce neighbors only within samples. + verbose : bool, optional (default: False) + Whether to print progress and status messages. + """ + if distances_obsp_key not in adata.obsp: + raise ValueError(f"{distances_obsp_key} not found in adata.obsp") + + distances_raw = adata.obsp[distances_obsp_key] + distances_raw = ( + csr_matrix(distances_raw) if isinstance(distances_raw, np.ndarray) else distances_raw + ) + + n_cells = adata.shape[0] + distances = lil_matrix((n_cells, n_cells)) + + # Restrict by sample + if sample_key is not None: + if verbose: + print(f"Restricting graph within samples using '{sample_key}'...") + adata.uns["sample_key"] = sample_key + samples = adata.obs[sample_key].unique().tolist() + for sample in tqdm(adata.obs[sample_key].unique(), desc="Samples"): + idx = np.where(adata.obs[sample_key] == sample)[0] + sub_dist = distances_raw[idx, :][:, idx] + distances[np.ix_(idx, idx)] = sub_dist + else: + distances = distances_raw.copy().tolil() + + if adata.uns.get("deconv_data", False) and "barcodes" in adata.obs: + if verbose: + print("Adding intra-spot connections...") + for barcode, inds in adata.obs.groupby("barcodes").indices.items(): + if len(inds) <= 1: + continue + for i, j in itertools.permutations(inds, 2): + distances[i, j] = spot_diameter / 2 + + adata.obsp["distances"] = distances.tocsr() + + return + + +def compute_weights( + adata: AnnData, + weighted_graph: bool, + neighborhood_factor: int, +) -> None: + """ + Computes weights on the neighbors based on a + gaussian kernel and their distances. + + Parameters + ---------- + adata : AnnData + Annotated data object containing a full distance matrix in `obsp`. + weighted_graph : bool + Whether or not to create a weighted graph. + neighborhood_factor : int + Used when creating a weighted graph. Sets how quickly weights decay relative to the distances within the neighborhood. + The weight for a cell with a distance d will decay as exp(-d^2/D) where D is the distance to the `n_neighbors`/`neighborhood_factor`-th neighbor. + """ + # Load distance matrix and remove diagonal entries + distances = sparse.COO.from_scipy_sparse(adata.obsp["distances"]) + i, j = distances.coords + non_diag_mask = i != j + i, j, data = i[non_diag_mask], j[non_diag_mask], distances.data[non_diag_mask] + distances = sparse.COO(coords=[i, j], data=data, shape=distances.shape) + + if not weighted_graph: + # Unweighted: convert all non-zero distances to 1 + weights = sparse.COO(coords=[i, j], data=np.ones_like(data), shape=distances.shape) + adata.obsp["weights"] = weights.tocsr() + return + + # Weighted: Gaussian kernel + n_cells = distances.shape[0] + row_starts = np.searchsorted(i, np.arange(n_cells), side="left") + row_ends = np.searchsorted(i, np.arange(n_cells) + 1, side="left") + + sigmas = np.ones(n_cells, dtype=float) + for idx in range(n_cells): + start, end = row_starts[idx], row_ends[idx] + row_data = data[start:end] + if row_data.size == 0: + continue + radius_idx = ceil(len(row_data) / neighborhood_factor) - 1 + sigmas[idx] = np.partition(row_data, radius_idx)[radius_idx] + + # Build exp(-d^2 / sigma^2) weights + sigma_lookup = sigmas[i] # map sigma per row + gaussian_weights = np.exp(-1.0 * data**2 / sigma_lookup**2) + + weights = sparse.COO(coords=[i, j], data=gaussian_weights, shape=distances.shape) + weights_csr = weights.tocsr() + weights_norm = normalize(weights_csr, norm="l1", axis=1) + + adata.obsp["weights"] = weights_norm + + return + + +def make_weights_non_redundant(weights): + + w_no_redundant = weights.copy() + + rows, cols = w_no_redundant.nonzero() + upper_diag_mask = rows < cols + upper_rows, upper_cols = rows[upper_diag_mask], cols[upper_diag_mask] + + w_no_redundant[upper_rows, upper_cols] += w_no_redundant[upper_cols, upper_rows] + w_no_redundant[upper_cols, upper_rows] = 0 + w_no_redundant.eliminate_zeros() + + return w_no_redundant + + +def tree_neighbors_and_weights(adata, tree, n_neighbors): + """ + Computes nearest neighbors and associated weights for data + Uses distance along the tree object + + Names of the leaves of the tree must match the columns in counts + + Parameters + ---------- + adata + AnnData object. + tree: ete3.TreeNode + The root of the tree + n_neighbors: int + Number of neighbors to find + + """ + K = n_neighbors + cell_labels = adata.obs_names + + all_leaves = [] + for x in tree: + if x.is_leaf(): + all_leaves.append(x) + + all_neighbors = {} + + for leaf in tqdm(all_leaves): + neighbors = _knn(leaf, K) + all_neighbors[leaf.name] = neighbors + + cell_ix = {c: i for i, c in enumerate(cell_labels)} + + knn_ix = lil_matrix((len(all_neighbors), len(all_neighbors)), dtype=np.int8) + for cell in all_neighbors: + row = cell_ix[cell] + nn_ix = [cell_ix[x] for x in all_neighbors[cell]] + knn_ix[row, nn_ix] = 1 + + weights = knn_ix.tocsr() + + adata.obsp["weights"] = weights + + return + + +def _knn(leaf, K): + + dists = _search(leaf, None, 0) + dists = pd.Series(dists) + dists = dists + np.random.rand(len(dists)) * 0.9 # to break ties randomly + + neighbors = dists.sort_values().index[0:K].tolist() + + return neighbors + + +def _search(current_node, previous_node, distance): + + if current_node.is_root(): + nodes_to_search = current_node.children + else: + nodes_to_search = current_node.children + [current_node.up] + nodes_to_search = [x for x in nodes_to_search if x != previous_node] + + if len(nodes_to_search) == 0: + return {current_node.name: distance} + + result = {} + for new_node in nodes_to_search: + res = _search(new_node, current_node, distance + 1) + for k, v in res.items(): + result[k] = v + + return result diff --git a/tests/external/harreman/__init__.py b/tests/external/harreman/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/external/harreman/test_harreman.py b/tests/external/harreman/test_harreman.py new file mode 100644 index 0000000000..ce292efeb7 --- /dev/null +++ b/tests/external/harreman/test_harreman.py @@ -0,0 +1,73 @@ +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +import scvi.external.harreman.hotspot as hs +import scvi.external.harreman.tools as tl + + +@pytest.fixture +def adata_spatial(): + n_obs = 50 + n_vars = 20 + np.random.seed(42) + X = np.random.poisson(1.0, size=(n_obs, n_vars)).astype(float) + obs = pd.DataFrame( + {"sample": pd.Categorical(np.random.choice(["s1", "s2"], size=n_obs))}, + index=[f"cell{i}" for i in range(n_obs)], + ) + var = pd.DataFrame(index=[f"gene{i}" for i in range(n_vars)]) + obsm = {"spatial": np.random.rand(n_obs, 2) * 100} + return AnnData(X=X, obs=obs, var=var, obsm=obsm) + + +def test_compute_knn_graph_n_neighbors(adata_spatial): + tl.compute_knn_graph(adata_spatial, compute_neighbors_on_key="spatial", n_neighbors=5) + assert "distances" in adata_spatial.obsp + assert "weights" in adata_spatial.obsp + assert adata_spatial.obsp["distances"].shape == (adata_spatial.n_obs, adata_spatial.n_obs) + + +def test_compute_knn_graph_weighted(adata_spatial): + tl.compute_knn_graph( + adata_spatial, compute_neighbors_on_key="spatial", n_neighbors=5, weighted_graph=True + ) + assert "weights" in adata_spatial.obsp + assert adata_spatial.obsp["weights"].nnz > 0 + + +def test_compute_knn_graph_with_sample_key(adata_spatial): + tl.compute_knn_graph( + adata_spatial, compute_neighbors_on_key="spatial", n_neighbors=5, sample_key="sample" + ) + assert "distances" in adata_spatial.obsp + + +def test_compute_knn_graph_missing_key(adata_spatial): + with pytest.raises(ValueError, match="not found in adata.obsm"): + tl.compute_knn_graph( + adata_spatial, compute_neighbors_on_key="nonexistent_key", n_neighbors=5 + ) + + +def test_compute_knn_graph_no_neighbors_raises(adata_spatial): + with pytest.raises(ValueError, match="Either 'n_neighbors' or 'neighborhood_radius'"): + tl.compute_knn_graph(adata_spatial, compute_neighbors_on_key="spatial") + + +def test_compute_local_autocorrelation(adata_spatial): + tl.compute_knn_graph(adata_spatial, compute_neighbors_on_key="spatial", n_neighbors=5) + hs.compute_local_autocorrelation(adata_spatial, model="danb") + assert "gene_autocorrelation_results" in adata_spatial.uns + results = adata_spatial.uns["gene_autocorrelation_results"] + assert len(results) == adata_spatial.n_vars + + +def test_compute_local_correlation(adata_spatial): + tl.compute_knn_graph(adata_spatial, compute_neighbors_on_key="spatial", n_neighbors=5) + hs.compute_local_autocorrelation(adata_spatial, model="danb") + # Pass genes explicitly to avoid empty selection from FDR filtering + genes = adata_spatial.var_names[:5].tolist() + hs.compute_local_correlation(adata_spatial, genes=genes) + assert "lc_zs" in adata_spatial.uns