forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathargsort_kernel.cu
More file actions
496 lines (451 loc) · 17.7 KB
/
argsort_kernel.cu
File metadata and controls
496 lines (451 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/argsort_kernel.h"
#include <thrust/copy.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/cub.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#ifdef __HIPCC__
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<phi::float16>
: radix_key_codec_integral<phi::float16, uint16_t> {};
template <>
struct radix_key_codec_base<phi::bfloat16>
: radix_key_codec_integral<phi::bfloat16, uint16_t> {};
#if HIP_VERSION >= 50400000
template <>
struct float_bit_mask<phi::float16> : float_bit_mask<rocprim::half> {};
template <>
struct float_bit_mask<phi::bfloat16> : float_bit_mask<rocprim::bfloat16> {};
#endif
} // namespace detail
} // namespace rocprim
#else
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<phi::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::float16> {};
template <>
struct NumericTraits<phi::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::bfloat16> {};
} // namespace cub
#endif
namespace phi {
// Iter for move to next row
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
template <typename T, typename IndType>
__global__ void merge_kernel(const T* A,
size_t sizeA,
const T* B,
size_t sizeB,
const IndType* ids_A,
const IndType* ids_B,
T* out,
IndType* out_ids,
bool descending) {
int64_t thread =
static_cast<int64_t>(blockDim.x) * static_cast<int64_t>(gridDim.x);
int64_t num_per_thread = (sizeA + sizeB + thread) / thread;
for (int64_t offset = 0; offset < num_per_thread; offset++) {
size_t idx =
static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
static_cast<size_t>(threadIdx.x) + offset * thread;
size_t total = sizeA + sizeB;
if (idx >= total) return;
size_t left = (idx > sizeB) ? idx - sizeB : 0;
size_t right = (idx < sizeA) ? idx : sizeA;
while (left < right) {
size_t mid = (left + right) / 2;
size_t b_idx = idx - mid;
T A_mid, B_bidx;
if (descending) {
A_mid = (mid >= sizeA) ? std::numeric_limits<T>::lowest() : A[mid];
B_bidx = (b_idx >= sizeB) ? std::numeric_limits<T>::lowest() : B[b_idx];
} else {
A_mid = (mid >= sizeA) ? std::numeric_limits<T>::max() : A[mid];
B_bidx = (b_idx >= sizeB) ? std::numeric_limits<T>::max() : B[b_idx];
}
if (descending ? (A_mid >= B_bidx) : (A_mid <= B_bidx))
left = mid + 1;
else
right = mid;
}
size_t a_idx = left;
size_t b_idx = idx - a_idx;
if (a_idx >= sizeA) {
if (descending ? (A[sizeA - 1] < B[b_idx]) : (A[sizeA - 1] > B[b_idx])) {
out[idx] = A[sizeA - 1];
out_ids[idx] = ids_A[sizeA - 1];
} else {
out[idx] = B[b_idx];
out_ids[idx] = ids_B[b_idx];
}
} else if (b_idx >= sizeB) {
out[idx] = A[a_idx];
out_ids[idx] = ids_A[a_idx];
} else {
if (descending ? (A[a_idx] >= B[b_idx]) : (A[a_idx] <= B[b_idx])) {
out[idx] = A[a_idx];
out_ids[idx] = ids_A[a_idx];
} else if (descending ? (a_idx > 0 && (A[a_idx - 1] < B[b_idx]))
: (a_idx > 0 && (A[a_idx - 1] > B[b_idx]))) {
out[idx] = A[a_idx - 1];
out_ids[idx] = ids_A[a_idx - 1];
} else {
out[idx] = B[b_idx];
out_ids[idx] = ids_B[b_idx];
}
}
}
}
template <typename T>
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (T j = row_id; j < num_rows; j += gridDim.x) {
for (T i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
#define CUB_ARGSORT_WRAPPER(func, ...) \
{ \
size_t temp_storage_bytes = 0; \
PADDLE_ENFORCE_GPU_SUCCESS( \
func(nullptr, temp_storage_bytes, __VA_ARGS__)); \
DenseTensor temp_storage; \
int64_t temp_size = static_cast<int64_t>(temp_storage_bytes); \
PADDLE_ENFORCE_GT( \
temp_size, \
0, \
common::errors::InvalidArgument( \
"Argsort temp storage size is %d, but should be greater than 0.", \
temp_size)); \
temp_storage.Resize({temp_size}); \
dev_ctx.template Alloc<uint8_t>(&temp_storage); \
PADDLE_ENFORCE_GPU_SUCCESS( \
func(temp_storage.data<uint8_t>(), temp_storage_bytes, __VA_ARGS__)); \
}
#define PREDICATE_CUB_ARGSORT(predicate, if_func, else_func, ...) \
if (predicate) \
CUB_ARGSORT_WRAPPER(if_func, __VA_ARGS__) \
else \
CUB_ARGSORT_WRAPPER(else_func, __VA_ARGS__)
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
void ArgFullSort(const GPUContext& dev_ctx,
const DenseTensor* input,
DenseTensor* output,
DenseTensor* indices,
const int64_t num_rows,
const int64_t num_cols,
const bool descending) {
PADDLE_ENFORCE_LE_INT_MAX(num_cols, "num_cols");
auto cu_stream = dev_ctx.stream();
auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else
return 128;
};
const int block_size = ComputeBlockSize(num_cols);
const int64_t maxGridDimX = dev_ctx.GetCUDAMaxGridDimSize()[0];
const T* inp = input->data<T>();
IndType* sorted_indices_ptr = indices->data<IndType>();
// create iter for counting input
cub::CountingInputIterator<IndType> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<IndType,
SegmentOffsetIter,
cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
// num_rows is the total segments to be sorted
constexpr int64_t max_elements = 1 << 30;
const int64_t total_elements = num_cols * num_rows;
const int64_t segment_size = num_cols;
const int64_t element_per_call = std::min(max_elements, total_elements);
// make sure element_per_call >= segment_size
const int64_t adjusted_elements_per_call =
std::max(max_elements, segment_size);
// make sure batch size is the multiple of segment_size
const int64_t batch_size =
(adjusted_elements_per_call / segment_size) * segment_size;
int64_t offset = 0;
DenseTensor input_indices;
T* sorted_out_ptr = sorted_out_ptr = output->data<T>();
IndType* ind_ptr = nullptr;
while (offset < total_elements) {
const int64_t n_elements = std::min(batch_size, total_elements - offset);
const int64_t n_segments = n_elements / segment_size;
// allocate a temporary storage for input indices, with shape:
// [num_segments = n_elements / segment_size, segment_size]
// will be de-allocated once the sort is done, to save memory and
// avoid repeated allocation and deallocation
if (input_indices.initialized()) {
ind_ptr = input_indices.data<IndType>();
} else {
input_indices.Resize({n_segments, segment_size});
ind_ptr = dev_ctx.template Alloc<IndType>(&input_indices);
}
const int64_t grid_size = std::min(n_segments, maxGridDimX);
// Init a index array
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
ind_ptr, n_segments, segment_size);
PREDICATE_CUB_ARGSORT(descending,
cub::DeviceSegmentedRadixSort::SortPairsDescending,
cub::DeviceSegmentedRadixSort::SortPairs,
inp + offset,
sorted_out_ptr + offset,
ind_ptr,
sorted_indices_ptr + offset,
n_elements,
n_segments,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
offset += n_elements;
}
}
template <typename T, typename IndType>
void PerSort(const GPUContext& dev_ctx,
T* out_data,
int64_t* ids_data,
IndType start,
IndType end,
bool stable,
bool descending) {
#ifdef PADDLE_WITH_CUDA
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(
dev_ctx.GetPlace(), dev_ctx.stream());
const auto& exec_policy =
thrust::cuda::par(allocator).on(dev_ctx.stream());
#else
phi::memory_utils::ThrustAllocator<hipStream_t> allocator(
dev_ctx.GetPlace(), dev_ctx.stream());
const auto& exec_policy =
thrust::hip::par(allocator).on(dev_ctx.stream());
#endif
if (stable) {
if (descending) {
thrust::stable_sort_by_key(exec_policy,
out_data + start,
out_data + end,
ids_data + start,
thrust::greater<T>());
} else {
thrust::stable_sort_by_key(
exec_policy, out_data + start, out_data + end, ids_data + start);
}
return;
} else {
thrust::sort_by_key(
exec_policy, out_data + start, out_data + end, ids_data + start);
if (descending) {
thrust::reverse(exec_policy, out_data + start, out_data + end);
thrust::reverse(exec_policy, ids_data + start, ids_data + end);
}
return;
}
}
template <typename T, typename Context>
void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
auto rank = in_dims.size();
if (input.numel() == 0) {
output->Resize(in_dims);
indices->Resize(in_dims);
dev_ctx.template Alloc<T>(output);
dev_ctx.template Alloc<int64_t>(indices);
return;
}
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input.data<T>();
auto size = input.numel();
if (rank == 0) {
dev_ctx.template Alloc<T>(output);
dev_ctx.template Alloc<int64_t>(indices);
Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
funcs::set_constant(dev_ctx, indices, static_cast<int64_t>(0));
return;
}
// Use thrust for parallel acceleration when the input size is equal to the
// length of the 'axis' dimension.
// Compared to the following 'Special case for full sort', ascending sort is
// 34 times faster and descending sort is 31 times faster.
if (size == in_dims[axis]) {
T* out_data = dev_ctx.template Alloc<T>(output);
int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);
#ifdef PADDLE_WITH_CUDA
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(
dev_ctx.GetPlace(), dev_ctx.stream());
const auto& exec_policy =
thrust::cuda::par(allocator).on(dev_ctx.stream());
#else
phi::memory_utils::ThrustAllocator<hipStream_t> allocator(
dev_ctx.GetPlace(), dev_ctx.stream());
const auto& exec_policy =
thrust::hip::par(allocator).on(dev_ctx.stream());
#endif
auto cu_stream = dev_ctx.stream();
thrust::sequence(exec_policy, ids_data, ids_data + size);
thrust::copy(exec_policy, in_data, in_data + size, out_data);
const int64_t per_number = (1LL << 31) - 1;
int64_t start = 0;
int64_t end = std::min(start + per_number, size);
if (end == size) {
PerSort<T, int64_t>(
dev_ctx, out_data, ids_data, start, end, stable, descending);
} else {
// Sorting the segments and then merging them
DenseTensor temp;
DenseTensor ids;
temp.Resize(in_dims);
ids.Resize(in_dims);
T* temp_data = dev_ctx.template Alloc<T>(&temp);
int64_t* temp_ids = dev_ctx.template Alloc<int64_t>(&ids);
while (start != size) {
PerSort<T, int64_t>(
dev_ctx, out_data, ids_data, start, end, stable, descending);
if (start != 0) {
auto config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, end);
merge_kernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
cu_stream>>>(out_data,
start,
out_data + start,
end - start,
ids_data,
ids_data + start,
temp_data,
temp_ids,
descending);
thrust::copy(exec_policy, temp_ids, temp_ids + end, ids_data);
thrust::copy(exec_policy, temp_data, temp_data + end, out_data);
}
start = end;
end = std::min(start + per_number, size);
}
}
return;
}
// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height =
common::product(slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
dev_ctx.template Alloc<int64_t>(indices);
dev_ctx.template Alloc<T>(output);
ArgFullSort<T, int64_t>(dev_ctx,
&input,
output,
indices,
input_height,
input_width,
descending);
} else {
// if not full sort, do transpose first
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
DenseTensor trans_inp;
trans_inp.Resize(trans_dims);
T* trans_inp_data = dev_ctx.template Alloc<T>(&trans_inp);
// Do transpose
TransposeKernel<T, Context>(dev_ctx, input, trans, &trans_inp);
const int64_t input_height =
common::product(slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
DenseTensor tmp_out;
tmp_out.Resize(trans_dims);
dev_ctx.template Alloc<T>(&tmp_out);
DenseTensor tmp_indices;
// temp indices for sorting
tmp_indices.Resize(trans_dims);
dev_ctx.template Alloc<int64_t>(&tmp_indices);
ArgFullSort<T, int64_t>(dev_ctx,
&trans_inp,
&tmp_out,
&tmp_indices,
input_height,
input_width,
descending);
// delay output allocation until after transpose, to avoid
// allocating too much memory
dev_ctx.template Alloc<T>(output);
dev_ctx.template Alloc<int64_t>(indices);
// transpose back
TransposeKernel<T, Context>(dev_ctx, tmp_out, trans, output);
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
return;
}
}
} // namespace phi
PD_REGISTER_KERNEL(argsort,
GPU,
ALL_LAYOUT,
phi::ArgsortKernel,
float,
double,
int,
int64_t,
uint8_t,
int16_t,
phi::float16,
phi::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}