Skip to content

Commit 58b0fd8

Browse files
authored
[PHI] Fix cum kernel for big tensor (PaddlePaddle#72562)
1 parent 5861e14 commit 58b0fd8

File tree

1 file changed

+101
-125
lines changed

1 file changed

+101
-125
lines changed

paddle/phi/kernels/gpu/cum_kernel.cu

+101-125
Original file line numberDiff line numberDiff line change
@@ -36,53 +36,23 @@ namespace cub = hipcub;
3636

3737
namespace phi {
3838

39-
template <typename T, int BLOCK_SIZE>
40-
__device__ void BlockReverse(
41-
const T* idata, T* odata, int src_base, int dst_base, int valid_item) {
42-
__shared__ T sh_mem[BLOCK_SIZE];
43-
int tx = threadIdx.x;
44-
45-
int offset = tx;
46-
T src_data = static_cast<T>(0);
47-
int src_offset = BLOCK_SIZE - offset - 1;
48-
if (src_offset < valid_item) {
49-
src_data = idata[src_base + src_offset];
50-
}
51-
sh_mem[offset] = src_data;
52-
53-
__syncthreads();
54-
int out_index = dst_base - offset;
55-
if (offset < valid_item) {
56-
int sh_mem_index = BLOCK_SIZE - offset - 1;
57-
odata[out_index] = sh_mem[sh_mem_index];
58-
}
59-
}
60-
6139
template <typename T>
6240
__global__ void MatrixRowReverse(const T* matrix_data,
6341
T* reverse_data,
64-
int reverse_size,
65-
int outer_size,
66-
int inner_size) {
67-
int bx = blockIdx.x;
68-
int by = blockIdx.y;
42+
int64_t grid_size,
43+
int64_t reverse_size) {
6944
int item_per_block = 1024;
70-
71-
for (int block_offset = 0; block_offset < reverse_size;
72-
block_offset += item_per_block) {
73-
int valid_item = (reverse_size - block_offset > item_per_block)
74-
? item_per_block
75-
: reverse_size - block_offset;
76-
int src_offset =
77-
bx * reverse_size + block_offset + by * (inner_size * reverse_size);
78-
int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) +
79-
reverse_size - 1 - block_offset;
80-
if (reverse_size < item_per_block) {
81-
valid_item = reverse_size;
45+
for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
46+
for (int64_t block_offset = 0; block_offset < reverse_size;
47+
block_offset += item_per_block) {
48+
int64_t reverse_offset = block_offset + threadIdx.x;
49+
int64_t src_offset = bx * reverse_size + reverse_offset;
50+
int64_t dst_offset =
51+
bx * reverse_size + (reverse_size - reverse_offset - 1);
52+
if (reverse_offset < reverse_size) {
53+
reverse_data[dst_offset] = matrix_data[src_offset];
54+
}
8255
}
83-
84-
BlockReverse<T, 1024>(
85-
matrix_data, reverse_data, src_offset, dst_offset, valid_item);
8656
}
8757
}
8858

@@ -112,24 +82,30 @@ __global__ void MatrixTranspose(T* odata,
11282
size_t width) {
11383
__shared__ T tile[TILE_DIM][TILE_DIM + 1];
11484

115-
int x = blockIdx.x * TILE_DIM + threadIdx.x;
116-
int y = blockIdx.y * TILE_DIM + threadIdx.y;
117-
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
118-
if (x < width && (y + j) < height) {
119-
tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x];
120-
} else {
121-
tile[threadIdx.y + j][threadIdx.x] = 0;
122-
}
123-
}
85+
int64_t wblocks = (width + TILE_DIM - 1) / TILE_DIM;
86+
int64_t hblocks = (height + TILE_DIM - 1) / TILE_DIM;
87+
88+
int64_t block_i = blockIdx.x;
89+
for (; block_i < wblocks * hblocks; block_i += gridDim.x) {
90+
int64_t block_y = block_i / wblocks;
91+
int64_t block_x = block_i % wblocks;
92+
int64_t x = block_x * TILE_DIM + threadIdx.x;
93+
int64_t y = block_y * TILE_DIM + threadIdx.y;
12494

125-
__syncthreads();
95+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
96+
if (x < width && (y + j) < height) {
97+
tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x];
98+
}
99+
}
100+
__syncthreads();
126101

