|
| 1 | +#include <cuda_runtime.h> |
| 2 | +#include <nanobind/nanobind.h> |
| 3 | +#include <cstdint> |
| 4 | + |
| 5 | +#include "kernels_spca.cuh" |
| 6 | + |
| 7 | +namespace nb = nanobind; |
| 8 | + |
| 9 | +template <typename T> |
| 10 | +static inline void launch_gram_csr_upper(std::uintptr_t indptr_ptr, std::uintptr_t index_ptr, |
| 11 | + std::uintptr_t data_ptr, int nrows, int ncols, |
| 12 | + std::uintptr_t out_ptr) { |
| 13 | + dim3 block(128); |
| 14 | + dim3 grid(nrows); |
| 15 | + const int* indptr = reinterpret_cast<const int*>(indptr_ptr); |
| 16 | + const int* index = reinterpret_cast<const int*>(index_ptr); |
| 17 | + const T* data = reinterpret_cast<const T*>(data_ptr); |
| 18 | + T* out = reinterpret_cast<T*>(out_ptr); |
| 19 | + gram_csr_upper_kernel<T><<<grid, block>>>(indptr, index, data, nrows, ncols, out); |
| 20 | +} |
| 21 | + |
| 22 | +template <typename T> |
| 23 | +static inline void launch_copy_upper_to_lower(std::uintptr_t out_ptr, int ncols) { |
| 24 | + dim3 block(32, 32); |
| 25 | + dim3 grid((ncols + block.x - 1) / block.x, (ncols + block.y - 1) / block.y); |
| 26 | + T* out = reinterpret_cast<T*>(out_ptr); |
| 27 | + copy_upper_to_lower_kernel<T><<<grid, block>>>(out, ncols); |
| 28 | +} |
| 29 | + |
| 30 | +template <typename T> |
| 31 | +static inline void launch_cov_from_gram(std::uintptr_t cov_ptr, std::uintptr_t gram_ptr, |
| 32 | + std::uintptr_t meanx_ptr, std::uintptr_t meany_ptr, |
| 33 | + int ncols) { |
| 34 | + dim3 block(32, 32); |
| 35 | + dim3 grid((ncols + 31) / 32, (ncols + 31) / 32); |
| 36 | + T* cov = reinterpret_cast<T*>(cov_ptr); |
| 37 | + const T* gram = reinterpret_cast<const T*>(gram_ptr); |
| 38 | + const T* meanx = reinterpret_cast<const T*>(meanx_ptr); |
| 39 | + const T* meany = reinterpret_cast<const T*>(meany_ptr); |
| 40 | + cov_from_gram_kernel<T><<<grid, block>>>(cov, gram, meanx, meany, ncols); |
| 41 | +} |
| 42 | + |
| 43 | +static inline void launch_check_zero_genes(std::uintptr_t indices_ptr, std::uintptr_t genes_ptr, |
| 44 | + int nnz) { |
| 45 | + dim3 block(32); |
| 46 | + dim3 grid((nnz + block.x - 1) / block.x); |
| 47 | + const int* indices = reinterpret_cast<const int*>(indices_ptr); |
| 48 | + int* genes = reinterpret_cast<int*>(genes_ptr); |
| 49 | + check_zero_genes_kernel<<<grid, block>>>(indices, genes, nnz); |
| 50 | +} |
| 51 | + |
| 52 | +NB_MODULE(_spca_cuda, m) { |
| 53 | + m.def("gram_csr_upper", [](std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data, |
| 54 | + int nrows, int ncols, std::uintptr_t out, int itemsize) { |
| 55 | + if (itemsize == 4) { |
| 56 | + launch_gram_csr_upper<float>(indptr, index, data, nrows, ncols, out); |
| 57 | + } else if (itemsize == 8) { |
| 58 | + launch_gram_csr_upper<double>(indptr, index, data, nrows, ncols, out); |
| 59 | + } else { |
| 60 | + throw nb::value_error("Unsupported itemsize (expected 4 or 8)"); |
| 61 | + } |
| 62 | + }); |
| 63 | + |
| 64 | + m.def("copy_upper_to_lower", [](std::uintptr_t out, int ncols, int itemsize) { |
| 65 | + if (itemsize == 4) { |
| 66 | + launch_copy_upper_to_lower<float>(out, ncols); |
| 67 | + } else if (itemsize == 8) { |
| 68 | + launch_copy_upper_to_lower<double>(out, ncols); |
| 69 | + } else { |
| 70 | + throw nb::value_error("Unsupported itemsize (expected 4 or 8)"); |
| 71 | + } |
| 72 | + }); |
| 73 | + |
| 74 | + m.def("cov_from_gram", [](std::uintptr_t cov, std::uintptr_t gram, std::uintptr_t meanx, |
| 75 | + std::uintptr_t meany, int ncols, int itemsize) { |
| 76 | + if (itemsize == 4) { |
| 77 | + launch_cov_from_gram<float>(cov, gram, meanx, meany, ncols); |
| 78 | + } else if (itemsize == 8) { |
| 79 | + launch_cov_from_gram<double>(cov, gram, meanx, meany, ncols); |
| 80 | + } else { |
| 81 | + throw nb::value_error("Unsupported itemsize (expected 4 or 8)"); |
| 82 | + } |
| 83 | + }); |
| 84 | + |
| 85 | + m.def("check_zero_genes", [](std::uintptr_t indices, std::uintptr_t genes, int nnz) { |
| 86 | + launch_check_zero_genes(indices, genes, nnz); |
| 87 | + }); |
| 88 | +} |
0 commit comments