Skip to content

Commit bdffee9

Browse files
committed
Fix tests
1 parent a84bb25 commit bdffee9

2 files changed

Lines changed: 42 additions & 40 deletions

File tree

src/rapids_singlecell/squidpy_gpu/_niche.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _neighborhood_profile(
199199
weights: Sequence[float] | None,
200200
abs_nhood: bool,
201201
key: str,
202-
) -> np.ndarray:
202+
) -> cp.ndarray:
203203
"""Cells x categories matrix of cell-type counts (or relative frequencies) over n-hop neighbors."""
204204
cats = pd.Categorical(adata.obs[groups])
205205
n_cats = len(cats.categories)
@@ -218,6 +218,8 @@ def _neighborhood_profile(
218218
weights = [1.0] * distance
219219
elif len(weights) < distance:
220220
weights = list(weights) + [weights[-1]] * (distance - len(weights))
221+
if not abs_nhood and sum(weights) == 0:
222+
raise ValueError("`n_hop_weights` must not sum to zero.")
221223

222224
profile = cp.zeros((n_obs, n_cats), dtype=cp.float32)
223225
adj_k = adj_bin
@@ -325,7 +327,7 @@ def _cellcharter_features(
325327
fused ``mul_csr`` kernel used for utag, then aggregated as either:
326328
327329
- ``"mean"``: ``Âₖ @ X``
328-
- ``"variance"``: ``Âₖ @ (X·X) (Âₖ @ X)²`` (matches squidpy's path; densifies X)
330+
- ``"variance"``: ``Âₖ @ (X·X) - (Âₖ @ X)²`` (matches squidpy's path; densifies X)
329331
330332
All layers are concatenated horizontally.
331333
"""

tests/test_gmm.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ def _cuda_e_step(
121121
means,
122122
prec_chol,
123123
log_det_half,
124-
log_prob,
125-
responsibilities,
126-
ll_per_cell,
127-
centered,
128-
e_step_y,
124+
log_prob=log_prob,
125+
responsibilities=responsibilities,
126+
ll_per_cell=ll_per_cell,
127+
centered=centered,
128+
e_step_y=e_step_y,
129129
e_step_route=e_step_route,
130130
stream=cp.cuda.get_current_stream().ptr,
131131
handle=cp.cuda.device.get_cublas_handle(),
@@ -147,11 +147,11 @@ def _cuda_m_step(
147147
weights,
148148
means,
149149
covariances,
150-
reg_covar,
151-
cp.ones(X.shape[0], dtype=X.dtype),
152-
cp.empty(K, dtype=X.dtype),
153-
cp.empty((K, X.shape[1]), dtype=X.dtype),
154-
cp.empty_like(X),
150+
reg_covar=reg_covar,
151+
ones=cp.ones(X.shape[0], dtype=X.dtype),
152+
effective_counts=cp.empty(K, dtype=X.dtype),
153+
weighted_sums=cp.empty((K, X.shape[1]), dtype=X.dtype),
154+
centered=cp.empty_like(X),
155155
stream=cp.cuda.get_current_stream().ptr,
156156
handle=cp.cuda.device.get_cublas_handle(),
157157
)
@@ -332,9 +332,9 @@ def test_cuda_matches_reference_steps():
332332
means,
333333
prec_chol,
334334
log_det_half,
335-
log_prob,
336-
resp,
337-
ll_per_cell,
335+
log_prob=log_prob,
336+
responsibilities=resp,
337+
ll_per_cell=ll_per_cell,
338338
stream=cp.cuda.get_current_stream().ptr,
339339
)
340340

@@ -380,11 +380,11 @@ def test_cuda_512_e_step_matches_reference_for_cublas_route():
380380
means,
381381
prec_chol,
382382
log_det_half,
383-
centered,
384-
e_step_y,
385-
log_prob,
386-
resp,
387-
ll_per_cell,
383+
centered=centered,
384+
e_step_y=e_step_y,
385+
log_prob=log_prob,
386+
responsibilities=resp,
387+
ll_per_cell=ll_per_cell,
388388
stream=cp.cuda.get_current_stream().ptr,
389389
handle=cp.cuda.device.get_cublas_handle(),
390390
)
@@ -415,11 +415,11 @@ def test_cuda_768_e_step_uses_cublas_route():
415415
means,
416416
prec_chol,
417417
log_det_half,
418-
log_prob,
419-
resp,
420-
ll_per_cell,
421-
centered,
422-
e_step_y,
418+
log_prob=log_prob,
419+
responsibilities=resp,
420+
ll_per_cell=ll_per_cell,
421+
centered=centered,
422+
e_step_y=e_step_y,
423423
e_step_route="cublas",
424424
stream=stream,
425425
handle=handle,
@@ -433,11 +433,11 @@ def test_cuda_768_e_step_uses_cublas_route():
433433
means,
434434
prec_chol,
435435
log_det_half,
436-
centered_b,
437-
e_step_y_b,
438-
log_prob_b,
439-
resp_b,
440-
ll_per_cell_b,
436+
centered=centered_b,
437+
e_step_y=e_step_y_b,
438+
log_prob=log_prob_b,
439+
responsibilities=resp_b,
440+
ll_per_cell=ll_per_cell_b,
441441
stream=stream,
442442
handle=handle,
443443
)
@@ -468,11 +468,11 @@ def test_cuda_float64_wide_e_step_uses_cublas_route():
468468
means,
469469
prec_chol,
470470
log_det_half,
471-
log_prob,
472-
resp,
473-
ll_per_cell,
474-
centered,
475-
e_step_y,
471+
log_prob=log_prob,
472+
responsibilities=resp,
473+
ll_per_cell=ll_per_cell,
474+
centered=centered,
475+
e_step_y=e_step_y,
476476
e_step_route=route,
477477
stream=cp.cuda.get_current_stream().ptr,
478478
handle=cp.cuda.device.get_cublas_handle(),
@@ -517,11 +517,11 @@ def test_cuda_fused_e_step_matches_reference_for_50_pc_regime():
517517
means,
518518
prec_chol,
519519
log_det_half,
520-
log_prob,
521-
resp,
522-
ll_per_cell,
523-
centered,
524-
e_step_y,
520+
log_prob=log_prob,
521+
responsibilities=resp,
522+
ll_per_cell=ll_per_cell,
523+
centered=centered,
524+
e_step_y=e_step_y,
525525
e_step_route="fused",
526526
stream=cp.cuda.get_current_stream().ptr,
527527
handle=cp.cuda.device.get_cublas_handle(),

0 commit comments

Comments
 (0)