127-
x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
128-
y = blockIdx.x * TILE_DIM + threadIdx.y;
102+
x = block_y * TILE_DIM + threadIdx.x; // transpose block offset
103+
y = block_x * TILE_DIM + threadIdx.y;
129104

130-
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
131-
if (x < height && (y + j) < width) {
132-
odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j];
105+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
106+
if (x < height && (y + j) < width) {
107+
odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j];
108+
}
133109
}
134110
}
135111
}
@@ -172,9 +148,8 @@ struct Identity<T, ComplexSum> {
172148
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
173149
__global__ void BlockScanKernel(T* d_out,
174150
const T* d_in,
175-
int inner_size,
176-
int outer_size,
177-
int scan_size,
151+
int64_t grid_size,
152+
int64_t scan_size,
178153
bool exclusive,
179154
Op op) {
180155
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -196,38 +171,40 @@ __global__ void BlockScanKernel(T* d_out,
196171
typename BlockScanT::TempStorage scan;
197172
} temp_storage;
198173

199-
int bx = blockIdx.x;
200-
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
201-
202174
// Obtain this block's segment of consecutive keys (blocked across threads)
203-
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
204-
for (int block_offset = 0; block_offset < scan_size;
205-
block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
206-
int valid_item = (scan_size - block_offset > item_per_block)
207-
? item_per_block
208-
: (scan_size - block_offset);
209-
if (scan_size < item_per_block) {
210-
valid_item = scan_size;
175+
int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
176+
177+
for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
178+
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
179+
180+
for (int64_t block_offset = 0; block_offset < scan_size;
181+
block_offset += item_per_block) {
182+
int64_t valid_item = (scan_size - block_offset > item_per_block)
183+
? item_per_block
184+
: (scan_size - block_offset);
185+
if (scan_size < item_per_block) {
186+
valid_item = scan_size;
187+
}
188+
189+
int64_t offset = bx * scan_size + block_offset;
190+
191+
MT thread_keys[ITEMS_PER_THREAD];
192+
BlockLoadT(temp_storage.load)
193+
.Load(d_in + offset, thread_keys, valid_item, 0);
194+
195+
__syncthreads();
196+
if (exclusive) {
197+
BlockScanT(temp_storage.scan)
198+
.ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
199+
} else {
200+
BlockScanT(temp_storage.scan)
201+
.InclusiveScan(thread_keys, thread_keys, op, prefix_op);
202+
}
203+
__syncthreads();
204+
205+
BlockStoreT(temp_storage.store)
206+
.Store(d_out + offset, thread_keys, valid_item);
211207
}
212-
213-
int offset = block_offset + bx * scan_size;
214-
215-
MT thread_keys[ITEMS_PER_THREAD];
216-
BlockLoadT(temp_storage.load)
217-
.Load(d_in + offset, thread_keys, valid_item, 0);
218-
219-
__syncthreads();
220-
if (exclusive) {
221-
BlockScanT(temp_storage.scan)
222-
.ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
223-
} else {
224-
BlockScanT(temp_storage.scan)
225-
.InclusiveScan(thread_keys, thread_keys, op, prefix_op);
226-
}
227-
__syncthreads();
228-
229-
BlockStoreT(temp_storage.store)
230-
.Store(d_out + offset, thread_keys, valid_item);
231208
}
232209
}
233210

@@ -347,14 +324,24 @@ void ScanKernel(const Context& dev_ctx,
347324
int scan_size = out_dims[axis];
348325
bool transpose = (axis != out_dims.size() - 1);
349326

350-
int tile_size = 32;
351-
dim3 blocks(32, 8);
352-
dim3 transpose_grids((width + tile_size - 1) / tile_size,
353-
(height + tile_size - 1) / tile_size);
354327
DenseTensor tmp_tensor;
355328
tmp_tensor.Resize(out_dims);
356329
auto* tmp_data = dev_ctx.template Alloc<T>(&tmp_tensor);
357330

331+
auto swap_ptr = [](T*& ptr1, T*& ptr2) {
332+
T* tmp = ptr2;
333+
ptr2 = ptr1;
334+
ptr1 = tmp;
335+
};
336+
337+
int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize()[0];
338+
339+
// Do pre-process transpose
340+
int tile_size = 32;
341+
dim3 blocks(32, 8);
342+
int64_t transpose_grids = ((width + tile_size - 1) / tile_size) *
343+
((height + tile_size - 1) / tile_size);
344+
transpose_grids = std::min(transpose_grids, max_grid_x);
358345
T* next_in_data = out_data;
359346
T* next_out_data = tmp_data;
360347
if (transpose) {
@@ -363,53 +350,42 @@ void ScanKernel(const Context& dev_ctx,
363350
next_in_data = out_data;
364351
next_out_data = tmp_data;
365352
}
366-
auto swap_ptr = [](T*& ptr1, T*& ptr2) {
367-
T* tmp = ptr2;
368-
ptr2 = ptr1;
369-
ptr1 = tmp;
370-
};
371-
int outer_size = height / scan_size;
372-
int inner_size = width;
373-
// Consider the size of shared memory, here block size is 128
374-
dim3 scan_grid(outer_size, inner_size);
375-
dim3 reverse_grid = scan_grid;
353+
354+
// Do pre-process reverse
355+
int64_t outer_size = height / scan_size;
356+
int64_t inner_size = width;
357+
int64_t grid_size = outer_size * inner_size;
358+
int64_t scan_grid = std::min(grid_size, max_grid_x);
376359
if (reverse) {
377360
if (transpose) {
378-
reverse_grid.x = scan_grid.y;
379-
reverse_grid.y = scan_grid.x;
380-
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
381-
next_in_data, next_out_data, scan_size, outer_size, inner_size);
361+
MatrixRowReverse<T><<<scan_grid, 1024, 0, dev_ctx.stream()>>>(
362+
next_in_data, next_out_data, grid_size, scan_size);
382363
if (!transpose) next_in_data = tmp_data;
383364
swap_ptr(next_in_data, next_out_data);
384365
} else {
385-
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
386-
in_data, out_data, scan_size, outer_size, inner_size);
366+
MatrixRowReverse<T><<<scan_grid, 1024, 0, dev_ctx.stream()>>>(
367+
in_data, out_data, grid_size, scan_size);
387368
}
388369
}
389-
int64_t grid_size = outer_size * inner_size;
370+
371+
// Do scan
390372
if (!transpose && !reverse) {
391-
BlockScanKernel<T, 128, 4, Op><<<grid_size, 128, 0, dev_ctx.stream()>>>(
392-
out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);
373+
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
374+
out_data, in_data, grid_size, scan_size, exclusive, op);
393375

394376
} else {
395-
BlockScanKernel<T, 128, 4, Op>
396-
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
397-
next_in_data,
398-
outer_size,
399-
inner_size,
400-
scan_size,
401-
exclusive,
402-
op);
377+
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
378+
next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
403379
}
404380
swap_ptr(next_in_data, next_out_data);
381+
382+
// Do post-process reverse and transpose
405383
if (reverse) {
406-
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
407-
next_in_data, next_out_data, scan_size, outer_size, inner_size);
384+
MatrixRowReverse<T><<<scan_grid, 1024, 0, dev_ctx.stream()>>>(
385+
next_in_data, next_out_data, grid_size, scan_size);
408386
swap_ptr(next_in_data, next_out_data);
409387
}
410388
if (transpose) {
411-
transpose_grids.x = (height + tile_size - 1) / tile_size;
412-
transpose_grids.y = (width + tile_size - 1) / tile_size;
413389
MatrixTranspose<T, 32, 8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
414390
next_out_data, next_in_data, width, height);
415391
}

0 commit comments

Comments
 (0)