-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy pathkernels_pr.cuh
More file actions
83 lines (78 loc) · 3.27 KB
/
kernels_pr.cuh
File metadata and controls
83 lines (78 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#pragma once
#include <cuda_runtime.h>
template <typename T>
__global__ void sparse_norm_res_csc_kernel(
const int* __restrict__ indptr, const int* __restrict__ index,
const T* __restrict__ data, const T* __restrict__ sums_cells,
const T* __restrict__ sums_genes, T* __restrict__ residuals,
const T inv_sum_total, const T clip, const T inv_theta, int n_cells,
int n_genes) {
int gene = blockDim.x * blockIdx.x + threadIdx.x;
if (gene >= n_genes) {
return;
}
int start = indptr[gene];
int stop = indptr[gene + 1];
int sparse_idx = start;
for (int cell = 0; cell < n_cells; ++cell) {
T mu = sums_genes[gene] * sums_cells[cell] * inv_sum_total;
long long res_index = static_cast<long long>(cell) * n_genes + gene;
if (sparse_idx < stop && index[sparse_idx] == cell) {
residuals[res_index] += data[sparse_idx];
++sparse_idx;
}
residuals[res_index] -= mu;
residuals[res_index] *= rsqrt(mu + mu * mu * inv_theta);
// clamp to [-clip, clip]
if (residuals[res_index] < -clip) residuals[res_index] = -clip;
if (residuals[res_index] > clip) residuals[res_index] = clip;
}
}
template <typename T>
__global__ void sparse_norm_res_csr_kernel(
const int* __restrict__ indptr, const int* __restrict__ index,
const T* __restrict__ data, const T* __restrict__ sums_cells,
const T* __restrict__ sums_genes, T* __restrict__ residuals,
const T inv_sum_total, const T clip, const T inv_theta, int n_cells,
int n_genes) {
int cell = blockDim.x * blockIdx.x + threadIdx.x;
if (cell >= n_cells) {
return;
}
int start = indptr[cell];
int stop = indptr[cell + 1];
int sparse_idx = start;
for (int gene = 0; gene < n_genes; ++gene) {
long long res_index = static_cast<long long>(cell) * n_genes + gene;
T mu = sums_genes[gene] * sums_cells[cell] * inv_sum_total;
if (sparse_idx < stop && index[sparse_idx] == gene) {
residuals[res_index] += data[sparse_idx];
++sparse_idx;
}
residuals[res_index] -= mu;
residuals[res_index] *= rsqrt(mu + mu * mu * inv_theta);
if (residuals[res_index] < -clip) residuals[res_index] = -clip;
if (residuals[res_index] > clip) residuals[res_index] = clip;
}
}
template <typename T>
__global__ void dense_norm_res_kernel(const T* __restrict__ X,
T* __restrict__ residuals,
const T* __restrict__ sums_cells,
const T* __restrict__ sums_genes,
const T inv_inv_sum_total, const T clip,
const T inv_theta, int n_cells,
int n_genes) {
int cell = blockDim.x * blockIdx.x + threadIdx.x;
int gene = blockDim.y * blockIdx.y + threadIdx.y;
if (cell >= n_cells || gene >= n_genes) {
return;
}
T mu = sums_genes[gene] * sums_cells[cell] * inv_inv_sum_total;
long long res_index = static_cast<long long>(cell) * n_genes + gene;
T r = X[res_index] - mu;
r *= rsqrt(mu + mu * mu * inv_theta);
if (r < -clip) r = -clip;
if (r > clip) r = clip;
residuals[res_index] = r;
}