@@ -8,28 +8,28 @@ using namespace nb::literals;
88constexpr int BLOCK_SIZE_SPARSE = 64 ;
99constexpr int BLOCK_SIZE_DENSE = 256 ;
1010
11- template <typename T>
12- static inline void launch_csr_aggr (const int * indptr, const int * index,
11+ template <typename T, typename IdxT >
12+ static inline void launch_csr_aggr (const IdxT * indptr, const IdxT * index,
1313 const T* data, double * out, const int * cats,
1414 const bool * mask, size_t n_cells,
1515 size_t n_genes, size_t n_groups,
1616 cudaStream_t stream) {
1717 dim3 grid ((unsigned )n_cells);
1818 dim3 block (BLOCK_SIZE_SPARSE);
19- csr_aggr_kernel<T><<<grid, block, 0 , stream>>> (
19+ csr_aggr_kernel<T, IdxT ><<<grid, block, 0 , stream>>> (
2020 indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
2121 CUDA_CHECK_LAST_ERROR (csr_aggr_kernel);
2222}
2323
24- template <typename T>
25- static inline void launch_csc_aggr (const int * indptr, const int * index,
24+ template <typename T, typename IdxT >
25+ static inline void launch_csc_aggr (const IdxT * indptr, const IdxT * index,
2626 const T* data, double * out, const int * cats,
2727 const bool * mask, size_t n_cells,
2828 size_t n_genes, size_t n_groups,
2929 cudaStream_t stream) {
3030 dim3 grid ((unsigned )n_genes);
3131 dim3 block (BLOCK_SIZE_SPARSE);
32- csc_aggr_kernel<T><<<grid, block, 0 , stream>>> (
32+ csc_aggr_kernel<T, IdxT ><<<grid, block, 0 , stream>>> (
3333 indptr, index, data, out, cats, mask, n_cells, n_genes, n_groups);
3434 CUDA_CHECK_LAST_ERROR (csc_aggr_kernel);
3535}
@@ -58,50 +58,51 @@ static inline void launch_dense_aggr_F(const T* data, double* out,
5858 CUDA_CHECK_LAST_ERROR (dense_aggr_kernel_F);
5959}
6060
61- template <typename T>
62- static inline void launch_csr_to_coo (const int * indptr, const int * index,
61+ template <typename T, typename IdxT >
62+ static inline void launch_csr_to_coo (const IdxT * indptr, const IdxT * index,
6363 const T* data, int * row, int * col,
6464 double * ndata, const int * cats,
6565 const bool * mask, int n_cells,
6666 cudaStream_t stream) {
6767 dim3 grid ((unsigned )n_cells);
6868 dim3 block (BLOCK_SIZE_SPARSE);
69- csr_to_coo_kernel<T><<<grid, block, 0 , stream>>> (
69+ csr_to_coo_kernel<T, IdxT ><<<grid, block, 0 , stream>>> (
7070 indptr, index, data, row, col, ndata, cats, mask, n_cells);
7171 CUDA_CHECK_LAST_ERROR (csr_to_coo_kernel);
7272}
7373
74- static inline void launch_sparse_var (const int * indptr, const int * index,
74+ template <typename IdxT>
75+ static inline void launch_sparse_var (const IdxT* indptr, const IdxT* index,
7576 double * data, const double * mean_data,
7677 double * n_cells, int dof, int n_groups,
7778 cudaStream_t stream) {
7879 dim3 grid ((unsigned )n_groups);
7980 dim3 block (BLOCK_SIZE_SPARSE);
80- sparse_var_kernel<<<grid, block, 0 , stream>>> (
81+ sparse_var_kernel<IdxT> < <<grid, block, 0 , stream>>> (
8182 indptr, index, data, mean_data, n_cells, dof, n_groups);
8283 CUDA_CHECK_LAST_ERROR (sparse_var_kernel);
8384}
8485
85- template <typename T, typename Device>
86+ template <typename T, typename IdxT, typename Device>
8687void def_sparse_aggr (nb::module_& m) {
8788 m.def (
8889 " sparse_aggr" ,
89- [](gpu_array_c<const int , Device> indptr,
90- gpu_array_c<const int , Device> index,
90+ [](gpu_array_c<const IdxT , Device> indptr,
91+ gpu_array_c<const IdxT , Device> index,
9192 gpu_array_c<const T, Device> data, gpu_array_c<double , Device> out,
9293 gpu_array_c<const int , Device> cats,
9394 gpu_array_c<const bool , Device> mask, size_t n_cells, size_t n_genes,
9495 size_t n_groups, bool is_csc, std::uintptr_t stream) {
9596 if (is_csc) {
96- launch_csc_aggr<T>(indptr.data (), index. data (), data .data (),
97- out .data (), cats .data (), mask .data (),
98- n_cells, n_genes, n_groups ,
99- (cudaStream_t)stream);
97+ launch_csc_aggr<T, IdxT >(indptr.data (), index.data (),
98+ data .data (), out .data (), cats .data (),
99+ mask. data (), n_cells, n_genes ,
100+ n_groups, (cudaStream_t)stream);
100101 } else {
101- launch_csr_aggr<T>(indptr.data (), index. data (), data .data (),
102- out .data (), cats .data (), mask .data (),
103- n_cells, n_genes, n_groups ,
104- (cudaStream_t)stream);
102+ launch_csr_aggr<T, IdxT >(indptr.data (), index.data (),
103+ data .data (), out .data (), cats .data (),
104+ mask. data (), n_cells, n_genes ,
105+ n_groups, (cudaStream_t)stream);
105106 }
106107 },
107108 " indptr" _a, " index" _a, " data" _a, nb::kw_only (), " out" _a, " cats" _a,
@@ -131,56 +132,66 @@ void def_dense_aggr(nb::module_& m) {
131132 " n_genes" _a, " n_groups" _a, " is_fortran" _a, " stream" _a = 0 );
132133}
133134
134- template <typename T, typename Device>
135+ template <typename T, typename IdxT, typename Device>
135136void def_csr_to_coo (nb::module_& m) {
136137 m.def (
137138 " csr_to_coo" ,
138- [](gpu_array_c<const int , Device> indptr,
139- gpu_array_c<const int , Device> index,
139+ [](gpu_array_c<const IdxT , Device> indptr,
140+ gpu_array_c<const IdxT , Device> index,
140141 gpu_array_c<const T, Device> data, gpu_array_c<int , Device> out_row,
141142 gpu_array_c<int , Device> out_col,
142143 gpu_array_c<double , Device> out_data,
143144 gpu_array_c<const int , Device> cats,
144145 gpu_array_c<const bool , Device> mask, int n_cells,
145146 std::uintptr_t stream) {
146- launch_csr_to_coo<T>(indptr. data (), index. data (), data. data (),
147- out_row .data (), out_col .data (),
148- out_data.data (), cats.data (), mask.data (),
149- n_cells, (cudaStream_t)stream);
147+ launch_csr_to_coo<T, IdxT>(
148+ indptr. data (), index. data (), data .data (), out_row .data (),
149+ out_col. data (), out_data.data (), cats.data (), mask.data (),
150+ n_cells, (cudaStream_t)stream);
150151 },
151152 " indptr" _a, " index" _a, " data" _a, nb::kw_only (), " out_row" _a,
152153 " out_col" _a, " out_data" _a, " cats" _a, " mask" _a, " n_cells" _a,
153154 " stream" _a = 0 );
154155}
155156
157+ template <typename IdxT, typename Device>
158+ void def_sparse_var (nb::module_& m) {
159+ m.def (
160+ " sparse_var" ,
161+ [](gpu_array_c<const IdxT, Device> indptr,
162+ gpu_array_c<const IdxT, Device> index,
163+ gpu_array_c<double , Device> data,
164+ gpu_array_c<const double , Device> means,
165+ gpu_array_c<double , Device> n_cells, int dof, int n_groups,
166+ std::uintptr_t stream) {
167+ launch_sparse_var<IdxT>(indptr.data (), index.data (), data.data (),
168+ means.data (), n_cells.data (), dof, n_groups,
169+ (cudaStream_t)stream);
170+ },
171+ " indptr" _a, " index" _a, " data" _a, nb::kw_only (), " means" _a, " n_cells" _a,
172+ " dof" _a, " n_groups" _a, " stream" _a = 0 );
173+ }
174+
156175template <typename Device>
157176void register_bindings (nb::module_& m) {
158- def_sparse_aggr<float , Device>(m);
159- def_sparse_aggr<double , Device>(m);
177+ def_sparse_aggr<float , int , Device>(m);
178+ def_sparse_aggr<float , long long , Device>(m);
179+ def_sparse_aggr<double , int , Device>(m);
180+ def_sparse_aggr<double , long long , Device>(m);
160181
161182 // F-order must come before C-order for proper dispatch
162183 def_dense_aggr<float , nb::f_contig, Device>(m);
163184 def_dense_aggr<float , nb::c_contig, Device>(m);
164185 def_dense_aggr<double , nb::f_contig, Device>(m);
165186 def_dense_aggr<double , nb::c_contig, Device>(m);
166187
167- def_csr_to_coo<float , Device>(m);
168- def_csr_to_coo<double , Device>(m);
188+ def_csr_to_coo<float , int , Device>(m);
189+ def_csr_to_coo<float , long long , Device>(m);
190+ def_csr_to_coo<double , int , Device>(m);
191+ def_csr_to_coo<double , long long , Device>(m);
169192
170- m.def (
171- " sparse_var" ,
172- [](gpu_array_c<const int , Device> indptr,
173- gpu_array_c<const int , Device> index,
174- gpu_array_c<double , Device> data,
175- gpu_array_c<const double , Device> means,
176- gpu_array_c<double , Device> n_cells, int dof, int n_groups,
177- std::uintptr_t stream) {
178- launch_sparse_var (indptr.data (), index.data (), data.data (),
179- means.data (), n_cells.data (), dof, n_groups,
180- (cudaStream_t)stream);
181- },
182- " indptr" _a, " index" _a, " data" _a, nb::kw_only (), " means" _a, " n_cells" _a,
183- " dof" _a, " n_groups" _a, " stream" _a = 0 );
193+ def_sparse_var<int , Device>(m);
194+ def_sparse_var<long long , Device>(m);
184195}
185196
186197NB_MODULE (_aggr_cuda, m) {
0 commit comments