11from __future__ import annotations
2+
23import inspect
34from collections .abc import Callable
45
56import cupy as cp
67import numpy as np
7- from numba import cuda
8+
89from rapids_singlecell .decoupler_gpu ._helper ._docs import docs
910from rapids_singlecell .decoupler_gpu ._helper ._log import _log
1011from rapids_singlecell .decoupler_gpu ._helper ._Method import Method , MethodMeta
1112
13+
1214def _ridx (
1315 times : int ,
1416 nvar : int ,
@@ -23,15 +25,16 @@ def _ridx(
2325 idx = cp .array (idx )
2426 return idx
2527
28+
2629_wsum_kernel = cp .RawKernel (
2730 r"""
2831extern "C" __global__ void matmul_kernel(const float* x, const float* w, float* C, int n_obs, int n_var, int n_src) {
2932 // x is n_obs x n_var, w is n_var x n_src, C is n_obs x n_src
30-
33+
3134 // Get the row and column index of the output matrix C for this thread
3235 const int row = blockIdx.y * blockDim.y + threadIdx.y;
3336 const int src = blockIdx.x * blockDim.x + threadIdx.x;
34-
37+
3538 // Bounds checking
3639 if (row < n_obs && src < n_src) {
3740 float sum = 0.0f; // Use float precision for accumulation
@@ -45,40 +48,48 @@ def _ridx(
4548 "matmul_kernel" ,
4649)
4750
51+
4852def _wsum_raw (x : cp .ndarray , w : cp .ndarray ) -> cp .ndarray :
4953 n_obs , n_var = x .shape
5054 n_var , n_src = w .shape
5155 es = cp .zeros ((n_obs , n_src ), dtype = cp .float32 )
52-
56+
5357 # Ensure input matrices are contiguous and of correct type
5458 if x .flags .c_contiguous and x .dtype == cp .float32 :
5559 x_contig = x
5660 else :
5761 x_contig = cp .ascontiguousarray (x , dtype = cp .float32 )
58-
62+
5963 if w .flags .c_contiguous and w .dtype == cp .float32 :
6064 w_contig = w
6165 else :
6266 w_contig = cp .ascontiguousarray (w , dtype = cp .float32 )
63-
67+
6468 # Use 2D thread blocks for better performance
6569 threads_per_block = (16 , 16 )
66-
70+
6771 # Calculate grid size to cover all output elements
6872 grid_x = (n_src + threads_per_block [0 ] - 1 ) // threads_per_block [0 ]
6973 grid_y = (n_obs + threads_per_block [1 ] - 1 ) // threads_per_block [1 ]
70-
71- _wsum_kernel ((grid_x , grid_y ), threads_per_block , (x_contig , w_contig , es , n_obs , n_var , n_src ))
74+
75+ _wsum_kernel (
76+ (grid_x , grid_y ),
77+ threads_per_block ,
78+ (x_contig , w_contig , es , n_obs , n_var , n_src ),
79+ )
7280 return es
7381
82+
7483def _wmean_raw (x : cp .ndarray , w : cp .ndarray ) -> cp .ndarray :
7584 agg = _wsum_raw (x , w )
7685 div = cp .sum (cp .abs (w ), axis = 0 )
7786 return agg / div
7887
88+
7989def _wsum (x : cp .ndarray , w : cp .ndarray ) -> cp .ndarray :
8090 return x .dot (w )
8191
92+
8293def _wmean (x : cp .ndarray , w : cp .ndarray ) -> cp .ndarray :
8394 agg = _wsum (x , w )
8495 div = cp .sum (cp .abs (w ), axis = 0 )
@@ -99,6 +110,7 @@ def _f(mat, adj):
99110 m = f"waggr - using { _f .__name__ } "
100111 _log (m , level = "info" , verbose = verbose )
101112
113+
102114_fun_dict = {
103115 "wsum" : _wsum ,
104116 "wmean" : _wmean ,
@@ -117,32 +129,40 @@ def _validate_args(
117129 required_args = ["x" , "w" ]
118130 for arg in required_args :
119131 if arg not in args :
120- assert AssertionError (), f"fun={ fun .__name__ } must contain arguments x and w"
132+ assert AssertionError (), (
133+ f"fun={ fun .__name__ } must contain arguments x and w"
134+ )
121135 # Check if any additional arguments have default values
122136 for param in args .values ():
123137 if param .name not in required_args and param .default == inspect .Parameter .empty :
124- assert AssertionError (), f"fun={ fun .__name__ } has an argument { param .name } without a default value"
138+ assert AssertionError (), (
139+ f"fun={ fun .__name__ } has an argument { param .name } without a default value"
140+ )
125141 return fun
126142
127143
128-
129144def _validate_func (
130145 fun : Callable ,
131146 verbose : bool ,
132147) -> None :
133148 fun = _validate_args (fun = fun , verbose = verbose )
134- x = cp .array ([[1.0 , 2.0 , 3.0 ],[4.0 , 5.0 , 6.0 ]])
149+ x = cp .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]])
135150 w = cp .array ([[- 1.0 , 3.0 ], [0.0 , 4.0 ], [2.0 , 5.0 ]])
136151 try :
137152 res = fun (x = x , w = w )
138153 assert isinstance (res , cp .ndarray ), "output of fun must be a cp.ndarray"
139- assert res .shape == (x .shape [0 ], w .shape [1 ]), "output of fun must be a cp.ndarray with shape (x.shape[0], w.shape[1])"
154+ assert res .shape == (x .shape [0 ], w .shape [1 ]), (
155+ "output of fun must be a cp.ndarray with shape (x.shape[0], w.shape[1])"
156+ )
140157 except Exception as err :
141- raise ValueError (f"fun failed to run with test data: fun(x={ x } ), w={ w } " ) from err
158+ raise ValueError (
159+ f"fun failed to run with test data: fun(x={ x } ), w={ w } "
160+ ) from err
142161 m = f"waggr - using function { fun .__name__ } "
143162 _log (m , level = "info" , verbose = verbose )
144163 _fun (f = fun , verbose = verbose )
145164
165+
146166def _perm (
147167 fun : Callable ,
148168 es : np .ndarray ,
@@ -165,30 +185,27 @@ def _perm(
165185 mat_perm = mat [:, idx [i ]]
166186 # Apply the function
167187 perm_result = fun (mat_perm , adj )
168- perm_result = perm_result .astype (cp .float64 ) # Use double precision for accumulation
188+ perm_result = perm_result .astype (
189+ cp .float64
190+ ) # Use double precision for accumulation
169191 # Update running statistics
170192 sum_null += perm_result
171193 sum_null_sq += perm_result * perm_result
172194 extreme_count += (cp .abs (perm_result ) > es_abs ).astype (cp .int32 )
173195 # Clean up intermediate results
174196 del mat_perm , perm_result
175-
176-
197+
177198 # Compute final statistics
178199 null_mean = sum_null / times
179200 # Var(X) = E[X²] - (E[X])²
180201 null_var = (sum_null_sq / times ) - (null_mean * null_mean )
181202 null_std = cp .sqrt (cp .maximum (null_var , 1e-10 ))
182-
203+
183204 # Compute NES
184205 nes = cp .where (
185- null_std > 1e-10 ,
186- (
187- es .astype (cp .float64 ) - null_mean ) / null_std ,
188- cp .where (cp .abs (es ) > 1e-10 ,
189- cp .sign (es .astype (cp .float64 )) * 1e6 ,
190- 0.0
191- )
206+ null_std > 1e-10 ,
207+ (es .astype (cp .float64 ) - null_mean ) / null_std ,
208+ cp .where (cp .abs (es ) > 1e-10 , cp .sign (es .astype (cp .float64 )) * 1e6 , 0.0 ),
192209 )
193210
194211 # Compute empirical p-value
@@ -198,9 +215,10 @@ def _perm(
198215 pvals = pvals / times
199216 pvals = cp .where (pvals >= 0.5 , 1 - pvals , pvals )
200217 pvals = pvals * 2 # Two-tailed test
201-
218+
202219 return nes .astype (cp .float32 ), pvals
203220
221+
204222@docs .dedent
205223def _func_waggr (
206224 mat : cp .ndarray ,
@@ -289,7 +307,9 @@ def _func_waggr(
289307 f_fun = fun
290308 _validate_func (f_fun , verbose = verbose )
291309 vfun = _cfuncs [f_fun .__name__ ]
292- assert isinstance (times , int | float ) and times >= 0 , "times must be numeric and >= 0"
310+ assert isinstance (times , int | float ) and times >= 0 , (
311+ "times must be numeric and >= 0"
312+ )
293313 assert isinstance (seed , int | float ) and seed >= 0 , "seed must be numeric and >= 0"
294314 times , seed = int (times ), int (seed )
295315 nobs , nvar = mat .shape
@@ -306,6 +326,7 @@ def _func_waggr(
306326 pv = cp .ones (es .shape )
307327 return es .get (), pv .get ()
308328
329+
309330_waggr = MethodMeta (
310331 name = "waggr" ,
311332 desc = "Weighted Aggregate (WAGGR)" ,
0 commit comments