Skip to content

Commit cfdec19

Browse files
committed
add streams
1 parent d386000 commit cfdec19

42 files changed

Lines changed: 1358 additions & 824 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/rapids_singlecell/_cuda/aggr/aggr.cu

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,39 @@ namespace nb = nanobind;
1010
template <typename T>
1111
static inline void launch_csr_aggr(std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data,
1212
std::uintptr_t out, std::uintptr_t cats, std::uintptr_t mask,
13-
std::size_t n_cells, std::size_t n_genes, std::size_t n_groups) {
13+
std::size_t n_cells, std::size_t n_genes, std::size_t n_groups,
14+
cudaStream_t stream) {
1415
dim3 grid((unsigned)n_cells);
1516
dim3 block(64);
16-
csr_aggr_kernel<T>
17-
<<<grid, block>>>(reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
18-
reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
19-
reinterpret_cast<const int*>(cats), reinterpret_cast<const bool*>(mask),
20-
n_cells, n_genes, n_groups);
17+
csr_aggr_kernel<T><<<grid, block, 0, stream>>>(
18+
reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
19+
reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
20+
reinterpret_cast<const int*>(cats), reinterpret_cast<const bool*>(mask), n_cells, n_genes,
21+
n_groups);
2122
}
2223

2324
template <typename T>
2425
static inline void launch_csc_aggr(std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data,
2526
std::uintptr_t out, std::uintptr_t cats, std::uintptr_t mask,
26-
std::size_t n_cells, std::size_t n_genes, std::size_t n_groups) {
27+
std::size_t n_cells, std::size_t n_genes, std::size_t n_groups,
28+
cudaStream_t stream) {
2729
dim3 grid((unsigned)n_genes);
2830
dim3 block(64);
29-
csc_aggr_kernel<T>
30-
<<<grid, block>>>(reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
31-
reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
32-
reinterpret_cast<const int*>(cats), reinterpret_cast<const bool*>(mask),
33-
n_cells, n_genes, n_groups);
31+
csc_aggr_kernel<T><<<grid, block, 0, stream>>>(
32+
reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
33+
reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
34+
reinterpret_cast<const int*>(cats), reinterpret_cast<const bool*>(mask), n_cells, n_genes,
35+
n_groups);
3436
}
3537

