@@ -35,10 +35,14 @@ __global__ void
3535 // Each thread is responsible for TM entries, so the "width" is divided by TM
3636 const auto thread_row = (threadIdx .x / (TILE_WIDTH / TM));
3737 const auto thread_col = (threadIdx .x % (TILE_WIDTH / TM));
38+ constexpr int THREAD_COLS = TILE_WIDTH / TM;
3839
3940 // Shared buffer for current tile block of A and B
40- __shared__ T A_block[TILE_HEIGHT * TILE_STRIDE]; // NOLINT(*-c-arrays)
41- __shared__ T B_block[TILE_STRIDE * TILE_WIDTH]; // NOLINT(*-c-arrays)
41+ constexpr int A_STRIDE = TILE_STRIDE + 1 ;
42+ constexpr int B_SKEW = TILE_WIDTH / 32 ;
43+ constexpr int B_STRIDE = TILE_WIDTH + B_SKEW;
44+ __shared__ T A_block[TILE_HEIGHT * A_STRIDE]; // NOLINT(*-c-arrays)
45+ __shared__ T B_block[TILE_STRIDE * B_STRIDE]; // NOLINT(*-c-arrays)
4246
4347 // starting row and column of C we will write into
4448 const auto c_row = blockIdx .y * TILE_HEIGHT;
@@ -85,7 +89,7 @@ __global__ void
8589 const bool in_range = row < N && col < K;
8690 const auto idx_A = static_cast <int >((blockIdx .z * N * K) + row * K + col);
8791 // NOLINTNEXTLINE(*-array-index)
88- A_block[local_row * TILE_STRIDE + innerColA] = in_range ? A[idx_A] : 0 ;
92+ A_block[local_row * A_STRIDE + innerColA] = in_range ? A[idx_A] : 0 ;
8993 }
9094 // We load a block of size [stride A x TILE_STRIDE] for A
9195 for (int load_offset = 0 ; load_offset < TILE_STRIDE; load_offset += strideB) {
@@ -94,8 +98,9 @@ __global__ void
9498 const auto col = static_cast <int >(b_col + innerColB);
9599 const bool in_range = row < K && col < M;
96100 const auto idx_B = static_cast <int >((blockIdx .z * K * M) + row * M + col);
101+ const auto skewedColB = innerColB + (innerColB / 32 );
97102 // NOLINTNEXTLINE(*-array-index)
98- B_block[local_row * TILE_WIDTH + innerColB ] = in_range ? B[idx_B] : 0 ;
103+ B_block[local_row * B_STRIDE + skewedColB ] = in_range ? B[idx_B] : 0 ;
99104 }
100105
101106 // Wait for all threads to load data into the cache
@@ -111,11 +116,13 @@ __global__ void
111116 // Load TN + TM results into registers first
112117 for (int i = 0 ; i < TN; ++i) {
113118 const auto row_idx = thread_row * TN + i;
114- cached_A[i] = A_block[row_idx * TILE_STRIDE + dot_idx]; // NOLINT(*-array-index)
119+ cached_A[i] = A_block[row_idx * A_STRIDE + dot_idx]; // NOLINT(*-array-index)
115120 }
116121 for (int i = 0 ; i < TM; ++i) {
122+ const auto col_idx = thread_col + i * THREAD_COLS;
123+ const auto skewedColRead = col_idx + (col_idx / 32 );
117124 // NOLINTNEXTLINE(*-array-index)
118- cached_B[i] = B_block[dot_idx * TILE_WIDTH + thread_col * TM + i ];
125+ cached_B[i] = B_block[dot_idx * B_STRIDE + skewedColRead ];
119126 }
120127
121128 // Compute TN * TM results using the cached results
@@ -134,7 +141,7 @@ __global__ void
134141 for (int i = 0 ; i < TN; ++i) {
135142 for (int j = 0 ; j < TM; ++j) {
136143 const auto row = static_cast <int >(c_row + (thread_row * TN + i));
137- const auto col = static_cast <int >(c_col + (thread_col * TM + j));
144+ const auto col = static_cast <int >(c_col + (thread_col + j * THREAD_COLS ));
138145 if (row < N && col < M) {
139146 const auto idx_C = static_cast <int >((blockIdx .z * N * M) + row * M + col);
140147 C[idx_C] = results[i * TM + j]; // NOLINT(*-array-index)
0 commit comments