Skip to content

Commit 5595196

Browse files
committed
Simplify GPU code after review
- Unify _build_sparse_weight_matrix with square parameter - Remove cp module parameter from GPU helpers (import locally) - Skip all_mu/all_x2 allocation when centered=True - Densify sparse counts once upfront instead of per-gene - Extract _postprocess_results to share between CPU and GPU paths - Fix pre-existing bug: none model branch in _compute_hs_pairs_inner fitted vals_x into mu_x instead of vals_y into mu_y - Replace inline model dispatch in _compute_hs_pairs_inner with _fit_gene
1 parent 549576f commit 5595196

File tree

3 files changed

+35
-85
lines changed

3 files changed

+35
-85
lines changed

hotspot/gpu.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,31 +44,18 @@ def _require_gpu():
4444
) from e
4545

4646

47-
def _build_sparse_weight_matrix(neighbors, weights, shape):
47+
def _build_sparse_weight_matrix(neighbors, weights, shape, square=False):
4848
"""Build a CuPy sparse CSR matrix from neighbor/weight arrays.
4949
5050
W[i, neighbors[i,k]] = weights[i,k] for all i, k where weights[i,k] != 0.
51+
If square=True, uses weights^2 instead (for moment computations).
5152
"""
5253
N, K = neighbors.shape
5354
rows = np.repeat(np.arange(N, dtype=np.int32), K)
5455
cols = neighbors.ravel().astype(np.int32)
5556
vals = weights.ravel().astype(np.float64)
56-
57-
mask = vals != 0
58-
rows, cols, vals = rows[mask], cols[mask], vals[mask]
59-
60-
return cp_sparse.csr_matrix(
61-
(cp.asarray(vals), (cp.asarray(rows), cp.asarray(cols))),
62-
shape=shape,
63-
)
64-
65-
66-
def _build_sparse_weight_sq_matrix(neighbors, weights, shape):
67-
"""Build sparse matrix with squared weights: W_sq[i,j] = weights[i,k]^2."""
68-
N, K = neighbors.shape
69-
rows = np.repeat(np.arange(N, dtype=np.int32), K)
70-
cols = neighbors.ravel().astype(np.int32)
71-
vals = (weights.ravel().astype(np.float64)) ** 2
57+
if square:
58+
vals = vals ** 2
7259

7360
mask = vals != 0
7461
rows, cols, vals = rows[mask], cols[mask], vals[mask]

hotspot/local_stats.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@ def compute_hs(
187187
):
188188

