-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtensorCuda.cu
132 lines (115 loc) · 4.79 KB
/
tensorCuda.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <cuda_runtime.h>
#include <stdlib.h>
#include <stdio.h>
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
static __device__ float E = 2.718281828;
static __device__ int getIndex(int *ids, int ndim, int *dims)
{
int i, id;
for (i = 0, id = ids[0]; i < ndim-1; i++)
id = dims[i+1] * id + ids[i+1];
return id;
}
static __device__ void getIndexes(int id, int *ids, int ndim, int *dims)
{
for (int i = ndim-1; i >=0; i--) {
ids[i] = id % dims[i];
id = id / dims[i];
}
}
/* __global__ void sliceTensorKernel(float *src, float *dst, int sdim, int ddim, int start, int block_size) */
/* { */
/* int di = blockIdx.x * block_size + threadIdx.x; */
/* /\* si is the index of src elements to be copied. */
/* The "block index" of src[si] is (blockIdx.x / ddim * sdim + blockIdx.x % ddim + start) *\/ */
/* int si = (blockIdx.x / ddim * sdim + blockIdx.x % ddim + start) * block_size + threadIdx.x; */
/* dst[di] = src[si]; */
/* } */
__global__ void sliceTensorKernel(float *src, float *dst, int start, int s_vol, int d_vol, int vol, int block_size, int total)
{
int di = blockIdx.x * block_size + threadIdx.x;
if (di >= total)
return;
int si = di / d_vol * s_vol + di % d_vol + start * vol;
dst[di] = src[si];
}
__global__ void reduceArgMaxKernel(float *src, float *dst, float *arg, int dim_size, int reduce_vol, int batch_vol, int block_size, int total)
{
int di = blockIdx.x * block_size + threadIdx.x;
if (di >= total)
return;
/* src[si] is the first element in this thread to be compared, then
si = batch_vol * batch + (di - reduce_vol * batch),
where batch = di / reduce_vol,
which is the same as the following code: */
int si = (batch_vol - reduce_vol) * (di / reduce_vol) + di;
float now = src[si], max = now;
int maxi = 0;
for (int i = 1; i < dim_size; i++) {
now = src[si+i*reduce_vol];
if (now > max) {
max = now;
maxi = i;
}
}
dst[di] = max;
arg[di] = maxi;
}
__global__ void multiplyElementKernel(float *src1, float *src2, float *dst, int block_size, int total)
{
int di = blockIdx.x * block_size + threadIdx.x;
if (di >= total)
return;
dst[di] = src1[di] * src2[di];
}
__global__ void transposeTensorKernel(float *src, float *dst, int ndim, int *s_dims, int *d_dims, int *s_ids, int *d_ids, int *axes, int block_size, int total)
{
int di = blockIdx.x * block_size + threadIdx.x;
if (di >= total)
return;
int *t_s_ids = s_ids + di * ndim;
int *t_d_ids = d_ids + di * ndim;
getIndexes(di, t_d_ids, ndim, d_dims);
for (int i = 0; i < ndim; i++)
t_s_ids[axes[i]] = t_d_ids[i];
int si = getIndex(t_s_ids, ndim, s_dims);
dst[di] = src[si];
}
__global__ void transformBboxSQDKernel(float *delta, float *anchor, float *res, float width, float height, float img_width, float img_height, int x_shift, int y_shift, int block_size, int total)
{
int di = blockIdx.x * block_size + threadIdx.x;
if (di >= total)
return;
/* int batch_idx = di / anchor_num; */
/* now only support batch_size = 1 */
float x_scale = 1.0 * width / img_width;
float y_scale = 1.0 * height / img_height;
/* (not used) si is the index of the first elements to be computed in the thread, then
si = 4 * anchor_num * batch_idx + (di - anchor_num * batch_idx),
which is the same as the following code: */
/* int si = 3 * anchor_num * batch_idx + di; */
/* take 4 elements from each of delta and anchor */
int si = di * 4;
float d[4] = {delta[si], delta[si+1], delta[si+2], delta[si+3]};
float a[4] = {anchor[si], anchor[si+1], anchor[si+2], anchor[si+3]};
/* compute and put 4 result elements to res, according to SqueezeDet's source code */
/* TODO: don't know why (maybe the resize), always has some shift compared to groundtruth*/
float cx = (a[0] + d[0] * a[2]) / x_scale + x_shift;
float cy = (a[1] + d[1] * a[3]) / y_scale + y_shift;
float w = (a[2] * (d[2] < 1 ? expf(d[2]) : d[2] * E)) / x_scale;
float h = (a[3] * (d[3] < 1 ? expf(d[3]) : d[3] * E)) / y_scale;
res[si] = min(max(cx - w * 0.5, 0), img_width - 1);
res[si+1] = min(max(cy - h * 0.5, 0), img_height - 1);
res[si+2] = max(min(cx + w * 0.5, img_width - 1), 0);
res[si+3] = max(min(cy + h * 0.5, img_height - 1), 0);
}
__global__ void pickElementsKernel(float *src, float *dst, int *idx, int stride, int block_size, int total)
{
int di = blockIdx.x * block_size + threadIdx.x;
if (di >= total)
return;
int si = idx[di];
for (int i = 0; i < stride; i++)
dst[di*stride+i] = src[si*stride+i];
}