Skip to content

Commit bcd98fb

Browse files
committed
cuda matmul perf increase by 30%
1 parent 30814f4 commit bcd98fb

2 files changed

Lines changed: 15 additions & 8 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.25)
22

3-
project(tinytensor VERSION 1.0.4 LANGUAGES CXX)
3+
project(tinytensor VERSION 1.0.5 LANGUAGES CXX)
44

55
# Build options
66
option(TT_BUILD_CUDA "Build tinytensor with cuda backend support" OFF)

tinytensor/tensor/backend/cuda/kernel/matmul.cuh

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,14 @@ __global__ void
3535
// Each thread is responsible for TM entries, so the "width" is divided by TM
3636
const auto thread_row = (threadIdx.x / (TILE_WIDTH / TM));
3737
const auto thread_col = (threadIdx.x % (TILE_WIDTH / TM));
38+
constexpr int THREAD_COLS = TILE_WIDTH / TM;
3839

3940
// Shared buffer for current tile block of A and B
40-
__shared__ T A_block[TILE_HEIGHT * TILE_STRIDE]; // NOLINT(*-c-arrays)
41-
__shared__ T B_block[TILE_STRIDE * TILE_WIDTH]; // NOLINT(*-c-arrays)
41+
constexpr int A_STRIDE = TILE_STRIDE + 1;
42+
constexpr int B_SKEW = TILE_WIDTH / 32;
43+
constexpr int B_STRIDE = TILE_WIDTH + B_SKEW;
44+
__shared__ T A_block[TILE_HEIGHT * A_STRIDE]; // NOLINT(*-c-arrays)
45+
__shared__ T B_block[TILE_STRIDE * B_STRIDE]; // NOLINT(*-c-arrays)
4246

4347
// starting row and column of C we will write into
4448
const auto c_row = blockIdx.y * TILE_HEIGHT;
@@ -85,7 +89,7 @@ __global__ void
8589
const bool in_range = row < N && col < K;
8690
const auto idx_A = static_cast<int>((blockIdx.z * N * K) + row * K + col);
8791
// NOLINTNEXTLINE(*-array-index)
88-
A_block[local_row * TILE_STRIDE + innerColA] = in_range ? A[idx_A] : 0;
92+
A_block[local_row * A_STRIDE + innerColA] = in_range ? A[idx_A] : 0;
8993
}
9094
// We load a block of size [stride A x TILE_STRIDE] for A
9195
for (int load_offset = 0; load_offset < TILE_STRIDE; load_offset += strideB) {
@@ -94,8 +98,9 @@ __global__ void
9498
const auto col = static_cast<int>(b_col + innerColB);
9599
const bool in_range = row < K && col < M;
96100
const auto idx_B = static_cast<int>((blockIdx.z * K * M) + row * M + col);
101+
const auto skewedColB = innerColB + (innerColB / 32);
97102
// NOLINTNEXTLINE(*-array-index)
98-
B_block[local_row * TILE_WIDTH + innerColB] = in_range ? B[idx_B] : 0;
103+
B_block[local_row * B_STRIDE + skewedColB] = in_range ? B[idx_B] : 0;
99104
}
100105

101106
// Wait for all threads to load data into the cache
@@ -111,11 +116,13 @@ __global__ void
111116
// Load TN + TM results into registers first
112117
for (int i = 0; i < TN; ++i) {
113118
const auto row_idx = thread_row * TN + i;
114-
cached_A[i] = A_block[row_idx * TILE_STRIDE + dot_idx]; // NOLINT(*-array-index)
119+
cached_A[i] = A_block[row_idx * A_STRIDE + dot_idx]; // NOLINT(*-array-index)
115120
}
116121
for (int i = 0; i < TM; ++i) {
122+
const auto col_idx = thread_col + i * THREAD_COLS;
123+
const auto skewedColRead = col_idx + (col_idx / 32);
117124
// NOLINTNEXTLINE(*-array-index)
118-
cached_B[i] = B_block[dot_idx * TILE_WIDTH + thread_col * TM + i];
125+
cached_B[i] = B_block[dot_idx * B_STRIDE + skewedColRead];
119126
}
120127

121128
// Compute TN * TM results using the cached results
@@ -134,7 +141,7 @@ __global__ void
134141
for (int i = 0; i < TN; ++i) {
135142
for (int j = 0; j < TM; ++j) {
136143
const auto row = static_cast<int>(c_row + (thread_row * TN + i));
137-
const auto col = static_cast<int>(c_col + (thread_col * TM + j));
144+
const auto col = static_cast<int>(c_col + (thread_col + j * THREAD_COLS));
138145
if (row < N && col < M) {
139146
const auto idx_C = static_cast<int>((blockIdx.z * N * M) + row * M + col);
140147
C[idx_C] = results[i * TM + j]; // NOLINT(*-array-index)

0 commit comments

Comments
 (0)