Skip to content

Commit 23b3f79

Browse files
committed
perf(autoware_tensorrt_plugins): remove Thrust from sort kernels
1 parent 54af299 commit 23b3f79

8 files changed

Lines changed: 417 additions & 69 deletions

File tree

perception/autoware_tensorrt_plugins/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,25 @@ if(TRT_AVAIL AND CUDA_AVAIL AND SPCONV_AVAIL)
147147
spconv::spconv
148148
)
149149

150+
if(BUILD_TESTING)
151+
find_package(ament_cmake_gtest REQUIRED)
152+
153+
ament_add_gtest(reference_kernels_test
154+
test/reference_kernels_test.cpp
155+
)
156+
if(TARGET reference_kernels_test)
157+
target_link_libraries(reference_kernels_test
158+
CUDA::cudart
159+
cuda_ops
160+
)
161+
target_include_directories(reference_kernels_test PRIVATE
162+
include
163+
${CUDA_INCLUDE_DIRS}
164+
)
165+
target_compile_definitions(reference_kernels_test PRIVATE _GLIBCXX_USE_CXX11_ABI=1)
166+
endif()
167+
endif()
168+
150169
install(
151170
TARGETS ${PROJECT_NAME}
152171
DESTINATION share/${PROJECT_NAME}/plugins

perception/autoware_tensorrt_plugins/include/autoware/unique_ops/unique.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ std::int64_t unique(
2424
std::int64_t * unique_counts, void * workspace, std::size_t num_input_elements,
2525
std::size_t workspace_size, cudaStream_t stream);
2626

27+
std::size_t get_unique_temp_storage_size(std::size_t num_elements);
2728
std::size_t get_unique_workspace_size(std::size_t num_elements);
2829

2930
#endif // AUTOWARE__UNIQUE_OPS__UNIQUE_HPP_

perception/autoware_tensorrt_plugins/package.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
<depend>autoware_cuda_utils</depend>
2222

2323
<test_depend>ament_cmake_ros</test_depend>
24+
<test_depend>ament_cmake_gtest</test_depend>
2425
<test_depend>ament_lint_auto</test_depend>
2526
<test_depend>autoware_lint_common</test_depend>
2627

perception/autoware_tensorrt_plugins/src/argsort_ops/argsort.cu

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,52 @@
1616

1717
#include <cub/cub.cuh>
1818

19-
#include <thrust/device_ptr.h>
20-
#include <thrust/execution_policy.h>
21-
#include <thrust/sequence.h>
19+
namespace
20+
{
21+
22+
constexpr int kThreadsPerBlock = 256;
23+
24+
std::size_t align_up(const std::size_t size, const std::size_t alignment)
25+
{
26+
return ((size + alignment - 1U) / alignment) * alignment;
27+
}
28+
29+
__global__ void fill_iota(std::int64_t * output, const std::size_t num_elements)
30+
{
31+
const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
32+
if (index >= num_elements) {
33+
return;
34+
}
35+
36+
output[index] = static_cast<std::int64_t>(index);
37+
}
38+
39+
} // namespace
2240

2341
cudaError_t argsort(
2442
const std::int64_t * input_d, std::int64_t * output_d, void * workspace, std::size_t num_elements,
2543
std::size_t argsort_workspace_size, cudaStream_t stream)
2644
{
27-
int workspace_offset = (argsort_workspace_size + sizeof(std::int64_t) - 1) / sizeof(std::int64_t);
28-
thrust::device_ptr<std::int64_t> idx_ptr(
29-
&reinterpret_cast<std::int64_t *>(workspace)[workspace_offset]);
45+
if (num_elements == 0U) {
46+
return cudaSuccess;
47+
}
3048

31-
thrust::sequence(thrust::cuda::par.on(stream), idx_ptr, idx_ptr + num_elements, 0);
49+
const auto scratch_offset = align_up(argsort_workspace_size, alignof(std::int64_t));
50+
auto * input_idx_d =
51+
reinterpret_cast<std::int64_t *>(reinterpret_cast<char *>(workspace) + scratch_offset);
52+
auto * input_sorted_d = input_idx_d + num_elements;
3253

33-
std::int64_t * input_sorted_d = thrust::raw_pointer_cast(idx_ptr) + num_elements;
54+
const auto num_blocks =
55+
static_cast<unsigned int>((num_elements + kThreadsPerBlock - 1U) / kThreadsPerBlock);
56+
fill_iota<<<num_blocks, kThreadsPerBlock, 0, stream>>>(input_idx_d, num_elements);
57+
cudaError_t status = cudaGetLastError();
58+
if (status != cudaSuccess) {
59+
return status;
60+
}
3461

3562
return cub::DeviceRadixSort::SortPairs(
36-
workspace, argsort_workspace_size, input_d, input_sorted_d, thrust::raw_pointer_cast(idx_ptr),
37-
output_d, num_elements, 0, 64, stream);
63+
workspace, argsort_workspace_size, input_d, input_sorted_d, input_idx_d, output_d, num_elements,
64+
0, 64, stream);
3865
}
3966

