Skip to content

first commit with new neighborhood_connectivity function into _shap.py #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ and this project adheres to [Semantic Versioning][].

### Added

- Basic tool, preprocessing and plotting functions
- Basic tool, preprocessing and plotting functions
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,24 @@ CellCharter is able to automatically identify spatial domains and offers a suite

## Features

- **Identify niches for multiple samples**: By combining the power of scVI and scArches, CellCharter can identify domains for multiple samples simultaneously, even with in presence of batch effects.
- **Scalability**: CellCharter can handle large datasets with millions of cells and thousands of features. The possibility to run it on GPUs makes it even faster
- **Flexibility**: CellCharter can be used with different types of spatial omics data, such as spatial transcriptomics, proteomics, epigenomics and multiomics data. The only difference is the method used for dimensionality reduction and batch effect removal.
- Spatial transcriptomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Zero-inflated negative binomial distribution.
- Spatial proteomics: CellCharter has been tested on a version of [scArches](https://docs.scarches.org/en/latest/api/models.html#scarches.models.TRVAE), modified to use Mean Squared Error loss instead of the default Negative Binomial loss.
- Spatial epigenomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Poisson distribution.
- Spatial multiomics: it's possible to use multi-omics models such as [MultiVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.MULTIVI.html#scvi.model.MULTIVI), or use the concatenation of the results from the different models.
- **Best candidates for the number of domains**: CellCharter offers a [method to find multiple best candidates](https://cellcharter.readthedocs.io/en/latest/generated/cellcharter.tl.ClusterAutoK.html) for the number of domains, based on the stability of a certain number of domains across multiple runs.
- **Domain characterization**: CellCharter provides a set of tools to characterize and compare the spatial domains, such as domain proportion, cell type enrichment, (differential) neighborhood enrichment, and domain shape characterization.
- **Identify niches for multiple samples**: By combining the power of scVI and scArches, CellCharter can identify domains for multiple samples simultaneously, even with in presence of batch effects.
- **Scalability**: CellCharter can handle large datasets with millions of cells and thousands of features. The possibility to run it on GPUs makes it even faster
- **Flexibility**: CellCharter can be used with different types of spatial omics data, such as spatial transcriptomics, proteomics, epigenomics and multiomics data. The only difference is the method used for dimensionality reduction and batch effect removal.
- Spatial transcriptomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Zero-inflated negative binomial distribution.
- Spatial proteomics: CellCharter has been tested on a version of [scArches](https://docs.scarches.org/en/latest/api/models.html#scarches.models.TRVAE), modified to use Mean Squared Error loss instead of the default Negative Binomial loss.
- Spatial epigenomics: CellCharter has been tested on [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI) with Poisson distribution.
- Spatial multiomics: it's possible to use multi-omics models such as [MultiVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.MULTIVI.html#scvi.model.MULTIVI), or use the concatenation of the results from the different models.
- **Best candidates for the number of domains**: CellCharter offers a [method to find multiple best candidates](https://cellcharter.readthedocs.io/en/latest/generated/cellcharter.tl.ClusterAutoK.html) for the number of domains, based on the stability of a certain number of domains across multiple runs.
- **Domain characterization**: CellCharter provides a set of tools to characterize and compare the spatial domains, such as domain proportion, cell type enrichment, (differential) neighborhood enrichment, and domain shape characterization.

Since CellCharter 0.3.0, we moved the implementation of the Gaussian Mixture Model (GMM) from [PyCave](https://github.com/borchero/pycave), not maintained anymore, to [TorchGMM](https://github.com/CSOgroup/torchgmm), a fork of PyCave maintained by the CSOgroup. This change allows us to have a more stable and maintained implementation of GMM that is compatible with the most recent versions of PyTorch.

## Getting started

Please refer to the [documentation][link-docs]. In particular, the

- [API documentation][link-api].
- [Tutorials][link-tutorial]
- [API documentation][link-api].
- [Tutorials][link-tutorial]

## Installation

Expand Down
18 changes: 9 additions & 9 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ Specify `vX.X.X` as a tag name and create a release. For more information, see [

Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:

- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)

See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
on how to write documentation.
Expand All @@ -120,10 +120,10 @@ repository.

#### Hints

- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`
- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`

#### Building the docs locally

Expand Down
241 changes: 131 additions & 110 deletions src/cellcharter/pl/_shape.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from __future__ import annotations

import warnings
from itertools import combinations
from pathlib import Path

import anndata as ad
import geopandas
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse as sps
import seaborn as sns
import spatialdata as sd
import spatialdata_plot # noqa: F401
from anndata import AnnData
from scipy.stats import ttest_ind
from squidpy._docs import d

from ._utils import adjust_box_widths
Expand Down Expand Up @@ -74,6 +73,7 @@ def boundaries(
component_key: str = "component",
alpha_boundary: float = 0.5,
show_cells: bool = True,
cells_radius: float = None,
save: str | Path | None = None,
) -> None:
"""
Expand All @@ -97,6 +97,9 @@ def boundaries(
-------
%(plotting_returns)s
"""
if show_cells is True and cells_radius is None:
raise ValueError("cells_radius must be provided when show_cells is True")

adata = adata[adata.obs[library_key] == sample].copy()
del adata.raw
clusters = adata.obs[component_key].unique()
Expand All @@ -113,7 +116,7 @@ def boundaries(
adata.obs["region"] = "cells"

xy = adata.obsm["spatial"]
cell_circles = sd.models.ShapesModel.parse(xy, geometry=0, radius=3000, index=adata.obs["instance_id"])
cell_circles = sd.models.ShapesModel.parse(xy, geometry=0, radius=cells_radius, index=adata.obs["instance_id"])

obs = pd.DataFrame(list(boundaries.keys()), columns=[component_key], index=np.arange(len(boundaries)).astype(str))
adata_obs = ad.AnnData(X=pd.DataFrame(index=obs.index, columns=adata.var_names), obs=obs)
Expand All @@ -122,6 +125,10 @@ def boundaries(
adata_obs.obs["instance_id"] = np.arange(len(boundaries))
adata_obs.obs[component_key] = pd.Categorical(adata_obs.obs[component_key])

if sps.issparse(adata.X):
# If the adata is sparse, we need to convert the adata_obs to an empty sparse matrix
adata_obs.X = sps.csr_matrix((len(adata_obs.obs), len(adata.var_names)))

adata = ad.concat((adata, adata_obs), join="outer")

adata.obs["region"] = adata.obs["region"].astype("category")
Expand Down Expand Up @@ -218,15 +225,56 @@ def plot_shape_metrics(
)


def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) -> None:
fig = plt.figure(figsize=figsize)
ax = sns.boxplot(
data=data,
x=x,
hue=hue,
y=y,
showfliers=False,
hue_order=hue_order,
)
adjust_box_widths(fig, 0.9)

ax = sns.stripplot(
data=data,
x=x,
hue=hue,
y=y,
palette="dark:0.08",
size=4,
jitter=0.13,
dodge=True,
hue_order=hue_order,
)

if len(data[hue].unique()) > 1:
handles, labels = ax.get_legend_handles_labels()
if len(handles) > 1:
plt.legend(
handles[0 : len(data[hue].unique())],
labels[0 : len(data[hue].unique())],
bbox_to_anchor=(1.0, 1.03),
title=hue,
)
else:
if ax.get_legend() is not None:
ax.get_legend().remove()
plt.ylim(-0.05, 1.05)
plt.title(title)
plt.show()


@d.dedent
def shape_metrics(
adata: AnnData,
condition_key: str,
condition_key: str | None = None,
condition_groups: list[str] | None = None,
cluster_key: str | None = None,
cluster_id: list[str] | None = None,
cluster_id: str | list[str] | None = None,
component_key: str = "component",
metrics: str | tuple[str] | list[str] = ("linearity", "curl"),
metrics: str | tuple[str] | list[str] | None = None,
fontsize: str | int = "small",
figsize: tuple[float, float] = (8, 7),
title: str | None = None,
Expand All @@ -248,7 +296,7 @@ def shape_metrics(
component_key
Key in :attr:`anndata.AnnData.obs` where the component labels are stored.
metrics
List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``.
List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``, ``ncc``. If `None`, all computed metrics are plotted.
figsize
Figure size.
title
Expand All @@ -263,115 +311,88 @@ def shape_metrics(
elif isinstance(metrics, tuple):
metrics = list(metrics)

metrics_df = {metric: adata.uns[f"shape_{component_key}"][metric] for metric in metrics}
metrics_df[condition_key] = (
adata[~adata.obs[condition_key].isna()]
.obs[[component_key, condition_key]]
.drop_duplicates()
.set_index(component_key)
.to_dict()[condition_key]
)
if cluster_id is not None and not isinstance(cluster_id, list) and not isinstance(cluster_id, np.ndarray):
cluster_id = [cluster_id]

metrics_df[cluster_key] = (
adata[~adata.obs[condition_key].isna()]
.obs[[component_key, cluster_key]]
.drop_duplicates()
.set_index(component_key)
.to_dict()[cluster_key]
)
if condition_groups is None and condition_key is not None:
condition_groups = adata.obs[condition_key].cat.categories
else:
if not isinstance(condition_groups, list) and not isinstance(condition_groups, np.ndarray):
condition_groups = [condition_groups]

metrics_df = pd.DataFrame(metrics_df)
if cluster_id is not None:
metrics_df = metrics_df[metrics_df[cluster_key].isin(cluster_id)]
if metrics is None:
metrics = [metric for metric in adata.uns[f"shape_{component_key}"].keys() if metric != "boundary"]

metrics_df = pd.melt(
metrics_df[metrics + [condition_key]],
id_vars=[condition_key],
var_name="metric",
)
keys = []
if condition_key is not None:
keys.append(condition_key)
if cluster_key is not None:
keys.append(cluster_key)

conditions = (
enumerate(combinations(adata.obs[condition_key].cat.categories, 2))
if condition_groups is None
else [condition_groups]
)
metrics_df = adata.obs[[component_key] + keys].drop_duplicates().dropna().set_index(component_key)

for condition1, condition2 in conditions:
fig = plt.figure(figsize=figsize)
metrics_condition_pair = metrics_df[metrics_df[condition_key].isin([condition1, condition2])]
ax = sns.boxplot(
data=metrics_condition_pair,
x="metric",
hue=condition_key,
y="value",
showfliers=False,
hue_order=[condition1, condition2],
)
for metric in metrics:
metrics_df[metric] = metrics_df.index.map(adata.uns[f"shape_{component_key}"][metric])

ax.tick_params(labelsize=fontsize)
ax.set_xlabel(ax.get_xlabel(), fontsize=fontsize)
ax.tick_params(labelsize=fontsize)
ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)

adjust_box_widths(fig, 0.9)

ax = sns.stripplot(
data=metrics_condition_pair,
x="metric",
hue=condition_key,
y="value",
color="0.08",
size=4,
jitter=0.13,
dodge=True,
hue_order=condition_groups if condition_groups else None,
)
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[0 : len(metrics_condition_pair[condition_key].unique())],
labels[0 : len(metrics_condition_pair[condition_key].unique())],
bbox_to_anchor=(1.24, 1.02),
fontsize=fontsize,
if cluster_id is not None:
metrics_df = metrics_df[metrics_df[cluster_key].isin(cluster_id)]

metrics_melted = pd.melt(
metrics_df,
id_vars=keys,
value_vars=metrics,
var_name="metric",
)

for count, metric in enumerate(["linearity", "curl"]):
pvalue = ttest_ind(
metrics_condition_pair[
(metrics_condition_pair[condition_key] == condition1) & (metrics_condition_pair["metric"] == metric)
]["value"],
metrics_condition_pair[
(metrics_condition_pair[condition_key] == condition2) & (metrics_condition_pair["metric"] == metric)
]["value"],
)[1]
x1, x2 = count, count
y, h, col = (
metrics_condition_pair[(metrics_condition_pair["metric"] == metric)]["value"].max()
+ 0.02
+ 0.05 * count,
0.01,
"k",
metrics_melted[cluster_key] = metrics_melted[cluster_key].cat.remove_unused_categories()

if cluster_key is not None:
plot_shapes(
metrics_melted,
"metric",
"value",
cluster_key,
cluster_id,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])} by domain',
)
plot_shapes(
metrics_melted,
"metric",
"value",
cluster_key,
cluster_id,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}',
)

if condition_key is not None:
plot_shapes(
metrics_melted,
"metric",
"value",
condition_key,
condition_groups,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])} by condition',
)
plot_shapes(
metrics_melted,
"metric",
"value",
condition_key,
condition_groups,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}',
)
else:
for metric in metrics:
plot_shapes(
metrics_df,
cluster_key if cluster_key is not None else condition_key,
metric,
condition_key if condition_key is not None else cluster_key,
condition_groups if condition_groups is not None else None,
figsize,
f"Spatial domains: {metric}",
)
plt.plot([x1 - 0.2, x1 - 0.2, x2 + 0.2, x2 + 0.2], [y, y + h, y + h, y], lw=1.5, c=col)
if pvalue < 0.05:
plt.text(
(x1 + x2) * 0.5,
y + h * 2,
f"p = {pvalue:.2e}",
ha="center",
va="bottom",
color=col,
fontdict={"fontsize": fontsize},
)
else:
plt.text(
(x1 + x2) * 0.5,
y + h * 2,
"ns",
ha="center",
va="bottom",
color=col,
fontdict={"fontsize": fontsize},
)
if title is not None:
plt.title(title, fontdict={"fontsize": fontsize})
plt.show()
Loading