Skip to content

Commit c817e05

Browse files
feat: seurat v3 with dask csr (#3340)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 1472a5f commit c817e05

7 files changed

Lines changed: 180 additions & 87 deletions

File tree

docs/release-notes/3340.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{func}`scanpy.pp.highly_variable_genes` flavors `seurat_v3` and `seurat_v3_paper` are now `dask`-compatible {smaller}`I Gold`

src/scanpy/_utils/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,15 @@ def dematrix[SA: _SupportedArray](x: SA | np.matrix) -> SA:
768768
return x
769769

770770

771+
def raise_if_dask_feature_axis_chunked(x: Any):
772+
if isinstance(x, DaskArray) and x.chunksize[1] != x.shape[1]:
773+
msg = (
774+
"Only dask arrays with chunking along the first axis are supported. "
775+
f"Got chunksize {x.chunksize} with shape {x.shape}. "
776+
)
777+
raise ValueError(msg)
778+
779+
771780
def select_groups(
772781
adata: AnnData,
773782
groups_order_subset: Iterable[str] | Literal["all"] = "all",

src/scanpy/experimental/pp/_normalization.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88

99
from ... import logging as logg
1010
from ..._compat import CSBase, warn
11-
from ..._utils import (
12-
_doc_params,
13-
_empty,
14-
check_nonnegative_integers,
15-
view_to_actual,
16-
)
11+
from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual
1712
from ...experimental._docs import (
1813
doc_adata,
1914
doc_check_values,

src/scanpy/preprocessing/_highly_variable_genes.py

Lines changed: 119 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import warnings
44
from dataclasses import dataclass
5+
from functools import singledispatch
56
from inspect import signature
67
from typing import TYPE_CHECKING, TypedDict, cast
78

@@ -12,9 +13,13 @@
1213
from fast_array_utils import stats
1314

1415
from .. import logging as logg
15-
from .._compat import CSBase, DaskArray, old_positionals, warn
16+
from .._compat import CSBase, CSRBase, DaskArray, old_positionals, warn
1617
from .._settings import Verbosity, settings
17-
from .._utils import check_nonnegative_integers, sanitize_anndata
18+
from .._utils import (
19+
check_nonnegative_integers,
20+
raise_if_dask_feature_axis_chunked,
21+
sanitize_anndata,
22+
)
1823
from ..get import _get_obs_rep
1924
from ._distributed import materialize_as_ndarray
2025
from ._simple import filter_genes
@@ -28,6 +33,91 @@
2833
from .._types import HVGFlavor
2934

3035

36+
@singledispatch
37+
def clip_square_sum(
38+
data_batch: np.ndarray, clip_val: np.ndarray
39+
) -> tuple[np.ndarray, np.ndarray]:
40+
"""Clip data_batch by clip_val.
41+
42+
Parameters
43+
----------
44+
data_batch
45+
The data to be clipped
46+
clip_val
47+
Clip by these values (must be broadcastable to the input data)
48+
49+
Returns
50+
-------
51+
The clipeed data
52+
"""
53+
batch_counts = data_batch.astype(np.float64).copy()
54+
clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape)
55+
np.putmask(
56+
batch_counts,
57+
batch_counts > clip_val_broad,
58+
clip_val_broad,
59+
)
60+
61+
squared_batch_counts_sum = np.square(batch_counts).sum(axis=0)
62+
batch_counts_sum = batch_counts.sum(axis=0)
63+
return squared_batch_counts_sum, batch_counts_sum
64+
65+
66+
@clip_square_sum.register(DaskArray)
67+
def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
68+
n_blocks = data_batch.blocks.size
69+
70+
def sum_and_sum_squares_clipped_from_block(block):
71+
return np.vstack(clip_square_sum(block, clip_val))[None, ...]
72+
73+
squared_batch_counts_sum, batch_counts_sum = (
74+
data_batch.map_blocks(
75+
sum_and_sum_squares_clipped_from_block,
76+
new_axis=(1,),
77+
chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)),
78+
meta=np.array([]),
79+
dtype=np.float64,
80+
)
81+
.sum(axis=0)
82+
.compute()
83+
)
84+
return squared_batch_counts_sum, batch_counts_sum
85+
86+
87+
@clip_square_sum.register(CSBase)
88+
def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
89+
batch_counts = data_batch if isinstance(data_batch, CSRBase) else data_batch.tocsr()
90+
91+
return _sum_and_sum_squares_clipped(
92+
batch_counts.indices,
93+
batch_counts.data,
94+
n_cols=batch_counts.shape[1],
95+
clip_val=clip_val,
96+
nnz=batch_counts.nnz,
97+
)
98+
99+
100+
# parallel=False needed for accuracy
101+
@numba.njit(cache=True, parallel=False) # noqa: TID251
102+
def _sum_and_sum_squares_clipped(
103+
indices: NDArray[np.integer],
104+
data: NDArray[np.floating],
105+
*,
106+
n_cols: int,
107+
clip_val: NDArray[np.float64],
108+
nnz: int,
109+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
110+
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
111+
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
112+
for i in numba.prange(nnz):
113+
idx = indices[i]
114+
element = min(np.float64(data[i]), clip_val[idx])
115+
squared_batch_counts_sum[idx] += element**2
116+
batch_counts_sum[idx] += element
117+
118+
return squared_batch_counts_sum, batch_counts_sum
119+
120+
31121
def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
32122
adata: AnnData,
33123
*,
@@ -70,23 +160,28 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
70160
raise ImportError(msg) from e
71161
df = pd.DataFrame(index=adata.var_names)
72162
data = _get_obs_rep(adata, layer=layer)
163+
raise_if_dask_feature_axis_chunked(data)
73164

74165
if check_values and not check_nonnegative_integers(data):
75166
msg = f"`{flavor=!r}` expects raw count data, but non-integers were found."
76167
warn(msg, UserWarning)
77168

78169
df["means"], df["variances"] = stats.mean_var(data, axis=0, correction=1)
79170

80-
if batch_key is None:
81-
batch_info = pd.Categorical(np.zeros(adata.shape[0], dtype=int))
82-
else:
83-
batch_info = adata.obs[batch_key].to_numpy()
171+
batch_info = (
172+
pd.Categorical(np.zeros(adata.shape[0], dtype=int))
173+
if batch_key is None
174+
else adata.obs[batch_key].to_numpy()
175+
)
84176

85177
norm_gene_vars = []
86178
for b in np.unique(batch_info):
87179
data_batch = data[batch_info == b]
88180

89181
mean, var = stats.mean_var(data_batch, axis=0, correction=1)
182+
# These get computed anyway for loess
183+
if isinstance(mean, DaskArray):
184+
mean, var = mean.compute(), var.compute()
90185
not_const = var > 0
91186
estimat_var = np.zeros(data.shape[1], dtype=np.float64)
92187

@@ -99,28 +194,10 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
99194

100195
# clip large values as in Seurat
101196
n_obs = data_batch.shape[0]
102-
vmax = np.sqrt(n_obs)
103-
clip_val = reg_std * vmax + mean
104-
if isinstance(data_batch, CSBase):
105-
batch_counts = data_batch.tocsr()
106-
squared_batch_counts_sum, batch_counts_sum = _sum_and_sum_squares_clipped(
107-
batch_counts.indices,
108-
batch_counts.data,
109-
n_cols=batch_counts.shape[1],
110-
clip_val=clip_val,
111-
nnz=batch_counts.nnz,
112-
)
113-
else:
114-
batch_counts = data_batch.astype(np.float64).copy()
115-
clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape)
116-
np.putmask(
117-
batch_counts,
118-
batch_counts > clip_val_broad,
119-
clip_val_broad,
120-
)
121-
122-
squared_batch_counts_sum = np.square(batch_counts).sum(axis=0)
123-
batch_counts_sum = batch_counts.sum(axis=0)
197+
clip_val = reg_std * np.sqrt(n_obs) + mean
198+
squared_batch_counts_sum, batch_counts_sum = clip_square_sum(
199+
data_batch, clip_val
200+
)
124201

