Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 165 additions & 67 deletions src/rapids_singlecell/tools/_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING

import cudf
import cupy as cp
import numpy as np
import pandas as pd
from natsort import natsorted
Expand All @@ -15,7 +16,6 @@
if TYPE_CHECKING:
from collections.abc import Sequence

import cupy as cp
from anndata import AnnData
from scipy import sparse

Expand Down Expand Up @@ -63,9 +63,64 @@ def _create_graph(adjacency, dtype=np.float64, *, use_weights=True):
return g


def _create_graph_dask(adjacency, dtype=np.float64, *, use_weights=True):
import cudf
import dask.dataframe as dd
from cugraph import Graph

rows = np.repeat(np.arange(adjacency.shape[0]), np.diff(adjacency.indptr)).astype(
np.int32
)
cols = adjacency.indices
weights = adjacency.data

n_devices = cp.cuda.runtime.getDeviceCount()
chunksize = int((adjacency.nnz + n_devices - 1) / n_devices)

boundaries = list(range(0, adjacency.nnz, chunksize))
pairs = [(start, min(start + chunksize, adjacency.nnz)) for start in boundaries]

def mapper(pair):
start, end = pair
return cudf.DataFrame(
{
"src": rows[start:end].astype(np.int64),
"dst": cols[start:end].astype(np.int64),
"weight": weights[start:end].astype(dtype),
}
)

# meta must match the actual columns
meta = {
"src": np.int64,
"dst": np.int64,
"weight": dtype,
}

ddf = dd.from_map(mapper, pairs, meta=meta).to_backend("cudf").persist()
import cugraph.dask.comms.comms as Comms

Comms.initialize(p2p=True)
g = Graph()
if use_weights:
g.from_dask_cudf_edgelist(
ddf,
source="src",
destination="dst",
weight="weight",
)
else:
g.from_dask_cudf_edgelist(
ddf,
source="src",
destination="dst",
)
return g


