@@ -10,37 +10,39 @@ namespace nb = nanobind;
1010template <typename T>
1111static inline void launch_csr_aggr (std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data,
1212 std::uintptr_t out, std::uintptr_t cats, std::uintptr_t mask,
13- std::size_t n_cells, std::size_t n_genes, std::size_t n_groups) {
13+ std::size_t n_cells, std::size_t n_genes, std::size_t n_groups,
14+ cudaStream_t stream) {
1415 dim3 grid ((unsigned )n_cells);
1516 dim3 block (64 );
16- csr_aggr_kernel<T>
17- <<<grid, block>>> ( reinterpret_cast <const int *>(indptr), reinterpret_cast <const int *>(index),
18- reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
19- reinterpret_cast <const int *>(cats), reinterpret_cast <const bool *>(mask),
20- n_cells, n_genes, n_groups);
17+ csr_aggr_kernel<T><<<grid, block, 0 , stream>>> (
18+ reinterpret_cast <const int *>(indptr), reinterpret_cast <const int *>(index),
19+ reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
20+ reinterpret_cast <const int *>(cats), reinterpret_cast <const bool *>(mask), n_cells, n_genes ,
21+ n_groups);
2122}
2223
2324template <typename T>
2425static inline void launch_csc_aggr (std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data,
2526 std::uintptr_t out, std::uintptr_t cats, std::uintptr_t mask,
26- std::size_t n_cells, std::size_t n_genes, std::size_t n_groups) {
27+ std::size_t n_cells, std::size_t n_genes, std::size_t n_groups,
28+ cudaStream_t stream) {
2729 dim3 grid ((unsigned )n_genes);
2830 dim3 block (64 );
29- csc_aggr_kernel<T>
30- <<<grid, block>>> ( reinterpret_cast <const int *>(indptr), reinterpret_cast <const int *>(index),
31- reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
32- reinterpret_cast <const int *>(cats), reinterpret_cast <const bool *>(mask),
33- n_cells, n_genes, n_groups);
31+ csc_aggr_kernel<T><<<grid, block, 0 , stream>>> (
32+ reinterpret_cast <const int *>(indptr), reinterpret_cast <const int *>(index),
33+ reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
34+ reinterpret_cast <const int *>(cats), reinterpret_cast <const bool *>(mask), n_cells, n_genes ,
35+ n_groups);
3436}
3537
3638template <typename T>
3739static inline void launch_csr_to_coo (std::uintptr_t indptr, std::uintptr_t index,
3840 std::uintptr_t data, std::uintptr_t row, std::uintptr_t col,
3941 std::uintptr_t ndata, std::uintptr_t cats, std::uintptr_t mask,
40- int n_cells) {
42+ int n_cells, cudaStream_t stream ) {
4143 dim3 grid ((unsigned )n_cells);
4244 dim3 block (64 );
43- csr_to_coo_kernel<T><<<grid, block>>> (
45+ csr_to_coo_kernel<T><<<grid, block, 0 , stream >>> (
4446 reinterpret_cast <const int *>(indptr), reinterpret_cast <const int *>(index),
4547 reinterpret_cast <const T*>(data), reinterpret_cast <int *>(row), reinterpret_cast <int *>(col),
4648 reinterpret_cast <double *>(ndata), reinterpret_cast <const int *>(cats),
@@ -50,93 +52,120 @@ static inline void launch_csr_to_coo(std::uintptr_t indptr, std::uintptr_t index
5052template <typename T>
5153static inline void launch_dense_C (std::uintptr_t data, std::uintptr_t out, std::uintptr_t cats,
5254 std::uintptr_t mask, std::size_t n_cells, std::size_t n_genes,
53- std::size_t n_groups) {
55+ std::size_t n_groups, cudaStream_t stream ) {
5456 dim3 block (256 );
5557 dim3 grid ((unsigned )((n_cells * n_genes + block.x - 1 ) / block.x ));
5658 dense_aggr_kernel_C<T>
57- <<<grid, block>>> (reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
58- reinterpret_cast < const int *>(cats), reinterpret_cast <const bool *>(mask ),
59- n_cells, n_genes, n_groups);
59+ <<<grid, block, 0 , stream >>> (reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
60+ reinterpret_cast <const int *>(cats ),
61+ reinterpret_cast < const bool *>(mask), n_cells, n_genes, n_groups);
6062}
6163
6264template <typename T>
6365static inline void launch_dense_F (std::uintptr_t data, std::uintptr_t out, std::uintptr_t cats,
6466 std::uintptr_t mask, std::size_t n_cells, std::size_t n_genes,
65- std::size_t n_groups) {
67+ std::size_t n_groups, cudaStream_t stream ) {
6668 dim3 block (256 );
6769 dim3 grid ((unsigned )((n_cells * n_genes + block.x - 1 ) / block.x ));
6870 dense_aggr_kernel_F<T>
69- <<<grid, block>>> (reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
70- reinterpret_cast < const int *>(cats), reinterpret_cast <const bool *>(mask ),
71- n_cells, n_genes, n_groups);
71+ <<<grid, block, 0 , stream >>> (reinterpret_cast <const T*>(data), reinterpret_cast <double *>(out),
72+ reinterpret_cast <const int *>(cats ),
73+ reinterpret_cast < const bool *>(mask), n_cells, n_genes, n_groups);
7274}
7375
7476// Unified dispatchers
7577static inline void sparse_aggr_dispatch (std::uintptr_t indptr, std::uintptr_t index,
7678 std::uintptr_t data, std::uintptr_t out,
7779 std::uintptr_t cats, std::uintptr_t mask,
7880 std::size_t n_cells, std::size_t n_genes,
79- std::size_t n_groups, bool is_csc, int dtype_itemsize) {
81+ std::size_t n_groups, bool is_csc, int dtype_itemsize,
82+ std::uintptr_t stream) {
8083 if (is_csc) {
8184 if (dtype_itemsize == 4 ) {
82- launch_csc_aggr<float >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
85+ launch_csc_aggr<float >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
86+ (cudaStream_t)stream);
8387 } else {
84- launch_csc_aggr<double >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
88+ launch_csc_aggr<double >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
89+ (cudaStream_t)stream);
8590 }
8691 } else {
8792 if (dtype_itemsize == 4 ) {
88- launch_csr_aggr<float >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
93+ launch_csr_aggr<float >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
94+ (cudaStream_t)stream);
8995 } else {
90- launch_csr_aggr<double >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
96+ launch_csr_aggr<double >(indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups,
97+ (cudaStream_t)stream);
9198 }
9299 }
93100}
94101
95102static inline void dense_aggr_dispatch (std::uintptr_t data, std::uintptr_t out, std::uintptr_t cats,
96103 std::uintptr_t mask, std::size_t n_cells,
97104 std::size_t n_genes, std::size_t n_groups, bool is_fortran,
98- int dtype_itemsize) {
105+ int dtype_itemsize, std:: uintptr_t stream ) {
99106 if (is_fortran) {
100107 if (dtype_itemsize == 4 ) {
101- launch_dense_F<float >(data, out, cats, mask, n_cells, n_genes, n_groups);
108+ launch_dense_F<float >(data, out, cats, mask, n_cells, n_genes, n_groups,
109+ (cudaStream_t)stream);
102110 } else {
103- launch_dense_F<double >(data, out, cats, mask, n_cells, n_genes, n_groups);
111+ launch_dense_F<double >(data, out, cats, mask, n_cells, n_genes, n_groups,
112+ (cudaStream_t)stream);
104113 }
105114 } else {
106115 if (dtype_itemsize == 4 ) {
107- launch_dense_C<float >(data, out, cats, mask, n_cells, n_genes, n_groups);
116+ launch_dense_C<float >(data, out, cats, mask, n_cells, n_genes, n_groups,
117+ (cudaStream_t)stream);
108118 } else {
109- launch_dense_C<double >(data, out, cats, mask, n_cells, n_genes, n_groups);
119+ launch_dense_C<double >(data, out, cats, mask, n_cells, n_genes, n_groups,
120+ (cudaStream_t)stream);
110121 }
111122 }
112123}
113124
114125static inline void csr_to_coo_dispatch (std::uintptr_t indptr, std::uintptr_t index,
115126 std::uintptr_t data, std::uintptr_t row, std::uintptr_t col,
116127 std::uintptr_t ndata, std::uintptr_t cats,
117- std::uintptr_t mask, int n_cells, int dtype_itemsize) {
128+ std::uintptr_t mask, int n_cells, int dtype_itemsize,
129+ std::uintptr_t stream) {
118130 if (dtype_itemsize == 4 ) {
119- launch_csr_to_coo<float >(indptr, index, data, row, col, ndata, cats, mask, n_cells);
131+ launch_csr_to_coo<float >(indptr, index, data, row, col, ndata, cats, mask, n_cells,
132+ (cudaStream_t)stream);
120133 } else {
121- launch_csr_to_coo<double >(indptr, index, data, row, col, ndata, cats, mask, n_cells);
134+ launch_csr_to_coo<double >(indptr, index, data, row, col, ndata, cats, mask, n_cells,
135+ (cudaStream_t)stream);
122136 }
123137}
124138
125139// variance launcher
126140static inline void launch_sparse_var (std::uintptr_t indptr, std::uintptr_t index,
127141 std::uintptr_t data, std::uintptr_t mean_data,
128- std::uintptr_t n_cells, int dof, int n_groups) {
142+ std::uintptr_t n_cells, int dof, int n_groups,
143+ cudaStream_t stream) {
129144 dim3 grid ((unsigned )n_groups);
130145 dim3 block (64 );
131- sparse_var_kernel<<<grid, block>>> (
146+ sparse_var_kernel<<<grid, block, 0 , stream >>> (
132147 reinterpret_cast <const int *>(indptr), reinterpret_cast <const int *>(index),
133148 reinterpret_cast <double *>(data), reinterpret_cast <const double *>(mean_data),
134149 reinterpret_cast <double *>(n_cells), dof, n_groups);
135150}
136151
137152NB_MODULE (_aggr_cuda, m) {
138- m.def (" sparse_aggr" , &sparse_aggr_dispatch);
139- m.def (" dense_aggr" , &dense_aggr_dispatch);
140- m.def (" csr_to_coo" , &csr_to_coo_dispatch);
141- m.def (" sparse_var" , &launch_sparse_var);
153+ m.def (" sparse_aggr" , &sparse_aggr_dispatch, nb::arg (" indptr" ), nb::arg (" index" ), nb::arg (" data" ),
154+ nb::arg (" out" ), nb::arg (" cats" ), nb::arg (" mask" ), nb::arg (" n_cells" ), nb::arg (" n_genes" ),
155+ nb::arg (" n_groups" ), nb::arg (" is_csc" ), nb::arg (" dtype_itemsize" ), nb::arg (" stream" ) = 0 );
156+ m.def (" dense_aggr" , &dense_aggr_dispatch, nb::arg (" data" ), nb::arg (" out" ), nb::arg (" cats" ),
157+ nb::arg (" mask" ), nb::arg (" n_cells" ), nb::arg (" n_genes" ), nb::arg (" n_groups" ),
158+ nb::arg (" is_fortran" ), nb::arg (" dtype_itemsize" ), nb::arg (" stream" ) = 0 );
159+ m.def (" csr_to_coo" , &csr_to_coo_dispatch, nb::arg (" indptr" ), nb::arg (" index" ), nb::arg (" data" ),
160+ nb::arg (" row" ), nb::arg (" col" ), nb::arg (" ndata" ), nb::arg (" cats" ), nb::arg (" mask" ),
161+ nb::arg (" n_cells" ), nb::arg (" dtype_itemsize" ), nb::arg (" stream" ) = 0 );
162+ m.def (
163+ " sparse_var" ,
164+ [](std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data, std::uintptr_t mean_data,
165+ std::uintptr_t n_cells, int dof, int n_groups, std::uintptr_t stream) {
166+ launch_sparse_var (indptr, index, data, mean_data, n_cells, dof, n_groups,
167+ (cudaStream_t)stream);
168+ },
169+ nb::arg (" indptr" ), nb::arg (" index" ), nb::arg (" data" ), nb::arg (" mean_data" ),
170+ nb::arg (" n_cells" ), nb::arg (" dof" ), nb::arg (" n_groups" ), nb::arg (" stream" ) = 0 );
142171}
0 commit comments