-
Notifications
You must be signed in to change notification settings - Fork 278
Expand file tree
/
Copy pathlevel_08.cu
More file actions
158 lines (129 loc) · 5.31 KB
/
level_08.cu
File metadata and controls
158 lines (129 loc) · 5.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <chrono>
#include <cuda_runtime.h>
#include "kittens.cuh"
using namespace kittens;
constexpr int BLOCK_SIZE = 64;
constexpr int M_BLOCK = 2; // Number of consumer warp groups
constexpr int N_BLOCK = 4; // Number of output tiles per row
static constexpr int NUM_PRODUCER_WORKERS = (4);
static constexpr int NUM_CONSUMER_WORKERS = (M_BLOCK*4);
static constexpr int NUM_THREADS = ((NUM_PRODUCER_WORKERS+NUM_CONSUMER_WORKERS)*kittens::WARP_THREADS);
struct matmul_globals {
using sub_tile = st_bf<BLOCK_SIZE,BLOCK_SIZE>;
using tile_gl = gl<bf16, 1, 1, -1, -1, sub_tile>;
tile_gl A, B, C;
int N;
};
__global__ void kernel(const __grid_constant__ matmul_globals g) {
extern __shared__ alignment_dummy __shm[];
shared_allocator al((int*)&__shm[0]);
st_bf<BLOCK_SIZE,BLOCK_SIZE> (&As)[2][M_BLOCK] = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>, 2, M_BLOCK>();
st_bf<BLOCK_SIZE,BLOCK_SIZE> (&Bs)[2][N_BLOCK] = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>, 2, N_BLOCK>();
st_bf<BLOCK_SIZE,BLOCK_SIZE> (&C_tiles)[M_BLOCK][N_BLOCK] = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>, M_BLOCK, N_BLOCK>();
int tic = 0;
int toc = 1;
// Accumulator for each consumer warp group
using wide_tile = st_bf<BLOCK_SIZE, BLOCK_SIZE*N_BLOCK>;
rt_fl<16, BLOCK_SIZE*N_BLOCK> C_accum;
int row = blockIdx.y * M_BLOCK;
int col = blockIdx.x * N_BLOCK;
const int warpid = kittens::warpid();
const int warpgroupid = warpid/4;
// Determine type of warp group
bool is_producer = (warpgroupid == 0);
bool is_consumer = (warpgroupid > 0 && warpgroupid <= M_BLOCK);
// Consumer index (0-based) for consumer warp groups
int consumer_idx = is_consumer ? (warpgroupid - 1) : 0;
__shared__ semaphore bar;
if (threadIdx.x == 0) {
init_semaphore(bar, 0, 1);
tma::expect_bytes(
bar,
M_BLOCK * size_bytes<typeof(As[0][0])> +
N_BLOCK * size_bytes<typeof(Bs[0][0])>
);
// Load initial A tiles (one row per consumer)
for (int m = 0; m < M_BLOCK; m++) {
tma::load_async(As[tic][m], g.A, {0, 0, row + m, 0}, bar);
}
// Load initial B tiles (all columns for this thread block)
for (int n = 0; n < N_BLOCK; n++) {
tma::load_async(Bs[tic][n], g.B, {0, 0, 0, col + n}, bar);
}
}
__syncthreads();
if (is_consumer) {
kittens::warp::zero(C_accum);
}
int num_tiles = (g.N + BLOCK_SIZE - 1) / BLOCK_SIZE;
for (int tile = 0; tile < num_tiles; ++tile, tic^=1, toc^=1) {
wait(bar, tic);
__syncthreads();
if (is_producer) {
warpgroup::decrease_registers<40>();
if (threadIdx.x == 0 && tile+1 < num_tiles) {
tma::expect_bytes(bar,
M_BLOCK * size_bytes<typeof(As[0][0])> +
N_BLOCK * size_bytes<typeof(Bs[0][0])>
);
for (int m = 0; m < M_BLOCK; m++) {
tma::load_async(As[toc][m], g.A, {0, 0, row + m, tile+1}, bar);
}
for (int n = 0; n < N_BLOCK; n++) {
tma::load_async(Bs[toc][n], g.B, {0, 0, tile+1, col + n}, bar);
}
}
}
else if (is_consumer) {
warpgroup::increase_registers<232>();
// Each consumer processes its assigned row of A against all columns of B
warpgroup::mma_AB(
C_accum,
As[tic][consumer_idx], // Get this consumer's A tile
reinterpret_cast<wide_tile&>(Bs[tic][0]) // Get all B tiles as a wide tile
);
warpgroup::mma_async_wait();
}
__syncthreads();
}
// Store
if (is_consumer) {
// First store the wide result to temporary tiles
wide_tile& wide_C_temp = reinterpret_cast<wide_tile&>(C_tiles[consumer_idx][0]);
warpgroup::store(wide_C_temp, C_accum);
warpgroup::sync(warpgroupid+4);
// Only first warp in each consumer group stores to global memory
if (warpid % 4 == 0) {
for (int n = 0; n < N_BLOCK; n++) {
tma::store_async(g.C, C_tiles[consumer_idx][n], {0, 0, row + consumer_idx, col + n});
tma::store_async_read_wait();
}
}
}
}
// Launch kernel
void matmul(bf16* A, bf16* B, bf16* C, size_t N) {
// Global pointers
using tile_gl = matmul_globals::tile_gl;
tile_gl a_arg{A, nullptr, nullptr, N, N};
tile_gl b_arg{B, nullptr, nullptr, N, N};
tile_gl c_arg{C, nullptr, nullptr, N, N};
matmul_globals g{a_arg, b_arg, c_arg, (int)N};
// Launch
int NEW_ROW_BLOCK_SIZE = BLOCK_SIZE*M_BLOCK;
int NEW_COL_BLOCK_SIZE = BLOCK_SIZE*N_BLOCK;
dim3 blocks(
(N + NEW_COL_BLOCK_SIZE - 1) / (NEW_COL_BLOCK_SIZE),
(N + NEW_ROW_BLOCK_SIZE - 1) / (NEW_ROW_BLOCK_SIZE)
);
unsigned long mem_size = 200000;
cudaDeviceSynchronize();
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size);
kernel<<<blocks, NUM_THREADS, mem_size>>>(g);
CHECK_CUDA_ERROR(cudaGetLastError());
cudaDeviceSynchronize();
}
#include "launch.cu"