Skip to content

Commit f985599

Browse files
authored
Merge pull request #56 from justjhong/jhong/hotspotgpu
Add GPU support for compute_autocorrelations and compute_local_correlations
2 parents b810831 + 5595196 commit f985599

File tree

8 files changed

+822
-57
lines changed

8 files changed

+822
-57
lines changed

hotspot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name = "hotspot"
22

33
from .hotspot import Hotspot
4+
from .gpu import is_gpu_available
45

56
# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
67
# https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302

hotspot/gpu.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
GPU utilities for Hotspot. Provides CuPy availability checks and
3+
sparse matrix construction helpers used by the GPU paths in
4+
local_stats.py and local_stats_pairs.py.
5+
"""
6+
7+
import numpy as np
8+
9+
try:
10+
import cupy as cp
11+
import cupyx.scipy.sparse as cp_sparse
12+
13+
HAS_CUPY = True
14+
except ImportError:
15+
HAS_CUPY = False
16+
17+
18+
def is_gpu_available():
19+
"""Check whether GPU acceleration is available (CuPy installed + CUDA device present)."""
20+
if not HAS_CUPY:
21+
return False
22+
try:
23+
cp.cuda.Device(0).compute_capability
24+
return True
25+
except Exception:
26+
return False
27+
28+
29+
def _require_gpu():
30+
"""Raise an informative error if GPU is not available."""
31+
if not HAS_CUPY:
32+
raise ImportError(
33+
"CuPy is required for GPU acceleration. "
34+
"Install with: pip install hotspotsc[gpu] "
35+
"(or install cupy separately for your CUDA version, "
36+
"e.g. pip install cupy-cuda12x)"
37+
)
38+
try:
39+
cp.cuda.Device(0).compute_capability
40+
except Exception as e:
41+
raise RuntimeError(
42+
"No CUDA-capable GPU device found. "
43+
"GPU acceleration requires an NVIDIA GPU with CUDA support."
44+
) from e
45+
46+
47+
def _build_sparse_weight_matrix(neighbors, weights, shape, square=False):
48+
"""Build a CuPy sparse CSR matrix from neighbor/weight arrays.
49+
50+
W[i, neighbors[i,k]] = weights[i,k] for all i, k where weights[i,k] != 0.
51+
If square=True, uses weights^2 instead (for moment computations).
52+
"""
53+
N, K = neighbors.shape
54+
rows = np.repeat(np.arange(N, dtype=np.int32), K)
55+
cols = neighbors.ravel().astype(np.int32)
56+
vals = weights.ravel().astype(np.float64)
57+
if square:
58+
vals = vals ** 2
59+
60+
mask = vals != 0
61+
rows, cols, vals = rows[mask], cols[mask], vals[mask]
62+
63+
return cp_sparse.csr_matrix(
64+
(cp.asarray(vals), (cp.asarray(rows), cp.asarray(cols))),
65+
shape=shape,
66+
)

hotspot/hotspot.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from . import modules
1818
from .plots import local_correlation_plot
19+
from .gpu import is_gpu_available
1920
from tqdm import tqdm
2021

2122

@@ -29,6 +30,7 @@ def __init__(
2930
distances_obsp_key=None,
3031
tree=None,
3132
umi_counts_obs_key=None,
33+
use_gpu=False,
3234
):
3335
"""Initialize a Hotspot object for analysis
3436
@@ -60,6 +62,10 @@ def __init__(
6062
umi_counts_obs_key : str
6163
Total umi count per cell. Used as a size factor.
6264
If omitted, the sum over genes in the counts matrix is used
65+
use_gpu : bool, optional
66+
Whether to use GPU acceleration via CuPy for embarrassingly
67+
parallel operations. Requires CuPy and a CUDA-capable GPU.
68+
Default is False.
6369
"""
6470
counts = self._counts_from_anndata(adata, layer_key)
6571
distances = (
@@ -166,6 +172,14 @@ def __init__(
166172
self.linkage = None
167173
self.module_scores = None
168174

175+
if use_gpu:
176+
if not is_gpu_available():
177+
raise RuntimeError(
178+
"use_gpu=True but GPU is not available. "
179+
"Ensure CuPy is installed and a CUDA GPU is present."
180+
)
181+
self.use_gpu = use_gpu
182+
169183
@classmethod
170184
def legacy_init(
171185
cls,
@@ -416,6 +430,7 @@ def _compute_hotspot(self, jobs=1):
416430
genes=self.adata.var_names,
417431
centered=True,
418432
jobs=jobs,
433+
use_gpu=self.use_gpu,
419434
)
420435

421436
self.results = results
@@ -486,6 +501,7 @@ def compute_local_correlations(self, genes, jobs=1):
486501
self.umi_counts,
487502
self.model,
488503
jobs=jobs,
504+
use_gpu=self.use_gpu,
489505
)
490506

491507
self.local_correlation_c = lc

hotspot/local_stats.py

Lines changed: 143 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,16 @@ def initializer(neighbors, weights, num_umi, model, centered, Wtot2, D):
182182
g_D = D
183183

184184
def compute_hs(
185-
counts, neighbors, weights, num_umi, model, genes, centered=False, jobs=1
185+
counts, neighbors, weights, num_umi, model, genes, centered=False, jobs=1,
186+
use_gpu=False
186187
):
187188

189+
if use_gpu:
190+
results = _compute_hs_gpu(
191+
counts, neighbors, weights, num_umi, model, genes, centered
192+
)
193+
return _postprocess_results(results)
194+
188195
neighbors = neighbors.values
189196
weights = weights.values
190197
num_umi = num_umi.values
@@ -202,16 +209,16 @@ def data_iter():
202209

203210
if jobs > 1:
204211
with multiprocessing.Pool(
205-
processes=jobs,
206-
initializer=initializer,
212+
processes=jobs,
213+
initializer=initializer,
207214
initargs=[neighbors, weights, num_umi, model, centered, Wtot2, D]
208215
) as pool:
209216
results = list(
210217
tqdm(
211218
pool.imap(
212219
_map_fun_parallel,
213220
data_iter()
214-
),
221+
),
215222
total=counts.shape[0]
216223
)
217224
)
@@ -226,26 +233,23 @@ def _map_fun(vals):
226233

227234
results = pd.DataFrame(results, index=genes, columns=["G", "EG", "stdG", "Z", "C"])
228235

236+
return _postprocess_results(results)
237+
238+
239+
def _postprocess_results(results):
229240
results["Pval"] = norm.sf(results["Z"].values)
230241
results["FDR"] = multipletests(results["Pval"], method="fdr_bh")[1]
231-
232242
results = results.sort_values("Z", ascending=False)
233243
results.index.name = "Gene"
234-
235-
results = results[["C", "Z", "Pval", "FDR"]] # Remove other columns
236-
244+
results = results[["C", "Z", "Pval", "FDR"]]
237245
return results
238246

239247

240-
def _compute_hs_inner(vals, neighbors, weights, num_umi, model, centered, Wtot2, D):
241-
"""
242-
Note, since this is an inner function, for parallelization to work well
243-
none of the contents of the function can use MKL or OPENBLAS threads.
244-
Or else we open too many. Because of this, some simple numpy operations
245-
are re-implemented using numba instead as it's difficult to control
246-
the number of threads in numpy after it's imported
247-
"""
248+
def _fit_gene(vals, model, num_umi):
249+
"""Fit a gene model and return (vals, mu, var, x2).
248250
251+
For the bernoulli model, vals is binarized before fitting.
252+
"""
249253
if model == "bernoulli":
250254
vals = (vals > 0).astype("double")
251255
mu, var, x2 = bernoulli_model.fit_gene_model(vals, num_umi)
@@ -256,7 +260,20 @@ def _compute_hs_inner(vals, neighbors, weights, num_umi, model, centered, Wtot2,
256260
elif model == "none":
257261
mu, var, x2 = none_model.fit_gene_model(vals, num_umi)
258262
else:
259-
raise Exception("Invalid Model: {}".format(model))
263+
raise ValueError("Invalid Model: {}".format(model))
264+
return vals, mu, var, x2
265+
266+
267+
def _compute_hs_inner(vals, neighbors, weights, num_umi, model, centered, Wtot2, D):
268+
"""
269+
Note, since this is an inner function, for parallelization to work well
270+
none of the contents of the function can use MKL or OPENBLAS threads.
271+
Or else we open too many. Because of this, some simple numpy operations
272+
are re-implemented using numba instead as it's difficult to control
273+
the number of threads in numpy after it's imported
274+
"""
275+
276+
vals, mu, var, x2 = _fit_gene(vals, model, num_umi)
260277

261278
if centered:
262279
vals = center_values(vals, mu, var)
@@ -289,3 +306,112 @@ def _map_fun_parallel(vals):
289306
return _compute_hs_inner(
290307
vals, g_neighbors, g_weights, g_num_umi, g_model, g_centered, g_Wtot2, g_D
291308
)
309+
310+
311+
def _local_cov_weights_gpu(vals_gpu, W):
312+
"""GPU batch of local_cov_weights: G[g] = vals[g] . (W @ vals[g]) for all genes."""
313+
smoothed_T = W @ vals_gpu.T
314+
return (vals_gpu * smoothed_T.T).sum(axis=1)
315+
316+
317+
def _compute_moments_weights_gpu(mu_gpu, x2_gpu, W, W_sq):
318+
"""GPU batch of compute_moments_weights for all genes at once."""
319+
# EG[g] = mu[g] . (W @ mu[g])
320+
EG = (mu_gpu * (W @ mu_gpu.T).T).sum(axis=1)
321+
322+
# t1[g] = (W + W.T) @ mu[g], t2[g] = (W_sq + W_sq.T) @ mu[g]^2
323+
W_sym = W + W.T
324+
W_sq_sym = W_sq + W_sq.T
325+
mu2_gpu = mu_gpu ** 2
326+
327+
t1_T = W_sym @ mu_gpu.T
328+
t2_T = W_sq_sym @ mu2_gpu.T
329+
330+
# Contribution 1: sum_i (x2[i] - mu[i]^2) * (t1[i]^2 - t2[i])
331+
diff_var = (x2_gpu - mu2_gpu).T
332+
eg2_c1 = (diff_var * (t1_T ** 2 - t2_T)).sum(axis=0)
333+
334+
# Contribution 2: sum_{edges} w^2 * (x2[i]*x2[j] - mu[i]^2*mu[j]^2)
335+
eg2_c2 = (x2_gpu.T * (W_sq @ x2_gpu.T)).sum(axis=0)
336+
eg2_c2 -= (mu2_gpu.T * (W_sq @ mu2_gpu.T)).sum(axis=0)
337+
338+
EG2 = eg2_c1 + eg2_c2 + EG ** 2
339+
return EG, EG2
340+
341+
342+
def _compute_local_cov_max_gpu(D_gpu, vals_gpu):
343+
"""GPU batch of compute_local_cov_max: G_max[g] = sum_i D[i]*vals[g,i]^2 / 2."""
344+
return (D_gpu * vals_gpu ** 2).sum(axis=1) / 2
345+
346+
347+
def _compute_hs_gpu(counts, neighbors, weights, num_umi, model, genes, centered):
348+
"""
349+
GPU-accelerated version of _compute_hs_inner, batched over all genes.
350+
All genes are processed in parallel via sparse matrix multiplication.
351+
"""
352+
import cupy as cp
353+
from .gpu import _require_gpu, _build_sparse_weight_matrix
354+
355+
_require_gpu()
356+
357+
neighbors_np = neighbors.values
358+
weights_np = weights.values
359+
num_umi_np = num_umi.values
360+
361+
N_genes = counts.shape[0]
362+
N_cells = counts.shape[1]
363+
364+
D = compute_node_degree(neighbors_np, weights_np)
365+
Wtot2 = (weights_np ** 2).sum()
366+
367+
if issparse(counts):
368+
counts_dense = counts.toarray()
369+
else:
370+
counts_dense = np.asarray(counts)
371+
372+
all_vals = np.zeros((N_genes, N_cells), dtype="double")
373+
if not centered:
374+
all_mu = np.zeros((N_genes, N_cells), dtype="double")
375+
all_x2 = np.zeros((N_genes, N_cells), dtype="double")
376+
377+
for i in range(N_genes):
378+
raw = counts_dense[i].astype("double")
379+
380+
vals, mu, var, x2 = _fit_gene(raw, model, num_umi_np)
381+
if centered:
382+
vals = center_values(vals, mu, var)
383+
else:
384+
all_mu[i] = mu
385+
all_x2[i] = x2
386+
all_vals[i] = vals
387+
388+
vals_gpu = cp.asarray(all_vals)
389+
D_gpu = cp.asarray(D)
390+
W = _build_sparse_weight_matrix(neighbors_np, weights_np, shape=(N_cells, N_cells))
391+
392+
G_stats = _local_cov_weights_gpu(vals_gpu, W)
393+
394+
if centered:
395+
EG = cp.zeros(N_genes, dtype="double")
396+
EG2 = cp.full(N_genes, Wtot2, dtype="double")
397+
else:
398+
mu_gpu = cp.asarray(all_mu)
399+
x2_gpu = cp.asarray(all_x2)
400+
W_sq = _build_sparse_weight_matrix(
401+
neighbors_np, weights_np, shape=(N_cells, N_cells), square=True
402+
)
403+
EG, EG2 = _compute_moments_weights_gpu(mu_gpu, x2_gpu, W, W_sq)
404+
405+
stdG = (EG2 - EG * EG) ** 0.5
406+
Z = (G_stats - EG) / stdG
407+
408+
G_max = _compute_local_cov_max_gpu(D_gpu, vals_gpu)
409+
C = (G_stats - EG) / G_max
410+
411+
return pd.DataFrame(
412+
{
413+
"G": cp.asnumpy(G_stats), "EG": cp.asnumpy(EG),
414+
"stdG": cp.asnumpy(stdG), "Z": cp.asnumpy(Z), "C": cp.asnumpy(C),
415+
},
416+
index=genes,
417+
)

0 commit comments

Comments
 (0)