def leiden(
adata: AnnData,
resolution: float = 1.0,
resolution: float | list[float] = 1.0,
*,
random_state: int | None = 0,
theta: float = 1.0,
Expand All @@ -77,6 +132,7 @@ def leiden(
neighbors_key: str | None = None,
obsp: str | None = None,
dtype: str | np.dtype | cp.dtype = np.float32,
use_dask: bool = False,
copy: bool = False,
) -> AnnData | None:
"""
Expand All @@ -93,9 +149,9 @@ def leiden(
annData object

resolution
A parameter value controlling the coarseness of the clustering.
A parameter value or a list of parameter values controlling the coarseness of the clustering.
(called gamma in the modularity formula). Higher values lead to
more clusters.
more clusters. If a list of values is provided, the Leiden algorithm will be run for each value in the list.

random_state
Change the initialization of the optimization. Defaults to 0.
Expand Down Expand Up @@ -140,11 +196,13 @@ def leiden(
dtype
Data type to use for the adjacency matrix.

use_dask
If `True`, use Dask to create the graph and cluster. This will use all GPUs available. This feature is experimental. For datasets with less than 10 Million cells, it is recommended to use `use_dask=False`.

copy
Whether to copy `adata` or modify it in place.
"""
# Adjacency graph
from cugraph import leiden as culeiden

adata = adata.copy() if copy else adata

Expand All @@ -160,40 +218,61 @@ def leiden(
restrict_categories=restrict_categories,
adjacency=adjacency,
)
if use_dask:
from cugraph.dask import leiden as culeiden

g = _create_graph_dask(adjacency, dtype, use_weights=use_weights)
else:
from cugraph import leiden as culeiden

g = _create_graph(adjacency, dtype, use_weights=use_weights)
g = _create_graph(adjacency, dtype, use_weights=use_weights)
# Cluster
leiden_parts, _ = culeiden(
g,
resolution=resolution,
random_state=random_state,
theta=theta,
max_iter=n_iterations,
)
if isinstance(resolution, float | int):
resolutions = [resolution]
else:
resolutions = resolution
for resolution in resolutions:
leiden_parts, _ = culeiden(
g,
resolution=resolution,
random_state=random_state,
theta=theta,
max_iter=n_iterations,
)
if use_dask:
leiden_parts = leiden_parts.to_backend("pandas").compute()
else:
leiden_parts = leiden_parts.to_pandas()

# Format output
groups = leiden_parts.sort_values("vertex")[["partition"]].to_numpy().ravel()
key_added_to_use = key_added
if restrict_to is not None:
if key_added == "leiden":
key_added_to_use += "_R"
groups = rename_groups(
adata,
key_added=key_added_to_use,
restrict_key=restrict_key,
restrict_categories=restrict_categories,
restrict_indices=restrict_indices,
groups=groups,
)
if len(resolutions) > 1:
key_added_to_use += f"_{resolution}"

# Format output
groups = (
leiden_parts.to_pandas().sort_values("vertex")[["partition"]].to_numpy().ravel()
)
if restrict_to is not None:
if key_added == "leiden":
key_added += "_R"
groups = rename_groups(
adata,
key_added=key_added,
restrict_key=restrict_key,
restrict_categories=restrict_categories,
restrict_indices=restrict_indices,
groups=groups,
adata.obs[key_added_to_use] = pd.Categorical(
values=groups.astype("U"),
categories=natsorted(map(str, np.unique(groups))),
)
adata.obs[key_added] = pd.Categorical(
values=groups.astype("U"),
categories=natsorted(map(str, np.unique(groups))),
)
if use_dask:
import cugraph.dask.comms.comms as Comms

Comms.destroy()
# store information on the clustering parameters
adata.uns[key_added] = {}
adata.uns[key_added]["params"] = {
"resolution": resolution,
"resolution": resolutions,
"random_state": random_state,
"n_iterations": n_iterations,
}
Expand All @@ -202,7 +281,7 @@ def leiden(

def louvain(
adata: AnnData,
resolution: float = 1.0,
resolution: float | list[float] = 1.0,
*,
restrict_to: tuple[str, Sequence[str]] | None = None,
key_added: str = "louvain",
Expand All @@ -213,6 +292,7 @@ def louvain(
neighbors_key: int | None = None,
obsp: str | None = None,
dtype: str | np.dtype | cp.dtype = np.float32,
use_dask: bool = False,
copy: bool = False,
) -> AnnData | None:
"""
Expand All @@ -229,9 +309,9 @@ def louvain(
annData object

resolution
A parameter value controlling the coarseness of the clustering
A parameter value or a list of parameter values controlling the coarseness of the clustering.
(called gamma in the modularity formula). Higher values lead to
more clusters.
more clusters. If a list of values is provided, the Leiden algorithm will be run for each value in the list.

restrict_to
Restrict the clustering to the categories within the key for
Expand Down Expand Up @@ -275,13 +355,14 @@ def louvain(
dtype
Data type to use for the adjacency matrix.

use_dask
If `True`, use Dask to create the graph and cluster. This will use all GPUs available. This feature is experimental. For datasets with less than 10 Million cells, it is recommended to use `use_dask=False`.

copy
Whether to copy `adata` or modify it in place.

"""
# Adjacency graph
from cugraph import louvain as culouvain

dtype = _check_dtype(dtype)

adata = adata.copy() if copy else adata
Expand All @@ -295,43 +376,60 @@ def louvain(
restrict_categories=restrict_categories,
adjacency=adjacency,
)
# Cluster
if use_dask:
from cugraph.dask import louvain as culouvain

g = _create_graph(adjacency, dtype, use_weights=use_weights)
g = _create_graph_dask(adjacency, dtype, use_weights=use_weights)
else:
from cugraph import louvain as culouvain

# Cluster
louvain_parts, _ = culouvain(
g,
resolution=resolution,
max_level=n_iterations,
threshold=threshold,
)
g = _create_graph(adjacency, dtype, use_weights=use_weights)

# Format output
groups = (
louvain_parts.to_pandas()
.sort_values("vertex")[["partition"]]
.to_numpy()
.ravel()
)
if restrict_to is not None:
if key_added == "louvain":
key_added += "_R"
groups = rename_groups(
adata,
key_added=key_added,
restrict_key=restrict_key,
restrict_categories=restrict_categories,
restrict_indices=restrict_indices,
groups=groups,
if isinstance(resolution, float | int):
resolutions = [resolution]
else:
resolutions = resolution
for resolution in resolutions:
louvain_parts, _ = culouvain(
g,
resolution=resolution,
max_level=n_iterations,
threshold=threshold,
)
if use_dask:
louvain_parts = louvain_parts.to_backend("pandas").compute()
else:
louvain_parts = louvain_parts.to_pandas()

# Format output
groups = louvain_parts.sort_values("vertex")[["partition"]].to_numpy().ravel()
key_added_to_use = key_added
if restrict_to is not None:
if key_added == "louvain":
key_added_to_use += "_R"
groups = rename_groups(
adata,
key_added=key_added_to_use,
restrict_key=restrict_key,
restrict_categories=restrict_categories,
restrict_indices=restrict_indices,
groups=groups,
)
if len(resolutions) > 1:
key_added_to_use += f"_{resolution}"

adata.obs[key_added] = pd.Categorical(
values=groups.astype("U"),
categories=natsorted(map(str, np.unique(groups))),
)
adata.obs[key_added_to_use] = pd.Categorical(
values=groups.astype("U"),
categories=natsorted(map(str, np.unique(groups))),
)
if use_dask:
import cugraph.dask.comms.comms as Comms

Comms.destroy()
adata.uns[key_added] = {}
adata.uns[key_added]["params"] = {
"resolution": resolution,
"resolution": resolutions,
"n_iterations": n_iterations,
"threshold": threshold,
}
Expand Down
36 changes: 36 additions & 0 deletions tests/dask/test_dask_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import pytest
from scanpy.datasets import pbmc3k_processed
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

import rapids_singlecell as rsc


@pytest.mark.parametrize("clustering_function", [rsc.tl.leiden, rsc.tl.louvain])
def test_dask_clustering(client, clustering_function):
adata = pbmc3k_processed()
clustering_function(adata, use_dask=True, key_added="test_dask")
clustering_function(adata, key_added="test_no_dask")

assert adjusted_rand_score(adata.obs["test_dask"], adata.obs["test_no_dask"]) > 0.9
assert (
normalized_mutual_info_score(adata.obs["test_dask"], adata.obs["test_no_dask"])
> 0.9
)


@pytest.mark.parametrize("clustering_function", [rsc.tl.leiden, rsc.tl.louvain])
@pytest.mark.parametrize("resolution", [0.1, [0.5, 1.0]])
def test_dask_clustering_resolution(client, clustering_function, resolution):
adata = pbmc3k_processed()
print(resolution)
clustering_function(
adata, use_dask=True, key_added="test_dask", resolution=resolution
)
print(adata.obs.columns)
if isinstance(resolution, list):
for r in resolution:
assert f"test_dask_{r}" in adata.obs.columns
else:
assert "test_dask" in adata.obs.columns
Loading