189189
if use_gpu:
190-
return _compute_hs_gpu(
190+
results = _compute_hs_gpu(
191191
counts, neighbors, weights, num_umi, model, genes, centered
192192
)
193+
return _postprocess_results(results)
193194

194195
neighbors = neighbors.values
195196
weights = weights.values
@@ -232,14 +233,15 @@ def _map_fun(vals):
232233

233234
results = pd.DataFrame(results, index=genes, columns=["G", "EG", "stdG", "Z", "C"])
234235

236+
return _postprocess_results(results)
237+
238+
239+
def _postprocess_results(results):
235240
results["Pval"] = norm.sf(results["Z"].values)
236241
results["FDR"] = multipletests(results["Pval"], method="fdr_bh")[1]
237-
238242
results = results.sort_values("Z", ascending=False)
239243
results.index.name = "Gene"
240-
241-
results = results[["C", "Z", "Pval", "FDR"]] # Remove other columns
242-
244+
results = results[["C", "Z", "Pval", "FDR"]]
243245
return results
244246

245247

@@ -312,7 +314,7 @@ def _local_cov_weights_gpu(vals_gpu, W):
312314
return (vals_gpu * smoothed_T.T).sum(axis=1)
313315

314316

315-
def _compute_moments_weights_gpu(cp, mu_gpu, x2_gpu, W, W_sq):
317+
def _compute_moments_weights_gpu(mu_gpu, x2_gpu, W, W_sq):
316318
"""GPU batch of compute_moments_weights for all genes at once."""
317319
# EG[g] = mu[g] . (W @ mu[g])
318320
EG = (mu_gpu * (W @ mu_gpu.T).T).sum(axis=1)
@@ -348,7 +350,7 @@ def _compute_hs_gpu(counts, neighbors, weights, num_umi, model, genes, centered)
348350
All genes are processed in parallel via sparse matrix multiplication.
349351
"""
350352
import cupy as cp
351-
from .gpu import _require_gpu, _build_sparse_weight_matrix, _build_sparse_weight_sq_matrix
353+
from .gpu import _require_gpu, _build_sparse_weight_matrix
352354

353355
_require_gpu()
354356

@@ -362,22 +364,26 @@ def _compute_hs_gpu(counts, neighbors, weights, num_umi, model, genes, centered)
362364
D = compute_node_degree(neighbors_np, weights_np)
363365
Wtot2 = (weights_np ** 2).sum()
364366

367+
if issparse(counts):
368+
counts_dense = counts.toarray()
369+
else:
370+
counts_dense = np.asarray(counts)
371+
365372
all_vals = np.zeros((N_genes, N_cells), dtype="double")
366-
all_mu = np.zeros((N_genes, N_cells), dtype="double")
367-
all_x2 = np.zeros((N_genes, N_cells), dtype="double")
373+
if not centered:
374+
all_mu = np.zeros((N_genes, N_cells), dtype="double")
375+
all_x2 = np.zeros((N_genes, N_cells), dtype="double")
368376

369377
for i in range(N_genes):
370-
raw = counts[i]
371-
if issparse(raw):
372-
raw = raw.toarray().ravel()
373-
raw = np.asarray(raw).ravel().astype("double")
378+
raw = counts_dense[i].astype("double")
374379

375380
vals, mu, var, x2 = _fit_gene(raw, model, num_umi_np)
376381
if centered:
377382
vals = center_values(vals, mu, var)
383+
else:
384+
all_mu[i] = mu
385+
all_x2[i] = x2
378386
all_vals[i] = vals
379-
all_mu[i] = mu
380-
all_x2[i] = x2
381387

382388
vals_gpu = cp.asarray(all_vals)
383389
D_gpu = cp.asarray(D)
@@ -391,29 +397,21 @@ def _compute_hs_gpu(counts, neighbors, weights, num_umi, model, genes, centered)
391397
else:
392398
mu_gpu = cp.asarray(all_mu)
393399
x2_gpu = cp.asarray(all_x2)
394-
W_sq = _build_sparse_weight_sq_matrix(
395-
neighbors_np, weights_np, shape=(N_cells, N_cells)
400+
W_sq = _build_sparse_weight_matrix(
401+
neighbors_np, weights_np, shape=(N_cells, N_cells), square=True
396402
)
397-
EG, EG2 = _compute_moments_weights_gpu(cp, mu_gpu, x2_gpu, W, W_sq)
403+
EG, EG2 = _compute_moments_weights_gpu(mu_gpu, x2_gpu, W, W_sq)
398404

399405
stdG = (EG2 - EG * EG) ** 0.5
400406
Z = (G_stats - EG) / stdG
401407

402408
G_max = _compute_local_cov_max_gpu(D_gpu, vals_gpu)
403409
C = (G_stats - EG) / G_max
404410

405-
results = pd.DataFrame(
411+
return pd.DataFrame(
406412
{
407413
"G": cp.asnumpy(G_stats), "EG": cp.asnumpy(EG),
408414
"stdG": cp.asnumpy(stdG), "Z": cp.asnumpy(Z), "C": cp.asnumpy(C),
409415
},
410416
index=genes,
411417
)
412-
413-
results["Pval"] = norm.sf(results["Z"].values)
414-
results["FDR"] = multipletests(results["Pval"], method="fdr_bh")[1]
415-
results = results.sort_values("Z", ascending=False)
416-
results.index.name = "Gene"
417-
results = results[["C", "Z", "Pval", "FDR"]]
418-
419-
return results

hotspot/local_stats_pairs.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from . import bernoulli_model
1010
from . import normal_model
1111
from . import none_model
12-
from .local_stats import compute_local_cov_max
12+
from .local_stats import compute_local_cov_max, _fit_gene
1313
from .knn import compute_node_degree
1414
from .utils import center_values
1515

@@ -449,25 +449,7 @@ def _compute_hs_pairs_inner(row_i, counts, neighbors, weights, num_umi,
449449
lc_out = np.zeros(counts.shape[0])
450450
lc_z_out = np.zeros(counts.shape[0])
451451

452-
if model == 'bernoulli':
453-
vals_x = (vals_x > 0).astype('double')
454-
mu_x, var_x, x2_x = bernoulli_model.fit_gene_model(
455-
vals_x, num_umi)
456-
457-
elif model == 'danb':
458-
mu_x, var_x, x2_x = danb_model.fit_gene_model(
459-
vals_x, num_umi)
460-
461-
elif model == 'normal':
462-
mu_x, var_x, x2_x = normal_model.fit_gene_model(
463-
vals_x, num_umi)
464-
465-
elif model == 'none':
466-
mu_x, var_x, x2_x = none_model.fit_gene_model(
467-
vals_x, num_umi)
468-
469-
else:
470-
raise Exception("Invalid Model: {}".format(model))
452+
vals_x, mu_x, var_x, x2_x = _fit_gene(vals_x, model, num_umi)
471453

472454
if centered:
473455
vals_x = center_values(vals_x, mu_x, var_x)
@@ -479,25 +461,7 @@ def _compute_hs_pairs_inner(row_i, counts, neighbors, weights, num_umi,
479461

480462
vals_y = counts[row_j]
481463

482-
if model == 'bernoulli':
483-
vals_y = (vals_y > 0).astype('double')
484-
mu_y, var_y, x2_y = bernoulli_model.fit_gene_model(
485-
vals_y, num_umi)
486-
487-
elif model == 'danb':
488-
mu_y, var_y, x2_y = danb_model.fit_gene_model(
489-
vals_y, num_umi)
490-
491-
elif model == 'normal':
492-
mu_y, var_y, x2_y = normal_model.fit_gene_model(
493-
vals_y, num_umi)
494-
495-
elif model == 'none':
496-
mu_x, var_x, x2_x = none_model.fit_gene_model(
497-
vals_x, num_umi)
498-
499-
else:
500-
raise Exception("Invalid Model: {}".format(model))
464+
vals_y, mu_y, var_y, x2_y = _fit_gene(vals_y, model, num_umi)
501465

502466
if centered:
503467
vals_y = center_values(vals_y, mu_y, var_y)
@@ -889,12 +853,13 @@ def _conditional_eg2_gpu(X, W_sym):
889853
return (t1x_T ** 2).sum(axis=0)
890854

891855

892-
def _local_cov_pair_all_gpu(cp, X, W):
856+
def _local_cov_pair_all_gpu(X, W):
893857
"""GPU batch of local_cov_pair for ALL gene pairs via dense matmul.
894858
895859
Returns the full G x G matrix of lc values (= local_cov_pair * 2).
896860
Diagonal is zeroed (no self-pairs).
897861
"""
862+
import cupy as cp
898863
smoothed_T = W @ X.T # (N, G)
899864
M = X @ smoothed_T # (G, G): M[a,b] = x_a . (W @ x_b)
900865
lc_matrix = M + M.T # symmetrize: lc[a,b] = x_a.(Wx_b) + x_b.(Wx_a)
@@ -930,7 +895,7 @@ def _compute_hs_pairs_centered_cond_gpu(counts, neighbors, weights, num_umi, mod
930895

931896
eg2s = _conditional_eg2_gpu(X, W_sym)
932897

933-
lc_matrix = _local_cov_pair_all_gpu(cp, X, W)
898+
lc_matrix = _local_cov_pair_all_gpu(X, W)
934899

935900
std_genes = eg2s ** 0.5
936901
Z_xy = lc_matrix / std_genes[:, None]

0 commit comments

Comments
 (0)