125202
norm_gene_var = (1 / ((n_obs - 1) * np.square(reg_std))) * (
126203
(n_obs * np.square(mean))
@@ -142,10 +219,12 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
142219
ma_ranked = np.ma.masked_invalid(ranked_norm_gene_vars)
143220
median_ranked = np.ma.median(ma_ranked, axis=0).filled(np.nan)
144221

145-
df["gene_name"] = df.index
146-
df["highly_variable_nbatches"] = num_batches_high_var
147-
df["highly_variable_rank"] = median_ranked
148-
df["variances_norm"] = np.mean(norm_gene_vars, axis=0)
222+
df = df.assign(
223+
gene_name=df.index,
224+
highly_variable_nbatches=num_batches_high_var,
225+
highly_variable_rank=median_ranked,
226+
variances_norm=np.mean(norm_gene_vars, axis=0),
227+
)
149228
if flavor == "seurat_v3":
150229
sort_cols = ["highly_variable_rank", "highly_variable_nbatches"]
151230
sort_ascending = [True, False]
@@ -173,10 +252,13 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
173252
" 'variances', float vector (adata.var)\n"
174253
" 'variances_norm', float vector (adata.var)"
175254
)
176-
adata.var["highly_variable"] = df["highly_variable"].to_numpy()
177-
adata.var["highly_variable_rank"] = df["highly_variable_rank"].to_numpy()
178-
adata.var["means"] = df["means"].to_numpy()
179-
adata.var["variances"] = df["variances"].to_numpy()
255+
for to_numpy_key in [
256+
"highly_variable",
257+
"highly_variable_rank",
258+
"means",
259+
"variances",
260+
]:
261+
adata.var[to_numpy_key] = df[to_numpy_key].to_numpy()
180262
adata.var["variances_norm"] = (
181263
df["variances_norm"].to_numpy().astype("float64", copy=False)
182264
)
@@ -193,27 +275,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
193275
df = df.iloc[df["highly_variable"].to_numpy(), :]
194276

