-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path02_tiled.cuh
More file actions
62 lines (47 loc) · 1.78 KB
/
02_tiled.cuh
File metadata and controls
62 lines (47 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#pragma once
namespace k2 {
template <const int TILE_DIM>
__global__ void sgemm_tiled_kernel(int M, int N, int K, float alpha,
const float *d_A, const float *d_B,
float beta, float *d_C) {
int col = blockIdx.x * TILE_DIM + threadIdx.x;
int row = blockIdx.y * TILE_DIM + threadIdx.y;
int num_tiles = (K + TILE_DIM - 1) / TILE_DIM;
__shared__ float A_tile[TILE_DIM][TILE_DIM];
__shared__ float B_tile[TILE_DIM][TILE_DIM];
float sum = 0;
for (int i = 0; i < num_tiles; i++) {
int A_read_row = row;
int A_read_col = TILE_DIM * i + threadIdx.x;
if (A_read_row < M && A_read_col < K) {
A_tile[threadIdx.y][threadIdx.x] = d_A[A_read_row * K + A_read_col];
} else {
A_tile[threadIdx.y][threadIdx.x] = 0;
}
int B_read_row = TILE_DIM * i + threadIdx.y;
int B_read_col = col;
if (B_read_row < K && B_read_col < N) {
B_tile[threadIdx.y][threadIdx.x] = d_B[B_read_row * N + B_read_col];
} else {
B_tile[threadIdx.y][threadIdx.x] = 0;
}
__syncthreads();
for (int k = 0; k < TILE_DIM; k++) {
sum += A_tile[threadIdx.y][k] * B_tile[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
d_C[row * N + col] = alpha * sum + beta * d_C[row * N + col];
}
}
void run_tiled_kernel(int M, int N, int K, float alpha, const float *d_A,
const float *d_B, float beta, float *d_C) {
const int TILE_DIM = 16;
dim3 threads_per_block(TILE_DIM, TILE_DIM);
dim3 grid_dim((N + threads_per_block.x - 1) / threads_per_block.x,
(M + threads_per_block.y - 1) / threads_per_block.y);
sgemm_tiled_kernel<TILE_DIM>
<<<grid_dim, threads_per_block>>>(M, N, K, alpha, d_A, d_B, beta, d_C);
}
} // namespace k2