4067
std::size_t get_argsort_workspace_size(std::size_t num_elements)

perception/autoware_tensorrt_plugins/src/argsort_plugin.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,11 @@ std::int32_t ArgsortPlugin::enqueue(
149149
cudaStream_t stream) noexcept
150150
{
151151
auto num_elements = static_cast<std::size_t>(input_desc[0].dims.d[0]);
152-
if (max_num_elements_ < num_elements) {
153-
max_num_elements_ = num_elements;
154-
argsort_workspace_size_ = get_argsort_workspace_size(max_num_elements_);
155-
}
152+
const auto workspace_size = get_argsort_workspace_size(num_elements);
156153

157154
return argsort(
158155
reinterpret_cast<std::int64_t const *>(inputs[0]), reinterpret_cast<std::int64_t *>(outputs[0]),
159-
workspace, num_elements, argsort_workspace_size_, stream);
156+
workspace, num_elements, workspace_size, stream);
160157
}
161158

162159
std::int32_t ArgsortPlugin::onShapeChange(
@@ -183,8 +180,10 @@ std::size_t ArgsortPlugin::getWorkspaceSize(
183180
[[maybe_unused]] std::int32_t num_outputs) const noexcept
184181
{
185182
std::int64_t max_num_elements = inputs[0].max.d[0];
186-
return get_argsort_workspace_size(max_num_elements) +
187-
sizeof(std::int64_t) * 2 * (max_num_elements + 1);
183+
const auto temp_size = get_argsort_workspace_size(max_num_elements);
184+
const auto scratch_offset =
185+
((temp_size + alignof(std::int64_t) - 1U) / alignof(std::int64_t)) * alignof(std::int64_t);
186+
return scratch_offset + sizeof(std::int64_t) * 2 * max_num_elements;
188187
}
189188

190189
} // namespace nvinfer1::plugin

perception/autoware_tensorrt_plugins/src/unique_ops/unique.cu

Lines changed: 125 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -101,78 +101,152 @@
101101

102102
#include <cub/cub.cuh>
103103

104-
#include <thrust/adjacent_difference.h>
105-
#include <thrust/device_ptr.h>
106-
#include <thrust/execution_policy.h>
107-
#include <thrust/scan.h>
108-
#include <thrust/scatter.h>
109-
#include <thrust/sequence.h>
110-
#include <thrust/sort.h>
111-
#include <thrust/unique.h>
104+
#include <algorithm>
105+
#include <cstdint>
112106

113-
std::int64_t unique(
114-
const std::int64_t * input, std::int64_t * unique, std::int64_t * inverse_indices,
115-
std::int64_t * unique_counts, void * workspace, std::size_t num_input_elements,
116-
std::size_t unique_workspace_size, cudaStream_t stream)
107+
namespace
117108
{
118-
auto policy = thrust::cuda::par.on(stream);
119-
120-
thrust::device_ptr<std::int64_t> idx_ptr(reinterpret_cast<std::int64_t *>(workspace));
121109

122-
thrust::sequence(policy, idx_ptr, idx_ptr + num_input_elements + 1, 0);
110+
constexpr int kThreadsPerBlock = 256;
123111

124-
std::int64_t * sorted_input = unique;
125-
std::int64_t * sorted_idx = thrust::raw_pointer_cast(idx_ptr) + 2 * num_input_elements + 1;
126-
std::int64_t * inv_loc_ptr = thrust::raw_pointer_cast(idx_ptr) + 3 * num_input_elements + 1;
112+
std::size_t align_up(const std::size_t size, const std::size_t alignment)
113+
{
114+
return ((size + alignment - 1U) / alignment) * alignment;
115+
}
127116

128-
void * sort_workspace_ptr =
129-
reinterpret_cast<void *>(thrust::raw_pointer_cast(idx_ptr) + 4 * num_input_elements + 1);
117+
std::size_t query_unique_temp_storage_size(const std::size_t num_elements)
118+
{
119+
std::size_t sort_temp_size = 0;
120+
std::size_t scan_temp_size = 0;
121+
std::size_t unique_temp_size = 0;
130122

131-
auto sort_workspace_size =
132-
unique_workspace_size - (4 * num_input_elements + 1) * sizeof(std::int64_t);
123+
std::int64_t * int64_nullptr = nullptr;
124+
std::int32_t * int32_nullptr = nullptr;
133125

134126
cub::DeviceRadixSort::SortPairs(
135-
sort_workspace_ptr, sort_workspace_size, input, sorted_input, thrust::raw_pointer_cast(idx_ptr),
136-
sorted_idx, num_input_elements, 0, 64, stream);
137-
138-
auto equal = [] __device__(const std::int64_t a, const std::int64_t b) { return a == b; };
127+
nullptr, sort_temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr,
128+
num_elements, 0, 64, nullptr);
129+
cub::DeviceScan::InclusiveSum(
130+
nullptr, scan_temp_size, int32_nullptr, int32_nullptr, num_elements, nullptr);
131+
cub::DeviceSelect::UniqueByKey(
132+
nullptr, unique_temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr,
133+
int64_nullptr, num_elements, nullptr);
134+
135+
return std::max(sort_temp_size, std::max(scan_temp_size, unique_temp_size));
136+
}
139137

