Skip to content

Commit 8abaab0

Browse files
committed
kw-only for aggr.cu
1 parent 3fdde98 commit 8abaab0

2 files changed

Lines changed: 52 additions & 51 deletions

File tree

src/rapids_singlecell/_cuda/aggr/aggr.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,21 @@ static inline void launch_sparse_var(std::uintptr_t indptr, std::uintptr_t index
151151
}
152152

153153
NB_MODULE(_aggr_cuda, m) {
154-
m.def("sparse_aggr", &sparse_aggr_dispatch, "indptr"_a, "index"_a, "data"_a, "out"_a, "cats"_a,
155-
"mask"_a, "n_cells"_a, "n_genes"_a, "n_groups"_a, "is_csc"_a, "dtype_itemsize"_a,
154+
m.def("sparse_aggr", &sparse_aggr_dispatch, "indptr"_a, "index"_a, "data"_a, nb::kw_only(),
155+
"out"_a, "cats"_a, "mask"_a, "n_cells"_a, "n_genes"_a, "n_groups"_a, "is_csc"_a,
156+
"dtype_itemsize"_a, "stream"_a = 0);
157+
m.def("dense_aggr", &dense_aggr_dispatch, "data"_a, nb::kw_only(), "out"_a, "cats"_a, "mask"_a,
158+
"n_cells"_a, "n_genes"_a, "n_groups"_a, "is_fortran"_a, "dtype_itemsize"_a, "stream"_a = 0);
159+
m.def("csr_to_coo", &csr_to_coo_dispatch, "indptr"_a, "index"_a, "data"_a, nb::kw_only(),
160+
"out_row"_a, "out_col"_a, "out_data"_a, "cats"_a, "mask"_a, "n_cells"_a, "dtype_itemsize"_a,
156161
"stream"_a = 0);
157-
m.def("dense_aggr", &dense_aggr_dispatch, "data"_a, "out"_a, "cats"_a, "mask"_a, "n_cells"_a,
158-
"n_genes"_a, "n_groups"_a, "is_fortran"_a, "dtype_itemsize"_a, "stream"_a = 0);
159-
m.def("csr_to_coo", &csr_to_coo_dispatch, "indptr"_a, "index"_a, "data"_a, "row"_a, "col"_a,
160-
"ndata"_a, "cats"_a, "mask"_a, "n_cells"_a, "dtype_itemsize"_a, "stream"_a = 0);
161162
m.def(
162163
"sparse_var",
163164
[](std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data, std::uintptr_t mean_data,
164165
std::uintptr_t n_cells, int dof, int n_groups, std::uintptr_t stream) {
165166
launch_sparse_var(indptr, index, data, mean_data, n_cells, dof, n_groups,
166167
(cudaStream_t)stream);
167168
},
168-
"indptr"_a, "index"_a, "data"_a, "mean_data"_a, "n_cells"_a, "dof"_a, "n_groups"_a,
169+
"indptr"_a, "index"_a, "data"_a, nb::kw_only(), "means"_a, "n_cells"_a, "dof"_a, "n_groups"_a,
169170
"stream"_a = 0);
170171
}

src/rapids_singlecell/get/_aggregated.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -91,26 +91,26 @@ def __aggregate_dask(X_part, mask_part, groupby_part):
9191
X_part.indptr.data.ptr,
9292
X_part.indices.data.ptr,
9393
X_part.data.data.ptr,
94-
out.data.ptr,
95-
gb.data.ptr,
96-
mk.data.ptr,
97-
X_part.shape[0],
98-
X_part.shape[1],
99-
int(n_groups),
100-
bool(0),
101-
int(X_part.data.dtype.itemsize),
94+
out=out.data.ptr,
95+
cats=gb.data.ptr,
96+
mask=mk.data.ptr,
97+
n_cells=X_part.shape[0],
98+
n_genes=X_part.shape[1],
99+
n_groups=n_groups,
100+
is_csc=False,
101+
dtype_itemsize=X_part.data.dtype.itemsize,
102102
)
103103
else:
104104
_aggr_cuda.dense_aggr(
105-
int(X_part.data.ptr),
106-
int(out.data.ptr),
107-
int(gb.data.ptr),
108-
int(mk.data.ptr),
109-
int(X_part.shape[0]),
110-
int(X_part.shape[1]),
111-
int(n_groups),
112-
bool(0 if X_part.flags.c_contiguous else 1),
113-
int(X_part.dtype.itemsize),
105+
X_part.data.ptr,
106+
out=out.data.ptr,
107+
cats=gb.data.ptr,
108+
mask=mk.data.ptr,
109+
n_cells=X_part.shape[0],
110+
n_genes=X_part.shape[1],
111+
n_groups=n_groups,
112+
is_fortran=X_part.flags.f_contiguous,
113+
dtype_itemsize=X_part.dtype.itemsize,
114114
)
115115
return out
116116

