forked from xlite-dev/LeetCUDA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmat_transpose_cute.cu
More file actions
536 lines (474 loc) · 21.5 KB
/
mat_transpose_cute.cu
File metadata and controls
536 lines (474 loc) · 21.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
#include <cuda_runtime.h>
#include <stdio.h>
#include <torch/extension.h>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>
using namespace cute;
#define UNIT_BLK_SIZE 16
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
/* Optionally, you could also call cudaDeviceReset here */ \
exit(EXIT_FAILURE); \
} \
} while (0)
template <typename T, int BLK_M, int BLK_N, typename ThreadLayoutA,
typename ThreadLayoutB>
__global__ void mat_transpose_cute_reg_kernel(const T *pA, T *pB, int M, int N,
ThreadLayoutA tA,
ThreadLayoutB tB) {
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
auto mA = make_tensor(make_gmem_ptr(pA),
make_layout(make_shape(M, N), GenRowMajor{})); // (M, N)
auto mB = make_tensor(make_gmem_ptr(pB),
make_layout(make_shape(N, M), GenRowMajor{})); // (N, M)
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
auto cA = local_tile(make_identity_tensor(mA.shape()),
make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
Tensor tAgA = local_partition(gA, tA, tx);
Tensor tBgB = local_partition(gB, tB, tx);
Tensor tAcA = local_partition(cA, tA, tx);
Tensor tApA = make_tensor<bool>(tAcA.shape(), tAcA.stride());
CUTE_UNROLL
for (int i = 0; i < size<0>(tApA); i++) {
CUTE_UNROLL
for (int j = 0; j < size<1>(tApA); j++) {
tApA(i, j) = get<0>(tAcA(i, j)) < M && get<1>(tAcA(i, j)) < N;
}
}
copy_if(tApA, tAgA, tBgB);
}
void mat_transpose_cute_row2col_reg(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenColMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_col2row_reg(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
CUDA_CHECK(cudaGetLastError());
}
template <typename T, int BLK_M, int BLK_N, typename ThreadLayoutA,
typename ThreadLayoutB, typename SmemLayoutA, typename SmemLayoutB>
__global__ void
mat_transpose_cute_smem_kernel(const T *pA, T *pB, int M, int N,
ThreadLayoutA tA, ThreadLayoutB tB,
SmemLayoutA sA_layout, SmemLayoutB sB_layout) {
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
auto mA = make_tensor(make_gmem_ptr(pA),
make_layout(make_shape(M, N), GenRowMajor{})); // (M, N)
auto mB = make_tensor(make_gmem_ptr(pB),
make_layout(make_shape(N, M), GenRowMajor{})); // (N, M)
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
auto cA = local_tile(make_identity_tensor(mA.shape()),
make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto cB = local_tile(make_identity_tensor(mB.shape()),
make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
__shared__ T smem[BLK_M * BLK_N];
auto sA = make_tensor(make_smem_ptr(smem),
sA_layout); // (BM, BN)
auto sB = make_tensor(make_smem_ptr(smem),
sB_layout); // (BN, BM)
Tensor tAgA = local_partition(gA, tA, tx);
Tensor tBgB = local_partition(gB, tB, tx);
Tensor tAsA = local_partition(sA, tA, tx);
Tensor tBsB = local_partition(sB, tB, tx);
Tensor tAcA = local_partition(cA, tA, tx);
Tensor tBcB = local_partition(cB, tB, tx);
Tensor tApA = make_tensor<bool>(tAcA.shape(), tAcA.stride());
Tensor tBpB = make_tensor<bool>(tBcB.shape(), tBcB.stride());
CUTE_UNROLL
for (int i = 0; i < size<0>(tApA); i++) {
CUTE_UNROLL
for (int j = 0; j < size<1>(tApA); j++) {
tApA(i, j) = get<0>(tAcA(i, j)) < M && get<1>(tAcA(i, j)) < N;
}
}
CUTE_UNROLL
for (int i = 0; i < size<0>(tBpB); i++) {
CUTE_UNROLL
for (int j = 0; j < size<1>(tBpB); j++) {
tBpB(i, j) = get<0>(tBcB(i, j)) < N && get<1>(tBcB(i, j)) < M;
}
}
copy_if(tApA, tAgA, tAsA);
__syncthreads();
copy_if(tBpB, tBsB, tBgB);
}
constexpr int log2(int x) {
assert(x > 0);
return (x & (x - 1)) == 0 ? __builtin_ctz(x)
: (throw "x is not a power of 2", 0);
}
void mat_transpose_cute_col_smem(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenColMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
auto sA_layout = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto sB_layout = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_kernel<float, BM, BN, decltype(tA), decltype(tB),
decltype(sA_layout), decltype(sB_layout)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB,
sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_row_smem(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{});
auto sA_layout = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto sB_layout = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_kernel<float, BM, BN, decltype(tA), decltype(tB),
decltype(sA_layout), decltype(sB_layout)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB,
sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_col_smem_swizzled(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenColMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
const int S = log2(BM);
auto swizzle_func = Swizzle<S, 0, S>{};
auto sA_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{}));
auto sB_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{}));
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_kernel<float, BM, BN, decltype(tA), decltype(tB),
decltype(sA_layout), decltype(sB_layout)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB,
sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_row_smem_swizzled(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{});
const int S = log2(BM);
auto swizzle_func = Swizzle<S, 0, S>{};
auto sA_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{}));
auto sB_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{}));
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_kernel<float, BM, BN, decltype(tA), decltype(tB),
decltype(sA_layout), decltype(sB_layout)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB,
sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
__host__ __device__ inline bool is_aligned_128(const void *ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 0xF) == 0;
}
template <typename T, int BLK_M, int BLK_N, typename TiledCopyA,
typename TiledCopyB, typename SmemLayoutA, typename SmemLayoutB>
__global__ void mat_transpose_cute_smem_vectorized_kernel(
const T *pA, T *pB, int M, int N, TiledCopyA copy_a, TiledCopyB copy_b,
SmemLayoutA sA_layout, SmemLayoutB sB_layout) {
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
auto mA = make_tensor(make_gmem_ptr(pA),
make_layout(make_shape(M, N), GenRowMajor{})); // (M, N)
auto mB = make_tensor(make_gmem_ptr(pB),
make_layout(make_shape(N, M), GenRowMajor{})); // (N, N)
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
__shared__ T smem[BLK_M * BLK_N];
auto sA = make_tensor(make_smem_ptr(smem),
sA_layout); // (BM, BN)
auto sB = make_tensor(make_smem_ptr(smem),
sB_layout); // (BN, BM)
auto thr_copy_a = copy_a.get_slice(tx);
Tensor tAgA = thr_copy_a.partition_S(gA);
Tensor tAsA = thr_copy_a.partition_D(sA);
auto thr_copy_b = copy_b.get_slice(tx);
Tensor tBsB = thr_copy_b.partition_S(sB);
Tensor tBgB = thr_copy_b.partition_D(gB);
copy(copy_a, tAgA, tAsA);
__syncthreads();
copy(copy_b, tBsB, tBgB);
}
void mat_transpose_cute_row_cvectorized(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE * 4;
const int BN = UNIT_BLK_SIZE;
auto ptr_A = x.data_ptr<float>();
auto ptr_B = y.data_ptr<float>();
const int M = x.size(0);
const int N = x.size(1);
// sanity checks
assert(M % 4 == 0);
assert(N % 4 == 0);
static_assert(BM % 4 == 0);
static_assert(BN % 4 == 0);
assert(is_aligned_128(ptr_A));
assert(is_aligned_128(ptr_B));
auto tile_copy_a = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BM / 4>{}, Int<BN>{}), GenRowMajor{}),
make_layout(make_shape(Int<4>{}, Int<1>{}), GenRowMajor{}));
auto tile_copy_b = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BN>{}, Int<BM / 4>{}), GenRowMajor{}),
make_layout(make_shape(Int<1>{}, Int<4>{}), GenRowMajor{}));
auto sA_layout = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto sB_layout = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tile_copy_a) == size(tile_copy_b));
dim3 block(size(tile_copy_a));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_vectorized_kernel<
float, BM, BN, decltype(tile_copy_a), decltype(tile_copy_b),
decltype(sA_layout), decltype(sB_layout)><<<grid, block>>>(
ptr_A, ptr_B, M, N, tile_copy_a, tile_copy_b, sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_row_cvectorized_swizzled(torch::Tensor x,
torch::Tensor y) {
const int BM = UNIT_BLK_SIZE * 4;
const int BN = UNIT_BLK_SIZE;
auto ptr_A = x.data_ptr<float>();
auto ptr_B = y.data_ptr<float>();
const int M = x.size(0);
const int N = x.size(1);
// sanity checks
assert(M % 4 == 0);
assert(N % 4 == 0);
static_assert(BM % 4 == 0);
static_assert(BN % 4 == 0);
assert(is_aligned_128(ptr_A));
assert(is_aligned_128(ptr_B));
auto tile_copy_a = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BM / 4>{}, Int<BN>{}), GenRowMajor{}),
make_layout(make_shape(Int<4>{}, Int<1>{}), GenRowMajor{}));
auto tile_copy_b = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BN>{}, Int<BM / 4>{}), GenRowMajor{}),
make_layout(make_shape(Int<1>{}, Int<4>{}), GenRowMajor{}));
const int S = log2(BN);
auto swizzle_func = Swizzle<S, 0, S>{};
auto sA_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{}));
auto sB_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{}));
static_assert(size(tile_copy_a) == size(tile_copy_b));
dim3 block(size(tile_copy_a));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_vectorized_kernel<
float, BM, BN, decltype(tile_copy_a), decltype(tile_copy_b),
decltype(sA_layout), decltype(sB_layout)><<<grid, block>>>(
ptr_A, ptr_B, M, N, tile_copy_a, tile_copy_b, sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_row_rvectorized(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE * 4;
auto ptr_A = x.data_ptr<float>();
auto ptr_B = y.data_ptr<float>();
const int M = x.size(0);
const int N = x.size(1);
// sanity checks
assert(M % 4 == 0);
assert(N % 4 == 0);
static_assert(BM % 4 == 0);
static_assert(BN % 4 == 0);
assert(is_aligned_128(ptr_A));
assert(is_aligned_128(ptr_B));
auto tile_copy_a = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BM>{}, Int<BN / 4>{}), GenRowMajor{}),
make_layout(make_shape(Int<1>{}, Int<4>{}), GenRowMajor{}));
auto tile_copy_b = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BN / 4>{}, Int<BM>{}), GenRowMajor{}),
make_layout(make_shape(Int<4>{}, Int<1>{}), GenRowMajor{}));
auto sA_layout = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto sB_layout = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tile_copy_a) == size(tile_copy_b));
dim3 block(size(tile_copy_a));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_vectorized_kernel<
float, BM, BN, decltype(tile_copy_a), decltype(tile_copy_b),
decltype(sA_layout), decltype(sB_layout)><<<grid, block>>>(
ptr_A, ptr_B, M, N, tile_copy_a, tile_copy_b, sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_row_rvectorized_swizzled(torch::Tensor x,
torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE * 4;
auto ptr_A = x.data_ptr<float>();
auto ptr_B = y.data_ptr<float>();
const int M = x.size(0);
const int N = x.size(1);
// sanity checks
assert(M % 4 == 0);
assert(N % 4 == 0);
static_assert(BM % 4 == 0);
static_assert(BN % 4 == 0);
assert(is_aligned_128(ptr_A));
assert(is_aligned_128(ptr_B));
auto tile_copy_a = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BM>{}, Int<BN / 4>{}), GenRowMajor{}),
make_layout(make_shape(Int<1>{}, Int<4>{}), GenRowMajor{}));
auto tile_copy_b = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BN / 4>{}, Int<BM>{}), GenRowMajor{}),
make_layout(make_shape(Int<4>{}, Int<1>{}), GenRowMajor{}));
const int S = log2(BM);
auto swizzle_func = Swizzle<S, 0, S>{};
auto sA_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{}));
auto sB_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{}));
static_assert(size(tile_copy_a) == size(tile_copy_b));
dim3 block(size(tile_copy_a));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_vectorized_kernel<
float, BM, BN, decltype(tile_copy_a), decltype(tile_copy_b),
decltype(sA_layout), decltype(sB_layout)><<<grid, block>>>(
ptr_A, ptr_B, M, N, tile_copy_a, tile_copy_b, sA_layout, sB_layout);
CUDA_CHECK(cudaGetLastError());
}
template <typename T, int BLK_M, int BLK_N, typename TiledCopyA,
typename TiledCopyTrans, typename TiledCopyB, typename SmemLayoutB>
__global__ void mat_transpose_cute_smem_vectorized_optimized_kernel(
const T *pA, T *pB, int M, int N, TiledCopyA copy_a,
TiledCopyTrans copy_trans, TiledCopyB copy_b, SmemLayoutB sB_layout) {
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
auto mA = make_tensor(make_gmem_ptr(pA),
make_layout(make_shape(M, N), GenRowMajor{})); // (M, N)
auto mB = make_tensor(make_gmem_ptr(pB),
make_layout(make_shape(N, M), GenRowMajor{})); // (N, N)
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
__shared__ T smem[BLK_M * BLK_N];
auto sB = make_tensor(make_smem_ptr(smem),
sB_layout); // (BN, BM)
auto thr_copy_a = copy_a.get_slice(tx);
Tensor tAgA = thr_copy_a.partition_S(gA);
auto tAsA = make_tensor_like(tAgA);
Tensor tAsA_view = thr_copy_a.retile_D(tAsA);
copy(copy_a, tAgA, tAsA_view);
auto thr_copy_trans = copy_trans.get_slice(tx);
auto tAsB = thr_copy_trans.retile_S(tAsA);
auto tBsB_trans = thr_copy_trans.partition_D(sB);
copy(copy_trans, tAsB, tBsB_trans);
auto thr_copy_b = copy_b.get_slice(tx);
Tensor tBsB = thr_copy_b.partition_S(sB);
Tensor tBgB = thr_copy_b.partition_D(gB);
copy(copy_b, tBsB, tBgB);
}
void mat_transpose_cute_row_rvectorized_swizzled_optimized(torch::Tensor x,
torch::Tensor y) {
const int BM = 8;
const int BN = 16 * 8;
auto ptr_A = x.data_ptr<float>();
auto ptr_B = y.data_ptr<float>();
const int M = x.size(0);
const int N = x.size(1);
// sanity checks
assert(M % 4 == 0);
assert(N % 4 == 0);
static_assert(BM % 4 == 0);
static_assert(BN % 4 == 0);
assert(is_aligned_128(ptr_A));
assert(is_aligned_128(ptr_B));
// 一次性加载8*16大小的矩阵
auto tile_copy_a = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BM>{}, make_shape(Int<4>{}, Int<BN / 16>{})),
make_stride(Int<4>{}, make_stride(Int<1>{}, Int<32>{}))),
make_layout(make_shape(Int<1>{}, Int<4>{}), GenRowMajor{}));
// 转换数据
auto tile_copy_trans = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(make_shape(Int<4>{}, Int<BN / 16>{}), Int<BM>{}),
make_stride(make_stride(Int<1>{}, Int<32>{}), Int<4>{})),
make_layout(make_shape(Int<4>{}, Int<1>{}), GenRowMajor{}));
// 一次性存储16*8大小的矩阵
auto tile_copy_b = make_tiled_copy(
Copy_Atom<AutoVectorizingCopy, float>{},
make_layout(make_shape(Int<BN>{}, Int<BM / 4>{}), GenRowMajor{}),
make_layout(make_shape(Int<1>{}, Int<4>{}), GenRowMajor{}));
auto swizzle_func = Swizzle<2, 3, 2>{};
auto sB_layout =
composition(swizzle_func,
make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{}));
static_assert(size(tile_copy_a) == size(tile_copy_b));
dim3 block(size(tile_copy_a));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_smem_vectorized_optimized_kernel<
float, BM, BN, decltype(tile_copy_a), decltype(tile_copy_trans),
decltype(tile_copy_b), decltype(sB_layout)><<<grid, block>>>(
ptr_A, ptr_B, M, N, tile_copy_a, tile_copy_trans, tile_copy_b, sB_layout);
CUDA_CHECK(cudaGetLastError());
}