Skip to content

Commit bea1d13

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5a3e07d commit bea1d13

2 files changed

Lines changed: 63 additions & 37 deletions

File tree

src/rapids_singlecell/decoupler_gpu/_method_waggr.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations
2+
23
import inspect
34
from collections.abc import Callable
45

56
import cupy as cp
67
import numpy as np
7-
from numba import cuda
8+
89
from rapids_singlecell.decoupler_gpu._helper._docs import docs
910
from rapids_singlecell.decoupler_gpu._helper._log import _log
1011
from rapids_singlecell.decoupler_gpu._helper._Method import Method, MethodMeta
1112

13+
1214
def _ridx(
1315
times: int,
1416
nvar: int,
@@ -23,15 +25,16 @@ def _ridx(
2325
idx = cp.array(idx)
2426
return idx
2527

28+
2629
_wsum_kernel = cp.RawKernel(
2730
r"""
2831
extern "C" __global__ void matmul_kernel(const float* x, const float* w, float* C, int n_obs, int n_var, int n_src) {
2932
// x is n_obs x n_var, w is n_var x n_src, C is n_obs x n_src
30-
33+
3134
// Get the row and column index of the output matrix C for this thread
3235
const int row = blockIdx.y * blockDim.y + threadIdx.y;
3336
const int src = blockIdx.x * blockDim.x + threadIdx.x;
34-
37+
3538
// Bounds checking
3639
if (row < n_obs && src < n_src) {
3740
float sum = 0.0f; // Use float precision for accumulation
@@ -45,40 +48,48 @@ def _ridx(
4548
"matmul_kernel",
4649
)
4750

51+
4852
def _wsum_raw(x: cp.ndarray, w: cp.ndarray) -> cp.ndarray:
4953
n_obs, n_var = x.shape
5054
n_var, n_src = w.shape
5155
es = cp.zeros((n_obs, n_src), dtype=cp.float32)
52-
56+
5357
# Ensure input matrices are contiguous and of correct type
5458
if x.flags.c_contiguous and x.dtype == cp.float32:
5559
x_contig = x
5660
else:
5761
x_contig = cp.ascontiguousarray(x, dtype=cp.float32)
58-
62+
5963
if w.flags.c_contiguous and w.dtype == cp.float32:
6064
w_contig = w
6165
else:
6266
w_contig = cp.ascontiguousarray(w, dtype=cp.float32)
63-
67+
6468
# Use 2D thread blocks for better performance
6569
threads_per_block = (16, 16)
66-
70+
6771
# Calculate grid size to cover all output elements
6872
grid_x = (n_src + threads_per_block[0] - 1) // threads_per_block[0]
6973
grid_y = (n_obs + threads_per_block[1] - 1) // threads_per_block[1]
70-
71-
_wsum_kernel((grid_x, grid_y), threads_per_block, (x_contig, w_contig, es, n_obs, n_var, n_src))
74+
75+
_wsum_kernel(
76+
(grid_x, grid_y),
77+
threads_per_block,
78+
(x_contig, w_contig, es, n_obs, n_var, n_src),
79+
)
7280
return es
7381

82+
7483
def _wmean_raw(x: cp.ndarray, w: cp.ndarray) -> cp.ndarray:
7584
agg = _wsum_raw(x, w)
7685
div = cp.sum(cp.abs(w), axis=0)
7786
return agg / div
7887

88+
7989
def _wsum(x: cp.ndarray, w: cp.ndarray) -> cp.ndarray:
8090
return x.dot(w)
8191

92+
8293
def _wmean(x: cp.ndarray, w: cp.ndarray) -> cp.ndarray:
8394
agg = _wsum(x, w)
8495
div = cp.sum(cp.abs(w), axis=0)
@@ -99,6 +110,7 @@ def _f(mat, adj):
99110
m = f"waggr - using {_f.__name__}"
100111
_log(m, level="info", verbose=verbose)
101112

113+
102114
_fun_dict = {
103115
"wsum": _wsum,
104116
"wmean": _wmean,
@@ -117,32 +129,40 @@ def _validate_args(
117129
required_args = ["x", "w"]
118130
for arg in required_args:
119131
if arg not in args:
120-
assert AssertionError(), f"fun={fun.__name__} must contain arguments x and w"
132+
assert AssertionError(), (
133+
f"fun={fun.__name__} must contain arguments x and w"
134+
)
121135
# Check if any additional arguments have default values
122136
for param in args.values():
123137
if param.name not in required_args and param.default == inspect.Parameter.empty:
124-
assert AssertionError(), f"fun={fun.__name__} has an argument {param.name} without a default value"
138+
assert AssertionError(), (
139+
f"fun={fun.__name__} has an argument {param.name} without a default value"
140+
)
125141
return fun
126142

127143

128-
129144
def _validate_func(
130145
fun: Callable,
131146
verbose: bool,
132147
) -> None:
133148
fun = _validate_args(fun=fun, verbose=verbose)
134-
x = cp.array([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
149+
x = cp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
135150
w = cp.array([[-1.0, 3.0], [0.0, 4.0], [2.0, 5.0]])
136151
try:
137152
res = fun(x=x, w=w)
138153
assert isinstance(res, cp.ndarray), "output of fun must be a cp.ndarray"
139-
assert res.shape == (x.shape[0], w.shape[1]), "output of fun must be a cp.ndarray with shape (x.shape[0], w.shape[1])"
154+
assert res.shape == (x.shape[0], w.shape[1]), (
155+
"output of fun must be a cp.ndarray with shape (x.shape[0], w.shape[1])"
156+
)
140157
except Exception as err:
141-
raise ValueError(f"fun failed to run with test data: fun(x={x}), w={w}") from err
158+
raise ValueError(
159+
f"fun failed to run with test data: fun(x={x}), w={w}"
160+
) from err
142161
m = f"waggr - using function {fun.__name__}"
143162
_log(m, level="info", verbose=verbose)
144163
_fun(f=fun, verbose=verbose)
145164

165+
146166
def _perm(
147167
fun: Callable,
148168
es: np.ndarray,
@@ -165,30 +185,27 @@ def _perm(
165185
mat_perm = mat[:, idx[i]]
166186
# Apply the function
167187
perm_result = fun(mat_perm, adj)
168-
perm_result = perm_result.astype(cp.float64) # Use double precision for accumulation
188+
perm_result = perm_result.astype(
189+
cp.float64
190+
) # Use double precision for accumulation
169191
# Update running statistics
170192
sum_null += perm_result
171193
sum_null_sq += perm_result * perm_result
172194
extreme_count += (cp.abs(perm_result) > es_abs).astype(cp.int32)
173195
# Clean up intermediate results
174196
del mat_perm, perm_result
175-
176-
197+
177198
# Compute final statistics
178199
null_mean = sum_null / times
179200
# Var(X) = E[X²] - (E[X])²
180201
null_var = (sum_null_sq / times) - (null_mean * null_mean)
181202
null_std = cp.sqrt(cp.maximum(null_var, 1e-10))
182-
203+
183204
# Compute NES
184205
nes = cp.where(
185-
null_std > 1e-10,
186-
(
187-
es.astype(cp.float64) - null_mean) / null_std,
188-
cp.where(cp.abs(es) > 1e-10,
189-
cp.sign(es.astype(cp.float64)) * 1e6,
190-
0.0
191-
)
206+
null_std > 1e-10,
207+
(es.astype(cp.float64) - null_mean) / null_std,
208+
cp.where(cp.abs(es) > 1e-10, cp.sign(es.astype(cp.float64)) * 1e6, 0.0),
192209
)
193210

194211
# Compute empirical p-value
@@ -198,9 +215,10 @@ def _perm(
198215
pvals = pvals / times
199216
pvals = cp.where(pvals >= 0.5, 1 - pvals, pvals)
200217
pvals = pvals * 2 # Two-tailed test
201-
218+
202219
return nes.astype(cp.float32), pvals
203220

221+
204222
@docs.dedent
205223
def _func_waggr(
206224
mat: cp.ndarray,
@@ -289,7 +307,9 @@ def _func_waggr(
289307
f_fun = fun
290308
_validate_func(f_fun, verbose=verbose)
291309
vfun = _cfuncs[f_fun.__name__]
292-
assert isinstance(times, int | float) and times >= 0, "times must be numeric and >= 0"
310+
assert isinstance(times, int | float) and times >= 0, (
311+
"times must be numeric and >= 0"
312+
)
293313
assert isinstance(seed, int | float) and seed >= 0, "seed must be numeric and >= 0"
294314
times, seed = int(times), int(seed)
295315
nobs, nvar = mat.shape
@@ -306,6 +326,7 @@ def _func_waggr(
306326
pv = cp.ones(es.shape)
307327
return es.get(), pv.get()
308328

329+
309330
_waggr = MethodMeta(
310331
name="waggr",
311332
desc="Weighted Aggregate (WAGGR)",

tests/decoupler/test_waggr.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import rapids_singlecell.decoupler_gpu as dc
88

9+
910
def test_funcs(rng):
1011
x = cp.array([[1, 2, 3, 4]], dtype=float)
1112
w = cp.array([rng.random(x.size)], dtype=float)
@@ -16,25 +17,28 @@ def test_funcs(rng):
1617

1718

1819
def test_wsum_wmean(mat, adjmat):
19-
2020
print("\n=== Testing with test data ===")
2121
X, obs, var = mat
2222
X = cp.array(X, dtype=cp.float32)
2323
adjmat = cp.array(adjmat, dtype=cp.float32)
24-
24+
2525
print(f"X shape: {X.shape}, adjmat shape: {adjmat.shape}")
2626
print(f"X dtype: {X.dtype}, adjmat dtype: {adjmat.dtype}")
2727
print(f"X min/max: {cp.min(X):.6f} / {cp.max(X):.6f}")
2828
print(f"adjmat min/max: {cp.min(adjmat):.6f} / {cp.max(adjmat):.6f}")
29-
29+
3030
# Test _wsum
3131
result_actual = dc._method_waggr._wsum(X, adjmat)
3232
expected_actual = X @ adjmat
33-
34-
print(f"Expected shape: {expected_actual.shape}, Result shape: {result_actual.shape}")
35-
print(f"Expected min/max: {cp.min(expected_actual):.6f} / {cp.max(expected_actual):.6f}")
33+
34+
print(
35+
f"Expected shape: {expected_actual.shape}, Result shape: {result_actual.shape}"
36+
)
37+
print(
38+
f"Expected min/max: {cp.min(expected_actual):.6f} / {cp.max(expected_actual):.6f}"
39+
)
3640
print(f"Result min/max: {cp.min(result_actual):.6f} / {cp.max(result_actual):.6f}")
37-
41+
3842
is_close_actual = cp.allclose(expected_actual, result_actual, rtol=1e-4)
3943
print(f"wsum results match: {is_close_actual}")
4044

@@ -49,7 +53,6 @@ def test_wsum_wmean(mat, adjmat):
4953
assert is_close_actual, "_wmean test failed"
5054

5155

52-
5356
@pytest.mark.parametrize(
5457
"fun,times,seed",
5558
[
@@ -68,6 +71,8 @@ def test_func_waggr(
6871
X = cp.array(X)
6972
adjmat = cp.array(adjmat)
7073
times = 0
71-
es, pv = dc._method_waggr._func_waggr(mat=X, adj=adjmat, fun=fun, times=times, seed=seed)
74+
es, pv = dc._method_waggr._func_waggr(
75+
mat=X, adj=adjmat, fun=fun, times=times, seed=seed
76+
)
7277
assert np.isfinite(es).all()
7378
assert ((0 <= pv) & (pv <= 1)).all()

0 commit comments

Comments
 (0)