3638
template <typename T>
3739
static inline void launch_csr_to_coo(std::uintptr_t indptr, std::uintptr_t index,
3840
std::uintptr_t data, std::uintptr_t row, std::uintptr_t col,
3941
std::uintptr_t ndata, std::uintptr_t cats, std::uintptr_t mask,
40-
int n_cells) {
42+
int n_cells, cudaStream_t stream) {
4143
dim3 grid((unsigned)n_cells);
4244
dim3 block(64);
43-
csr_to_coo_kernel<T><<<grid, block>>>(
45+
csr_to_coo_kernel<T><<<grid, block, 0, stream>>>(
4446
reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
4547
reinterpret_cast<const T*>(data), reinterpret_cast<int*>(row), reinterpret_cast<int*>(col),
4648
reinterpret_cast<double*>(ndata), reinterpret_cast<const int*>(cats),
@@ -50,93 +52,120 @@ static inline void launch_csr_to_coo(std::uintptr_t indptr, std::uintptr_t index
5052
template <typename T>
5153
static inline void launch_dense_C(std::uintptr_t data, std::uintptr_t out, std::uintptr_t cats,
5254
std::uintptr_t mask, std::size_t n_cells, std::size_t n_genes,
53-
std::size_t n_groups) {
55+
std::size_t n_groups, cudaStream_t stream) {
5456
dim3 block(256);
5557
dim3 grid((unsigned)((n_cells * n_genes + block.x - 1) / block.x));
5658
dense_aggr_kernel_C<T>
57-
<<<grid, block>>>(reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
58-
reinterpret_cast<const int*>(cats), reinterpret_cast<const bool*>(mask),
59-
n_cells, n_genes, n_groups);
59+
<<<grid, block, 0, stream>>>(reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
60+
reinterpret_cast<const int*>(cats),
61+
reinterpret_cast<const bool*>(mask), n_cells, n_genes, n_groups);
6062
}
6163

6264
template <typename T>
6365
static inline void launch_dense_F(std::uintptr_t data, std::uintptr_t out, std::uintptr_t cats,
6466
std::uintptr_t mask, std::size_t n_cells, std::size_t n_genes,
65-
std::size_t n_groups) {
67+
std::size_t n_groups, cudaStream_t stream) {
6668
dim3 block(256);
6769
dim3 grid((unsigned)((n_cells * n_genes + block.x - 1) / block.x));
6870
dense_aggr_kernel_F<T>
69-
<<<grid, block>>>(reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
70-
reinterpret_cast<const int*>(cats), reinterpret_cast<const bool*>(mask),
71-
n_cells, n_genes, n_groups);
71+
<<<grid, block, 0, stream>>>(reinterpret_cast<const T*>(data), reinterpret_cast<double*>(out),
72+
reinterpret_cast<const int*>(cats),
73+
reinterpret_cast<const bool*>(mask), n_cells, n_genes, n_groups);
7274
}
7375

7476
// Unified dispatchers
7577
static inline void sparse_aggr_dispatch(std::uintptr_t indptr, std::uintptr_t index,
7678
std::uintptr_t data, std::uintptr_t out,
7779
std::uintptr_t cats, std::uintptr_t mask,
7880
std::size_t n_cells, std::size_t n_genes,
79-
std::size_t n_groups, bool is_csc, int dtype_itemsize) {
81+
std::size_t n_groups, bool is_csc, int dtype_itemsize,
82+
std::uintptr_t stream) {
8083
if (is_csc) {
8184
if (dtype_itemsize == 4) {
82-
launch_csc_aggr<float>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
85+
launch_csc_aggr<float>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
86+
(cudaStream_t)stream);
8387
} else {
84-
launch_csc_aggr<double>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
88+
launch_csc_aggr<double>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
89+
(cudaStream_t)stream);
8590
}
8691
} else {
8792
if (dtype_itemsize == 4) {
88-
launch_csr_aggr<float>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
93+
launch_csr_aggr<float>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
94+
(cudaStream_t)stream);
8995
} else {
90-
launch_csr_aggr<double>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
96+
launch_csr_aggr<double>(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
97+
(cudaStream_t)stream);
9198
}
9299
}
93100
}
94101

95102
static inline void dense_aggr_dispatch(std::uintptr_t data, std::uintptr_t out, std::uintptr_t cats,
96103
std::uintptr_t mask, std::size_t n_cells,
97104
std::size_t n_genes, std::size_t n_groups, bool is_fortran,
98-
int dtype_itemsize) {
105+
int dtype_itemsize, std::uintptr_t stream) {
99106
if (is_fortran) {
100107
if (dtype_itemsize == 4) {
101-
launch_dense_F<float>(data, out, cats, mask, n_cells, n_genes, n_groups);
108+
launch_dense_F<float>(data, out, cats, mask, n_cells, n_genes, n_groups,
109+
(cudaStream_t)stream);
102110
} else {
103-
launch_dense_F<double>(data, out, cats, mask, n_cells, n_genes, n_groups);
111+
launch_dense_F<double>(data, out, cats, mask, n_cells, n_genes, n_groups,
112+
(cudaStream_t)stream);
104113
}
105114
} else {
106115
if (dtype_itemsize == 4) {
107-
launch_dense_C<float>(data, out, cats, mask, n_cells, n_genes, n_groups);
116+
launch_dense_C<float>(data, out, cats, mask, n_cells, n_genes, n_groups,
117+
(cudaStream_t)stream);
108118
} else {
109-
launch_dense_C<double>(data, out, cats, mask, n_cells, n_genes, n_groups);
119+
launch_dense_C<double>(data, out, cats, mask, n_cells, n_genes, n_groups,
120+
(cudaStream_t)stream);
110121
}
111122
}
112123
}
113124

114125
static inline void csr_to_coo_dispatch(std::uintptr_t indptr, std::uintptr_t index,
115126
std::uintptr_t data, std::uintptr_t row, std::uintptr_t col,
116127
std::uintptr_t ndata, std::uintptr_t cats,
117-
std::uintptr_t mask, int n_cells, int dtype_itemsize) {
128+
std::uintptr_t mask, int n_cells, int dtype_itemsize,
129+
std::uintptr_t stream) {
118130
if (dtype_itemsize == 4) {
119-
launch_csr_to_coo<float>(indptr, index, data, row, col, ndata, cats, mask, n_cells);
131+
launch_csr_to_coo<float>(indptr, index, data, row, col, ndata, cats, mask, n_cells,
132+
(cudaStream_t)stream);
120133
} else {
121-
launch_csr_to_coo<double>(indptr, index, data, row, col, ndata, cats, mask, n_cells);
134+
launch_csr_to_coo<double>(indptr, index, data, row, col, ndata, cats, mask, n_cells,
135+
(cudaStream_t)stream);
122136
}
123137
}
124138

125139
// variance launcher
126140
static inline void launch_sparse_var(std::uintptr_t indptr, std::uintptr_t index,
127141
std::uintptr_t data, std::uintptr_t mean_data,
128-
std::uintptr_t n_cells, int dof, int n_groups) {
142+
std::uintptr_t n_cells, int dof, int n_groups,
143+
cudaStream_t stream) {
129144
dim3 grid((unsigned)n_groups);
130145
dim3 block(64);
131-
sparse_var_kernel<<<grid, block>>>(
146+
sparse_var_kernel<<<grid, block, 0, stream>>>(
132147
reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
133148
reinterpret_cast<double*>(data), reinterpret_cast<const double*>(mean_data),
134149
reinterpret_cast<double*>(n_cells), dof, n_groups);
135150
}
136151

137152
NB_MODULE(_aggr_cuda, m) {
138-
m.def("sparse_aggr", &sparse_aggr_dispatch);
139-
m.def("dense_aggr", &dense_aggr_dispatch);
140-
m.def("csr_to_coo", &csr_to_coo_dispatch);
141-
m.def("sparse_var", &launch_sparse_var);
153+
m.def("sparse_aggr", &sparse_aggr_dispatch, nb::arg("indptr"), nb::arg("index"), nb::arg("data"),
154+
nb::arg("out"), nb::arg("cats"), nb::arg("mask"), nb::arg("n_cells"), nb::arg("n_genes"),
155+
nb::arg("n_groups"), nb::arg("is_csc"), nb::arg("dtype_itemsize"), nb::arg("stream") = 0);
156+
m.def("dense_aggr", &dense_aggr_dispatch, nb::arg("data"), nb::arg("out"), nb::arg("cats"),
157+
nb::arg("mask"), nb::arg("n_cells"), nb::arg("n_genes"), nb::arg("n_groups"),
158+
nb::arg("is_fortran"), nb::arg("dtype_itemsize"), nb::arg("stream") = 0);
159+
m.def("csr_to_coo", &csr_to_coo_dispatch, nb::arg("indptr"), nb::arg("index"), nb::arg("data"),
160+
nb::arg("row"), nb::arg("col"), nb::arg("ndata"), nb::arg("cats"), nb::arg("mask"),
161+
nb::arg("n_cells"), nb::arg("dtype_itemsize"), nb::arg("stream") = 0);
162+
m.def(
163+
"sparse_var",
164+
[](std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data, std::uintptr_t mean_data,
165+
std::uintptr_t n_cells, int dof, int n_groups, std::uintptr_t stream) {
166+
launch_sparse_var(indptr, index, data, mean_data, n_cells, dof, n_groups,
167+
(cudaStream_t)stream);
168+
},
169+
nb::arg("indptr"), nb::arg("index"), nb::arg("data"), nb::arg("mean_data"),
170+
nb::arg("n_cells"), nb::arg("dof"), nb::arg("n_groups"), nb::arg("stream") = 0);
142171
}

src/rapids_singlecell/_cuda/aucell/aucell.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,25 @@ __global__ void auc_kernel(const int* __restrict__ ranks, int R, int C,
3232

3333
static inline void launch_auc(std::uintptr_t ranks, int R, int C, std::uintptr_t cnct,
3434
std::uintptr_t starts, std::uintptr_t lens, int n_sets, int n_up,
35-
std::uintptr_t max_aucs, std::uintptr_t es) {
35+
std::uintptr_t max_aucs, std::uintptr_t es, cudaStream_t stream) {
3636
dim3 block(32);
3737
dim3 grid((unsigned)n_sets, (unsigned)((R + block.x - 1) / block.x));
38-
auc_kernel<<<grid, block>>>(
38+
auc_kernel<<<grid, block, 0, stream>>>(
3939
reinterpret_cast<const int*>(ranks), R, C, reinterpret_cast<const int*>(cnct),
4040
reinterpret_cast<const int*>(starts), reinterpret_cast<const int*>(lens), n_sets, n_up,
4141
reinterpret_cast<const float*>(max_aucs), reinterpret_cast<float*>(es));
4242
}
4343

4444
NB_MODULE(_aucell_cuda, m) {
45-
m.def("auc", &launch_auc);
45+
m.def(
46+
"auc",
47+
[](std::uintptr_t ranks, int R, int C, std::uintptr_t cnct, std::uintptr_t starts,
48+
std::uintptr_t lens, int n_sets, int n_up, std::uintptr_t max_aucs, std::uintptr_t es,
49+
std::uintptr_t stream) {
50+
launch_auc(ranks, R, C, cnct, starts, lens, n_sets, n_up, max_aucs, es,
51+
(cudaStream_t)stream);
52+
},
53+
nb::arg("ranks"), nb::arg("R"), nb::arg("C"), nb::arg("cnct"), nb::arg("starts"),
54+
nb::arg("lens"), nb::arg("n_sets"), nb::arg("n_up"), nb::arg("max_aucs"), nb::arg("es"),
55+
nb::arg("stream") = 0);
4656
}

0 commit comments

Comments
 (0)