22
33import warnings
44from dataclasses import dataclass
5+ from functools import singledispatch
56from inspect import signature
67from typing import TYPE_CHECKING , TypedDict , cast
78
1213from fast_array_utils import stats
1314
1415from .. import logging as logg
15- from .._compat import CSBase , DaskArray , old_positionals , warn
16+ from .._compat import CSBase , CSRBase , DaskArray , old_positionals , warn
1617from .._settings import Verbosity , settings
17- from .._utils import check_nonnegative_integers , sanitize_anndata
18+ from .._utils import (
19+ check_nonnegative_integers ,
20+ raise_if_dask_feature_axis_chunked ,
21+ sanitize_anndata ,
22+ )
1823from ..get import _get_obs_rep
1924from ._distributed import materialize_as_ndarray
2025from ._simple import filter_genes
2833 from .._types import HVGFlavor
2934
3035
36+ @singledispatch
37+ def clip_square_sum (
38+ data_batch : np .ndarray , clip_val : np .ndarray
39+ ) -> tuple [np .ndarray , np .ndarray ]:
40+ """Clip data_batch by clip_val.
41+
42+ Parameters
43+ ----------
44+ data_batch
45+ The data to be clipped
46+ clip_val
47+ Clip by these values (must be broadcastable to the input data)
48+
49+ Returns
50+ -------
51+ The clipeed data
52+ """
53+ batch_counts = data_batch .astype (np .float64 ).copy ()
54+ clip_val_broad = np .broadcast_to (clip_val , batch_counts .shape )
55+ np .putmask (
56+ batch_counts ,
57+ batch_counts > clip_val_broad ,
58+ clip_val_broad ,
59+ )
60+
61+ squared_batch_counts_sum = np .square (batch_counts ).sum (axis = 0 )
62+ batch_counts_sum = batch_counts .sum (axis = 0 )
63+ return squared_batch_counts_sum , batch_counts_sum
64+
65+
66+ @clip_square_sum .register (DaskArray )
67+ def _ (data_batch : DaskArray , clip_val : np .ndarray ) -> tuple [np .ndarray , np .ndarray ]:
68+ n_blocks = data_batch .blocks .size
69+
70+ def sum_and_sum_squares_clipped_from_block (block ):
71+ return np .vstack (clip_square_sum (block , clip_val ))[None , ...]
72+
73+ squared_batch_counts_sum , batch_counts_sum = (
74+ data_batch .map_blocks (
75+ sum_and_sum_squares_clipped_from_block ,
76+ new_axis = (1 ,),
77+ chunks = ((1 ,) * n_blocks , (2 ,), (data_batch .shape [1 ],)),
78+ meta = np .array ([]),
79+ dtype = np .float64 ,
80+ )
81+ .sum (axis = 0 )
82+ .compute ()
83+ )
84+ return squared_batch_counts_sum , batch_counts_sum
85+
86+
87+ @clip_square_sum .register (CSBase )
88+ def _ (data_batch : CSBase , clip_val : np .ndarray ) -> tuple [np .ndarray , np .ndarray ]:
89+ batch_counts = data_batch if isinstance (data_batch , CSRBase ) else data_batch .tocsr ()
90+
91+ return _sum_and_sum_squares_clipped (
92+ batch_counts .indices ,
93+ batch_counts .data ,
94+ n_cols = batch_counts .shape [1 ],
95+ clip_val = clip_val ,
96+ nnz = batch_counts .nnz ,
97+ )
98+
99+
100+ # parallel=False needed for accuracy
101+ @numba .njit (cache = True , parallel = False ) # noqa: TID251
102+ def _sum_and_sum_squares_clipped (
103+ indices : NDArray [np .integer ],
104+ data : NDArray [np .floating ],
105+ * ,
106+ n_cols : int ,
107+ clip_val : NDArray [np .float64 ],
108+ nnz : int ,
109+ ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
110+ squared_batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
111+ batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
112+ for i in numba .prange (nnz ):
113+ idx = indices [i ]
114+ element = min (np .float64 (data [i ]), clip_val [idx ])
115+ squared_batch_counts_sum [idx ] += element ** 2
116+ batch_counts_sum [idx ] += element
117+
118+ return squared_batch_counts_sum , batch_counts_sum
119+
120+
31121def _highly_variable_genes_seurat_v3 ( # noqa: PLR0912, PLR0915
32122 adata : AnnData ,
33123 * ,
@@ -70,23 +160,28 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
70160 raise ImportError (msg ) from e
71161 df = pd .DataFrame (index = adata .var_names )
72162 data = _get_obs_rep (adata , layer = layer )
163+ raise_if_dask_feature_axis_chunked (data )
73164
74165 if check_values and not check_nonnegative_integers (data ):
75166 msg = f"`{ flavor = !r} ` expects raw count data, but non-integers were found."
76167 warn (msg , UserWarning )
77168
78169 df ["means" ], df ["variances" ] = stats .mean_var (data , axis = 0 , correction = 1 )
79170
80- if batch_key is None :
81- batch_info = pd .Categorical (np .zeros (adata .shape [0 ], dtype = int ))
82- else :
83- batch_info = adata .obs [batch_key ].to_numpy ()
171+ batch_info = (
172+ pd .Categorical (np .zeros (adata .shape [0 ], dtype = int ))
173+ if batch_key is None
174+ else adata .obs [batch_key ].to_numpy ()
175+ )
84176
85177 norm_gene_vars = []
86178 for b in np .unique (batch_info ):
87179 data_batch = data [batch_info == b ]
88180
89181 mean , var = stats .mean_var (data_batch , axis = 0 , correction = 1 )
182+ # These get computed anyway for loess
183+ if isinstance (mean , DaskArray ):
184+ mean , var = mean .compute (), var .compute ()
90185 not_const = var > 0
91186 estimat_var = np .zeros (data .shape [1 ], dtype = np .float64 )
92187
@@ -99,28 +194,10 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
99194
100195 # clip large values as in Seurat
101196 n_obs = data_batch .shape [0 ]
102- vmax = np .sqrt (n_obs )
103- clip_val = reg_std * vmax + mean
104- if isinstance (data_batch , CSBase ):
105- batch_counts = data_batch .tocsr ()
106- squared_batch_counts_sum , batch_counts_sum = _sum_and_sum_squares_clipped (
107- batch_counts .indices ,
108- batch_counts .data ,
109- n_cols = batch_counts .shape [1 ],
110- clip_val = clip_val ,
111- nnz = batch_counts .nnz ,
112- )
113- else :
114- batch_counts = data_batch .astype (np .float64 ).copy ()
115- clip_val_broad = np .broadcast_to (clip_val , batch_counts .shape )
116- np .putmask (
117- batch_counts ,
118- batch_counts > clip_val_broad ,
119- clip_val_broad ,
120- )
121-
122- squared_batch_counts_sum = np .square (batch_counts ).sum (axis = 0 )
123- batch_counts_sum = batch_counts .sum (axis = 0 )
197+ clip_val = reg_std * np .sqrt (n_obs ) + mean
198+ squared_batch_counts_sum , batch_counts_sum = clip_square_sum (
199+ data_batch , clip_val
200+ )
124201
125202 norm_gene_var = (1 / ((n_obs - 1 ) * np .square (reg_std ))) * (
126203 (n_obs * np .square (mean ))
@@ -142,10 +219,12 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
142219 ma_ranked = np .ma .masked_invalid (ranked_norm_gene_vars )
143220 median_ranked = np .ma .median (ma_ranked , axis = 0 ).filled (np .nan )
144221
145- df ["gene_name" ] = df .index
146- df ["highly_variable_nbatches" ] = num_batches_high_var
147- df ["highly_variable_rank" ] = median_ranked
148- df ["variances_norm" ] = np .mean (norm_gene_vars , axis = 0 )
222+ df = df .assign (
223+ gene_name = df .index ,
224+ highly_variable_nbatches = num_batches_high_var ,
225+ highly_variable_rank = median_ranked ,
226+ variances_norm = np .mean (norm_gene_vars , axis = 0 ),
227+ )
149228 if flavor == "seurat_v3" :
150229 sort_cols = ["highly_variable_rank" , "highly_variable_nbatches" ]
151230 sort_ascending = [True , False ]
@@ -173,10 +252,13 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
173252 " 'variances', float vector (adata.var)\n "
174253 " 'variances_norm', float vector (adata.var)"
175254 )
176- adata .var ["highly_variable" ] = df ["highly_variable" ].to_numpy ()
177- adata .var ["highly_variable_rank" ] = df ["highly_variable_rank" ].to_numpy ()
178- adata .var ["means" ] = df ["means" ].to_numpy ()
179- adata .var ["variances" ] = df ["variances" ].to_numpy ()
255+ for to_numpy_key in [
256+ "highly_variable" ,
257+ "highly_variable_rank" ,
258+ "means" ,
259+ "variances" ,
260+ ]:
261+ adata .var [to_numpy_key ] = df [to_numpy_key ].to_numpy ()
180262 adata .var ["variances_norm" ] = (
181263 df ["variances_norm" ].to_numpy ().astype ("float64" , copy = False )
182264 )
@@ -193,27 +275,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
193275 df = df .iloc [df ["highly_variable" ].to_numpy (), :]
194276
195277 return df
196-
197-
198- # parallel=False needed for accuracy
199- @numba .njit (cache = True , parallel = False ) # noqa: TID251
200- def _sum_and_sum_squares_clipped (
201- indices : NDArray [np .integer ],
202- data : NDArray [np .floating ],
203- * ,
204- n_cols : int ,
205- clip_val : NDArray [np .float64 ],
206- nnz : int ,
207- ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
208- squared_batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
209- batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
210- for i in numba .prange (nnz ):
211- idx = indices [i ]
212- element = min (np .float64 (data [i ]), clip_val [idx ])
213- squared_batch_counts_sum [idx ] += element ** 2
214- batch_counts_sum [idx ] += element
215-
216- return squared_batch_counts_sum , batch_counts_sum
278+ return None
217279
218280
219281@dataclass
0 commit comments