140-
auto not_equal = [] __device__(const std::int64_t a, const std::int64_t b) { return a != b; };
138+
__global__ void mark_run_starts(
139+
const std::int64_t * sorted_input, std::int32_t * run_ids, const std::size_t num_input_elements)
140+
{
141+
const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
142+
if (index >= num_input_elements) {
143+
return;
144+
}
141145

142-
thrust::adjacent_difference(
143-
policy, sorted_input, sorted_input + num_input_elements, inv_loc_ptr, not_equal);
146+
run_ids[index] = (index == 0U || sorted_input[index] != sorted_input[index - 1U]) ? 1 : 0;
147+
}
144148

145-
cudaMemsetAsync(inv_loc_ptr, 0, sizeof(int64_t), stream);
149+
__global__ void fill_iota(std::int64_t * output, const std::size_t num_input_elements)
150+
{
151+
const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
152+
if (index >= num_input_elements) {
153+
return;
154+
}
146155

147-
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_input_elements, inv_loc_ptr);
148-
thrust::scatter(
149-
policy, inv_loc_ptr, inv_loc_ptr + num_input_elements, sorted_idx, inverse_indices);
156+
output[index] = static_cast<std::int64_t>(index);
157+
}
150158

151-
std::int64_t num_out;
159+
__global__ void scatter_inverse_indices(
160+
const std::int64_t * sorted_idx, const std::int32_t * run_ids, std::int64_t * inverse_indices,
161+
const std::size_t num_input_elements)
162+
{
163+
const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
164+
if (index >= num_input_elements) {
165+
return;
166+
}
152167

153-
std::int64_t * range_ptr = idx_ptr.get();
154-
num_out =
155-
thrust::unique_by_key(policy, sorted_input, sorted_input + num_input_elements, range_ptr, equal)
156-
.first -
157-
sorted_input;
168+
inverse_indices[sorted_idx[index]] = static_cast<std::int64_t>(run_ids[index] - 1);
169+
}
158170

159-
cudaMemcpyAsync(
160-
range_ptr + num_out * sizeof(int64_t), &num_input_elements, sizeof(std::int64_t),
161-
cudaMemcpyHostToDevice, stream);
171+
__global__ void write_unique_offset_sentinel(
172+
std::int64_t * unique_offsets, const std::int64_t * num_unique,
173+
const std::size_t num_input_elements)
174+
{
175+
unique_offsets[*num_unique] = static_cast<std::int64_t>(num_input_elements);
176+
}
162177

163-
thrust::adjacent_difference(policy, range_ptr + 1, range_ptr + num_out + 1, unique_counts);
178+
__global__ void write_unique_counts(
179+
const std::int64_t * unique_offsets, const std::int64_t * num_unique, std::int64_t * unique_counts)
180+
{
181+
const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
182+
if (index >= static_cast<std::size_t>(*num_unique)) {
183+
return;
184+
}
164185

165-
return num_out;
186+
unique_counts[index] = unique_offsets[index + 1U] - unique_offsets[index];
166187
}
167188