195277
return df
196-
197-
198-
# parallel=False needed for accuracy
199-
@numba.njit(cache=True, parallel=False) # noqa: TID251
200-
def _sum_and_sum_squares_clipped(
201-
indices: NDArray[np.integer],
202-
data: NDArray[np.floating],
203-
*,
204-
n_cols: int,
205-
clip_val: NDArray[np.float64],
206-
nnz: int,
207-
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
208-
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
209-
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
210-
for i in numba.prange(nnz):
211-
idx = indices[i]
212-
element = min(np.float64(data[i]), clip_val[idx])
213-
squared_batch_counts_sum[idx] += element**2
214-
batch_counts_sum[idx] += element
215-
216-
return squared_batch_counts_sum, batch_counts_sum
278+
return None
217279

218280

219281
@dataclass

src/scanpy/preprocessing/_pca/_dask.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import scipy.linalg
88
from fast_array_utils import stats
99

10+
from scanpy._utils import raise_if_dask_feature_axis_chunked
1011
from scanpy._utils._doctests import doctest_needs
1112

1213
from ..._compat import CSBase
@@ -52,13 +53,7 @@ def fit(self, x: DaskArray) -> PCAEighDaskFit:
5253
f"Got {x._meta.format} as meta."
5354
)
5455
raise ValueError(msg)
55-
if x.chunksize[1] != x.shape[1]:
56-
msg = (
57-
"Only dask arrays with chunking along the first axis are supported. "
58-
f"Got chunksize {x.chunksize} with shape {x.shape}. "
59-
"Rechunking should be simple and cost nothing from AnnData's on-disk format when the on-disk layout has this chunking."
60-
)
61-
raise ValueError(msg)
56+
raise_if_dask_feature_axis_chunked(x)
6257
self.__class__ = PCAEighDaskFit
6358
self = cast("PCAEighDaskFit", self) # noqa: PLW0642
6459

src/testing/scanpy/_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def as_dense_dask_array(*args, **kwargs) -> DaskArray:
129129
from anndata.tests.helpers import as_dense_dask_array
130130

131131
a = as_dense_dask_array(*args, **kwargs)
132+
# Newer versions of as_dense_dask_array chunk all axes by halve when the input is not a dask array.
132133
if (
133134
pkg_version("anndata") < Version("0.11")
134-
and a.chunksize == a.shape
135135
and not isinstance(args[0], DaskArray) # keep chunksize intact
136136
):
137137
from anndata.tests.helpers import _half_chunk_size

0 commit comments

Comments
 (0)