-
Notifications
You must be signed in to change notification settings - Fork 278
Expand file tree
/
Copy pathlevel_06.cu
More file actions
103 lines (84 loc) · 3.13 KB
/
level_06.cu
File metadata and controls
103 lines (84 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <chrono>
#include <cuda_runtime.h>
#include "kittens.cuh"
using namespace kittens;
static constexpr int BLOCK_SIZE = 64;
static constexpr int NUM_WORKERS = (4);
static constexpr int NUM_THREADS = (NUM_WORKERS*kittens::WARP_THREADS);
struct matmul_globals {
using sub_tile = st_bf<BLOCK_SIZE,BLOCK_SIZE>;
using tile_gl = gl<bf16, 1, 1, -1, -1, sub_tile>;
tile_gl A, B, C;
int N;
};
__global__ void kernel(const __grid_constant__ matmul_globals g) {
extern __shared__ alignment_dummy __shm[];
shared_allocator al((int*)&__shm[0]);
st_bf<BLOCK_SIZE,BLOCK_SIZE> (&As)[2] = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>, 2>();
st_bf<BLOCK_SIZE,BLOCK_SIZE> (&Bs)[2] = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>, 2>();
int tic = 0;
int toc = 1;
rt_fl<16,BLOCK_SIZE> C_accum;
rt_fl<16,BLOCK_SIZE> C_accum_cpy;
int row = blockIdx.y;
int col = blockIdx.x;
int condition = (threadIdx.x == 0 && threadIdx.y == 0 & blockIdx.x == 0);
__shared__ semaphore bar;
if (threadIdx.x == 0) { // this should be on thread and not warp (SA: note)
init_semaphore(bar, 0, 1);
tma::expect_bytes(
bar,
size_bytes<typeof(As[0])> +
size_bytes<typeof(Bs[0])>
);
tma::load_async(As[tic], g.A, {0, 0, row, 0}, bar);
tma::load_async(Bs[tic], g.B, {0, 0, 0, col}, bar);
}
__syncthreads();
kittens::warp::zero(C_accum_cpy);
int num_tiles = (g.N + BLOCK_SIZE - 1) / BLOCK_SIZE;
for (int tile = 0; tile < num_tiles; ++tile, tic^=1, toc^=1) {
// arrive memory
wait(bar, tic);
__syncthreads();
// load next
if (threadIdx.x == 0 && tile+1 < num_tiles) {
tma::expect_bytes(
bar,
size_bytes<typeof(As[0])> +
size_bytes<typeof(Bs[0])>
);
tma::load_async(As[toc], g.A, {0, 0, row, tile+1}, bar);
tma::load_async(Bs[toc], g.B, {0, 0, tile+1, col}, bar);
}
warpgroup::mma_AB(C_accum, As[tic], Bs[tic]);
warpgroup::mma_async_wait();
kittens::warp::add(C_accum_cpy, C_accum_cpy, C_accum);
kittens::warp::zero(C_accum);
__syncthreads();
}
warpgroup::store(g.C, C_accum_cpy, {0, 0, row, col});
}
// launch kernel
void matmul(bf16* A, bf16* B, bf16* C, size_t N) {
// global pointers
using a_gl = matmul_globals::tile_gl;
using b_gl = matmul_globals::tile_gl;
using c_gl = matmul_globals::tile_gl;
a_gl a_arg{A, nullptr, nullptr, N, N};
b_gl b_arg{B, nullptr, nullptr, N, N};
c_gl c_arg{C, nullptr, nullptr, N, N};
matmul_globals g{a_arg, b_arg, c_arg, (int)N};
// launch
dim3 blocks((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
unsigned long mem_size = 100000;
cudaDeviceSynchronize();
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size);
kernel<<<blocks, NUM_THREADS, mem_size>>>(g);
CHECK_CUDA_ERROR(cudaGetLastError());
cudaDeviceSynchronize();
}
#include "launch.cu"