@@ -182,9 +182,16 @@ def initializer(neighbors, weights, num_umi, model, centered, Wtot2, D):
182182 g_D = D
183183
184184def compute_hs (
185- counts , neighbors , weights , num_umi , model , genes , centered = False , jobs = 1
185+ counts , neighbors , weights , num_umi , model , genes , centered = False , jobs = 1 ,
186+ use_gpu = False
186187):
187188
189+ if use_gpu :
190+ results = _compute_hs_gpu (
191+ counts , neighbors , weights , num_umi , model , genes , centered
192+ )
193+ return _postprocess_results (results )
194+
188195 neighbors = neighbors .values
189196 weights = weights .values
190197 num_umi = num_umi .values
@@ -202,16 +209,16 @@ def data_iter():
202209
203210 if jobs > 1 :
204211 with multiprocessing .Pool (
205- processes = jobs ,
206- initializer = initializer ,
212+ processes = jobs ,
213+ initializer = initializer ,
207214 initargs = [neighbors , weights , num_umi , model , centered , Wtot2 , D ]
208215 ) as pool :
209216 results = list (
210217 tqdm (
211218 pool .imap (
212219 _map_fun_parallel ,
213220 data_iter ()
214- ),
221+ ),
215222 total = counts .shape [0 ]
216223 )
217224 )
@@ -226,26 +233,23 @@ def _map_fun(vals):
226233
227234 results = pd .DataFrame (results , index = genes , columns = ["G" , "EG" , "stdG" , "Z" , "C" ])
228235
236+ return _postprocess_results (results )
237+
238+
239+ def _postprocess_results (results ):
229240 results ["Pval" ] = norm .sf (results ["Z" ].values )
230241 results ["FDR" ] = multipletests (results ["Pval" ], method = "fdr_bh" )[1 ]
231-
232242 results = results .sort_values ("Z" , ascending = False )
233243 results .index .name = "Gene"
234-
235- results = results [["C" , "Z" , "Pval" , "FDR" ]] # Remove other columns
236-
244+ results = results [["C" , "Z" , "Pval" , "FDR" ]]
237245 return results
238246
239247
240- def _compute_hs_inner (vals , neighbors , weights , num_umi , model , centered , Wtot2 , D ):
241- """
242- Note, since this is an inner function, for parallelization to work well
243- none of the contents of the function can use MKL or OPENBLAS threads.
244- Or else we open too many. Because of this, some simple numpy operations
245- are re-implemented using numba instead as it's difficult to control
246- the number of threads in numpy after it's imported
247- """
248+ def _fit_gene (vals , model , num_umi ):
249+ """Fit a gene model and return (vals, mu, var, x2).
248250
251+ For the bernoulli model, vals is binarized before fitting.
252+ """
249253 if model == "bernoulli" :
250254 vals = (vals > 0 ).astype ("double" )
251255 mu , var , x2 = bernoulli_model .fit_gene_model (vals , num_umi )
@@ -256,7 +260,20 @@ def _compute_hs_inner(vals, neighbors, weights, num_umi, model, centered, Wtot2,
256260 elif model == "none" :
257261 mu , var , x2 = none_model .fit_gene_model (vals , num_umi )
258262 else :
259- raise Exception ("Invalid Model: {}" .format (model ))
263+ raise ValueError ("Invalid Model: {}" .format (model ))
264+ return vals , mu , var , x2
265+
266+
267+ def _compute_hs_inner (vals , neighbors , weights , num_umi , model , centered , Wtot2 , D ):
268+ """
269+ Note, since this is an inner function, for parallelization to work well
270+ none of the contents of the function can use MKL or OPENBLAS threads.
271+ Or else we open too many. Because of this, some simple numpy operations
272+ are re-implemented using numba instead as it's difficult to control
273+ the number of threads in numpy after it's imported
274+ """
275+
276+ vals , mu , var , x2 = _fit_gene (vals , model , num_umi )
260277
261278 if centered :
262279 vals = center_values (vals , mu , var )
@@ -289,3 +306,112 @@ def _map_fun_parallel(vals):
289306 return _compute_hs_inner (
290307 vals , g_neighbors , g_weights , g_num_umi , g_model , g_centered , g_Wtot2 , g_D
291308 )
309+
310+
311+ def _local_cov_weights_gpu (vals_gpu , W ):
312+ """GPU batch of local_cov_weights: G[g] = vals[g] . (W @ vals[g]) for all genes."""
313+ smoothed_T = W @ vals_gpu .T
314+ return (vals_gpu * smoothed_T .T ).sum (axis = 1 )
315+
316+
317+ def _compute_moments_weights_gpu (mu_gpu , x2_gpu , W , W_sq ):
318+ """GPU batch of compute_moments_weights for all genes at once."""
319+ # EG[g] = mu[g] . (W @ mu[g])
320+ EG = (mu_gpu * (W @ mu_gpu .T ).T ).sum (axis = 1 )
321+
322+ # t1[g] = (W + W.T) @ mu[g], t2[g] = (W_sq + W_sq.T) @ mu[g]^2
323+ W_sym = W + W .T
324+ W_sq_sym = W_sq + W_sq .T
325+ mu2_gpu = mu_gpu ** 2
326+
327+ t1_T = W_sym @ mu_gpu .T
328+ t2_T = W_sq_sym @ mu2_gpu .T
329+
330+ # Contribution 1: sum_i (x2[i] - mu[i]^2) * (t1[i]^2 - t2[i])
331+ diff_var = (x2_gpu - mu2_gpu ).T
332+ eg2_c1 = (diff_var * (t1_T ** 2 - t2_T )).sum (axis = 0 )
333+
334+ # Contribution 2: sum_{edges} w^2 * (x2[i]*x2[j] - mu[i]^2*mu[j]^2)
335+ eg2_c2 = (x2_gpu .T * (W_sq @ x2_gpu .T )).sum (axis = 0 )
336+ eg2_c2 -= (mu2_gpu .T * (W_sq @ mu2_gpu .T )).sum (axis = 0 )
337+
338+ EG2 = eg2_c1 + eg2_c2 + EG ** 2
339+ return EG , EG2
340+
341+
342+ def _compute_local_cov_max_gpu (D_gpu , vals_gpu ):
343+ """GPU batch of compute_local_cov_max: G_max[g] = sum_i D[i]*vals[g,i]^2 / 2."""
344+ return (D_gpu * vals_gpu ** 2 ).sum (axis = 1 ) / 2
345+
346+
347+ def _compute_hs_gpu (counts , neighbors , weights , num_umi , model , genes , centered ):
348+ """
349+ GPU-accelerated version of _compute_hs_inner, batched over all genes.
350+ All genes are processed in parallel via sparse matrix multiplication.
351+ """
352+ import cupy as cp
353+ from .gpu import _require_gpu , _build_sparse_weight_matrix
354+
355+ _require_gpu ()
356+
357+ neighbors_np = neighbors .values
358+ weights_np = weights .values
359+ num_umi_np = num_umi .values
360+
361+ N_genes = counts .shape [0 ]
362+ N_cells = counts .shape [1 ]
363+
364+ D = compute_node_degree (neighbors_np , weights_np )
365+ Wtot2 = (weights_np ** 2 ).sum ()
366+
367+ if issparse (counts ):
368+ counts_dense = counts .toarray ()
369+ else :
370+ counts_dense = np .asarray (counts )
371+
372+ all_vals = np .zeros ((N_genes , N_cells ), dtype = "double" )
373+ if not centered :
374+ all_mu = np .zeros ((N_genes , N_cells ), dtype = "double" )
375+ all_x2 = np .zeros ((N_genes , N_cells ), dtype = "double" )
376+
377+ for i in range (N_genes ):
378+ raw = counts_dense [i ].astype ("double" )
379+
380+ vals , mu , var , x2 = _fit_gene (raw , model , num_umi_np )
381+ if centered :
382+ vals = center_values (vals , mu , var )
383+ else :
384+ all_mu [i ] = mu
385+ all_x2 [i ] = x2
386+ all_vals [i ] = vals
387+
388+ vals_gpu = cp .asarray (all_vals )
389+ D_gpu = cp .asarray (D )
390+ W = _build_sparse_weight_matrix (neighbors_np , weights_np , shape = (N_cells , N_cells ))
391+
392+ G_stats = _local_cov_weights_gpu (vals_gpu , W )
393+
394+ if centered :
395+ EG = cp .zeros (N_genes , dtype = "double" )
396+ EG2 = cp .full (N_genes , Wtot2 , dtype = "double" )
397+ else :
398+ mu_gpu = cp .asarray (all_mu )
399+ x2_gpu = cp .asarray (all_x2 )
400+ W_sq = _build_sparse_weight_matrix (
401+ neighbors_np , weights_np , shape = (N_cells , N_cells ), square = True
402+ )
403+ EG , EG2 = _compute_moments_weights_gpu (mu_gpu , x2_gpu , W , W_sq )
404+
405+ stdG = (EG2 - EG * EG ) ** 0.5
406+ Z = (G_stats - EG ) / stdG
407+
408+ G_max = _compute_local_cov_max_gpu (D_gpu , vals_gpu )
409+ C = (G_stats - EG ) / G_max
410+
411+ return pd .DataFrame (
412+ {
413+ "G" : cp .asnumpy (G_stats ), "EG" : cp .asnumpy (EG ),
414+ "stdG" : cp .asnumpy (stdG ), "Z" : cp .asnumpy (Z ), "C" : cp .asnumpy (C ),
415+ },
416+ index = genes ,
417+ )
0 commit comments