@@ -170,14 +170,14 @@ def count_mean_var_sparse(self, dof: int = 1):
170170
self.data.indptr.data.ptr,
171171
self.data.indices.data.ptr,
172172
self.data.data.data.ptr,
173-
out.data.ptr,
174-
self.groupby.data.ptr,
175-
mask.data.ptr,
176-
int(self.data.shape[0]),
177-
int(self.data.shape[1]),
178-
int(self.n_cells.shape[0]),
179-
self.data.format == "csc",
180-
int(self.data.data.dtype.itemsize),
173+
out=out.data.ptr,
174+
cats=self.groupby.data.ptr,
175+
mask=mask.data.ptr,
176+
n_cells=int(self.data.shape[0]),
177+
n_genes=int(self.data.shape[1]),
178+
n_groups=int(self.n_cells.shape[0]),
179+
is_csc=self.data.format == "csc",
180+
dtype_itemsize=int(self.data.data.dtype.itemsize),
181181
)
182182
sums, counts, sq_sums = out[0, :], out[1, :], out[2, :]
183183
sums = sums.reshape(self.n_cells.shape[0], self.data.shape[1])
@@ -211,13 +211,13 @@ def count_mean_var_sparse_sparse(self, funcs, dof: int = 1):
211211
self.data.indptr.data.ptr,
212212
self.data.indices.data.ptr,
213213
self.data.data.data.ptr,
214-
src_row.data.ptr,
215-
src_col.data.ptr,
216-
src_data.data.ptr,
217-
self.groupby.data.ptr,
218-
mask.data.ptr,
219-
int(self.data.shape[0]),
220-
int(self.data.data.dtype.itemsize),
214+
out_row=src_row.data.ptr,
215+
out_col=src_col.data.ptr,
216+
out_data=src_data.data.ptr,
217+
cats=self.groupby.data.ptr,
218+
mask=mask.data.ptr,
219+
n_cells=self.data.shape[0],
220+
dtype_itemsize=self.data.data.dtype.itemsize,
221221
)
222222

223223
keys = cp.stack([src_col, src_row])
@@ -309,10 +309,10 @@ def count_mean_var_sparse_sparse(self, funcs, dof: int = 1):
309309
var.indptr.data.ptr,
310310
var.indices.data.ptr,
311311
var.data.data.ptr,
312-
means.data.ptr,
313-
self.n_cells.data.ptr,
314-
int(dof),
315-
int(var.shape[0]),
312+
means=means.data.ptr,
313+
n_cells=self.n_cells.data.ptr,
314+
dof=int(dof),
315+
n_groups=var.shape[0],
316316
)
317317
results["var"] = var
318318
if "count_nonzero" in funcs:
@@ -347,14 +347,14 @@ def count_mean_var_dense(self, dof: int = 1):
347347

348348
_aggr_cuda.dense_aggr(
349349
self.data.data.ptr,
350-
out.data.ptr,
351-
self.groupby.data.ptr,
352-
mask.data.ptr,
353-
self.data.shape[0],
354-
int(self.data.shape[1]),
355-
int(self.n_cells.shape[0]),
356-
bool(0 if self.data.flags.c_contiguous else 1),
357-
int(self.data.dtype.itemsize),
350+
out=out.data.ptr,
351+
cats=self.groupby.data.ptr,
352+
mask=mask.data.ptr,
353+
n_cells=self.data.shape[0],
354+
n_genes=self.data.shape[1],
355+
n_groups=self.n_cells.shape[0],
356+
is_fortran=self.data.flags.f_contiguous,
357+
dtype_itemsize=self.data.dtype.itemsize,
358358
)
359359
sums, counts, sq_sums = out[0], out[1], out[2]
360360
sums = sums.reshape(self.n_cells.shape[0], self.data.shape[1])

0 commit comments

Comments
 (0)