168-
std::size_t get_unique_workspace_size(std::size_t num_elements)
189+
} // namespace
190+
191+
std::int64_t unique(
192+
const std::int64_t * input, std::int64_t * unique, std::int64_t * inverse_indices,
193+
std::int64_t * unique_counts, void * workspace, std::size_t num_input_elements,
194+
std::size_t unique_workspace_size, cudaStream_t stream)
169195
{
170-
std::size_t temp_size = 0;
171-
std::int64_t * int64_nullptr = nullptr;
196+
if (num_input_elements == 0U) {
197+
return 0;
198+
}
172199

200+
const auto temp_storage_size = get_unique_temp_storage_size(num_input_elements);
201+
const auto scratch_offset = align_up(temp_storage_size, alignof(std::int64_t));
202+
auto * scratch = reinterpret_cast<char *>(workspace) + scratch_offset;
203+
204+
auto * input_positions = reinterpret_cast<std::int64_t *>(scratch);
205+
auto * sorted_input = input_positions + num_input_elements;
206+
auto * unique_offsets = sorted_input + num_input_elements;
207+
auto * num_unique_d = unique_offsets + num_input_elements + 1U;
208+
auto * run_ids = reinterpret_cast<std::int32_t *>(num_unique_d + 1U);
209+
210+
const auto num_blocks =
211+
static_cast<unsigned int>((num_input_elements + kThreadsPerBlock - 1U) / kThreadsPerBlock);
212+
213+
fill_iota<<<num_blocks, kThreadsPerBlock, 0, stream>>>(input_positions, num_input_elements);
173214
cub::DeviceRadixSort::SortPairs(
174-
nullptr, temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr, num_elements, 0,
175-
64, nullptr);
215+
workspace, temp_storage_size, input, sorted_input, input_positions, unique_offsets,
216+
num_input_elements, 0, 64, stream);
217+
218+
mark_run_starts<<<num_blocks, kThreadsPerBlock, 0, stream>>>(
219+
sorted_input, run_ids, num_input_elements);
220+
cub::DeviceScan::InclusiveSum(
221+
workspace, temp_storage_size, run_ids, run_ids, num_input_elements, stream);
222+
223+
scatter_inverse_indices<<<num_blocks, kThreadsPerBlock, 0, stream>>>(
224+
unique_offsets, run_ids, inverse_indices, num_input_elements);
176225

177-
return temp_size + (4 * num_elements + 1) * sizeof(std::int64_t);
226+
cub::DeviceSelect::UniqueByKey(
227+
workspace, temp_storage_size, sorted_input, input_positions, unique, unique_offsets,
228+
num_unique_d, num_input_elements, stream);
229+
230+
write_unique_offset_sentinel<<<1, 1, 0, stream>>>(
231+
unique_offsets, num_unique_d, num_input_elements);
232+
write_unique_counts<<<num_blocks, kThreadsPerBlock, 0, stream>>>(
233+
unique_offsets, num_unique_d, unique_counts);
234+
235+
std::int64_t num_out = 0;
236+
cudaMemcpyAsync(&num_out, num_unique_d, sizeof(std::int64_t), cudaMemcpyDeviceToHost, stream);
237+
cudaStreamSynchronize(stream);
238+
return num_out;
239+
}
240+
241+
std::size_t get_unique_temp_storage_size(std::size_t num_elements)
242+
{
243+
return query_unique_temp_storage_size(num_elements);
244+
}
245+
246+
std::size_t get_unique_workspace_size(std::size_t num_elements)
247+
{
248+
const auto temp_size = query_unique_temp_storage_size(num_elements);
249+
const auto scratch_offset = align_up(temp_size, alignof(std::int64_t));
250+
return scratch_offset + (3 * num_elements + 2U) * sizeof(std::int64_t) +
251+
num_elements * sizeof(std::int32_t);
178252
}

perception/autoware_tensorrt_plugins/src/unique_plugin.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,12 @@ std::int32_t UniquePlugin::enqueue(
164164
cudaStream_t stream) noexcept
165165
{
166166
std::int64_t num_elements = input_desc[0].dims.d[0];
167+
const auto workspace_size = get_unique_workspace_size(static_cast<std::size_t>(num_elements));
167168

168169
std::int64_t num_unique_elements = unique(
169170
reinterpret_cast<const std::int64_t *>(inputs[0]), reinterpret_cast<std::int64_t *>(outputs[0]),
170171
reinterpret_cast<std::int64_t *>(outputs[1]), reinterpret_cast<std::int64_t *>(outputs[2]),
171-
workspace, num_elements, workspace_size_, stream);
172+
workspace, num_elements, workspace_size, stream);
172173

173174
cudaMemcpyAsync(
174175
reinterpret_cast<std::int64_t *>(outputs[3]), &num_unique_elements, sizeof(std::int64_t),

0 commit comments

Comments
 (0)