Skip to content

Commit 971643a

Browse files
committed
perf(autoware_tensorrt_plugins): keep SegmentCSR allocation-free
1 parent 23b3f79 commit 971643a

4 files changed

Lines changed: 167 additions & 107 deletions

File tree

perception/autoware_tensorrt_plugins/include/autoware/scatter_ops/segment_csr.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,10 @@
2626
#include <cuda_runtime.h>
2727

2828
#include <tuple>
29-
#include <vector>
3029

3130
template <typename scalar_t, ReductionType REDUCE>
3231
int32_t segment_csr_launch(
33-
const scalar_t * src, const std::vector<int32_t> & src_size, const int64_t * indptr,
34-
const std::vector<int32_t> & indptr_size, const scalar_t * base,
35-
std::tuple<scalar_t *, int64_t *> out, cudaStream_t stream);
36-
37-
template <typename scalar_t, ReductionType REDUCE>
38-
int32_t segment_csr_launch(
39-
const scalar_t * src, const std::vector<int32_t> & src_size, const int64_t * indptr,
40-
const std::vector<int32_t> & indptr_size, std::tuple<scalar_t *, int64_t *> out,
41-
cudaStream_t stream);
32+
const scalar_t * src, int32_t num_rows, int32_t num_cols, const int64_t * indptr,
33+
int32_t indptr_size, std::tuple<scalar_t *, int64_t *> out, cudaStream_t stream);
4234

4335
#endif // AUTOWARE__SCATTER_OPS__SEGMENT_CSR_H_

perception/autoware_tensorrt_plugins/src/scatter_ops/segment_csr.cu

Lines changed: 36 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,16 @@
2323
#include "autoware/scatter_ops/segment_csr.h"
2424
#include "autoware/scatter_ops/utils.cuh"
2525

26-
#include <numeric>
2726
#include <string>
2827
#include <tuple>
29-
#include <vector>
3028

3129
#define THREADS 256
3230
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
3331
#define FULL_MASK 0xffffffff
34-
#define SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, R) \
35-
template int32_t segment_csr_launch<T, R>( \
36-
const T * src, const std::vector<int32_t> & src_size, const int64_t * indptr, \
37-
const std::vector<int32_t> & indptr_size, std::tuple<T *, int64_t *> out, \
38-
cudaStream_t stream); \
39-
template int32_t segment_csr_launch<T, R>( \
40-
const T * src, const std::vector<int32_t> & src_size, const int64_t * indptr, \
41-
const std::vector<int32_t> & indptr_size, const T * base, std::tuple<T *, int64_t *> out, \
42-
cudaStream_t stream);
32+
#define SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, R) \
33+
template int32_t segment_csr_launch<T, R>( \
34+
const T * src, int32_t num_rows, int32_t num_cols, const int64_t * indptr, \
35+
int32_t indptr_size, std::tuple<T *, int64_t *> out, cudaStream_t stream);
4336
#define SEGMENT_CSR_LAUNCH_INSTANTIATION(T) \
4437
SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, ReductionType::SUM) \
4538
SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, ReductionType::MEAN) \
@@ -50,27 +43,25 @@
5043

5144
template <typename scalar_t, ReductionType REDUCE, int TB>
5245
__global__ void segment_csr_kernel(
53-
const scalar_t * src, const int64_t * indptr, const int64_t * indptr_size, int32_t indptr_dim,
54-
scalar_t * out, int64_t * arg_out, size_t N, size_t E)
46+
const scalar_t * src, const int64_t * indptr, scalar_t * out, int64_t * arg_out,
47+
size_t num_segments)
5548
{
5649
// Each warp processes exactly `32/TB` rows and aggregates all row values
5750
// via a parallel reduction.
5851

5952
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
6053
int row_idx = thread_idx / TB;
6154
int lane_idx = thread_idx & (TB - 1);
62-
if (row_idx >= N) return;
55+
if (row_idx >= num_segments) return;
6356

64-
int offset = indptr_to_offset(indptr_size, indptr_dim, row_idx);
65-
int64_t row_start = __ldg(indptr + offset);
66-
int64_t row_end = __ldg(indptr + offset + 1);
57+
int64_t row_start = __ldg(indptr + row_idx);
58+
int64_t row_end = __ldg(indptr + row_idx + 1);
6759

6860
scalar_t val = Reducer<scalar_t, REDUCE>::init();
69-
int64_t arg, arg_tmp;
61+
int64_t arg{0}, arg_tmp{0};
7062

71-
offset = (row_idx / (indptr_size[indptr_dim - 1] - 1)) * E;
7263
for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB)
73-
Reducer<scalar_t, REDUCE>::update(&val, src[offset + src_idx], &arg, src_idx);
64+
Reducer<scalar_t, REDUCE>::update(&val, src[src_idx], &arg, src_idx);
7465

7566
#pragma unroll
7667
for (int i = TB / 2; i > 0; i /= 2) {
@@ -90,28 +81,26 @@ __global__ void segment_csr_kernel(
9081

9182
template <typename scalar_t, ReductionType REDUCE>
9283
__global__ void segment_csr_broadcast_kernel(
93-
const scalar_t * src, const int64_t * indptr, const int64_t * indptr_size, int32_t indptr_dim,
94-
scalar_t * out, int64_t * arg_out, size_t N, size_t K, size_t E)
84+
const scalar_t * src, const int64_t * indptr, scalar_t * out, int64_t * arg_out,
85+
size_t num_segments, size_t num_cols)
9586
{
9687
// Each thread processes exactly one row. It turned out that is more
9788
// efficient than using shared memory due to avoiding synchronization
9889
// barriers.
9990

10091
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
101-
int row_idx = thread_idx / K;
102-
int lane_idx = thread_idx % K;
103-
if (thread_idx >= N * K) return;
92+
int row_idx = thread_idx / num_cols;
93+
int lane_idx = thread_idx % num_cols;
94+
if (thread_idx >= num_segments * num_cols) return;
10495

105-
int offset = indptr_to_offset(indptr_size, indptr_dim, row_idx);
106-
int64_t row_start = __ldg(indptr + offset);
107-
int64_t row_end = __ldg(indptr + offset + 1);
96+
int64_t row_start = __ldg(indptr + row_idx);
97+
int64_t row_end = __ldg(indptr + row_idx + 1);
10898

10999
scalar_t val = Reducer<scalar_t, REDUCE>::init();
110-
int64_t arg;
100+
int64_t arg{0};
111101

112-
offset = (row_idx / (indptr_size[indptr_dim - 1] - 1)) * E * K;
113102
for (int64_t src_idx = row_start; src_idx < row_end; src_idx++)
114-
Reducer<scalar_t, REDUCE>::update(&val, src[offset + K * src_idx + lane_idx], &arg, src_idx);
103+
Reducer<scalar_t, REDUCE>::update(&val, src[num_cols * src_idx + lane_idx], &arg, src_idx);
115104

116105
if (arg_out != nullptr)
117106
Reducer<scalar_t, REDUCE>::write(
@@ -124,71 +113,29 @@ __global__ void segment_csr_broadcast_kernel(
124113
//! \todo expand index
125114
template <typename scalar_t, ReductionType REDUCE>
126115
int32_t segment_csr_launch(
127-
const scalar_t * src, const std::vector<int32_t> & src_size, const int64_t * indptr,
128-
const std::vector<int32_t> & indptr_size, const scalar_t * base,
129-
std::tuple<scalar_t *, int64_t *> out, cudaStream_t stream)
116+
const scalar_t * src, int32_t num_rows, int32_t num_cols, const int64_t * indptr,
117+
int32_t indptr_size, std::tuple<scalar_t *, int64_t *> out, cudaStream_t stream)
130118
{
131-
if (src_size.size() < indptr_size.size()) return -1;
132-
133-
if (!std::equal(indptr_size.begin(), indptr_size.end() - 1, src_size.begin())) return -1;
134-
135-
auto dim = indptr_size.size() - 1;
136-
137-
auto _mul = [](int a, int b) { return a * b; };
138-
auto src_numel = std::accumulate(src_size.begin(), src_size.end(), 1, _mul);
139-
auto indptr_numel = std::accumulate(indptr_size.begin(), indptr_size.end(), 1, _mul);
140-
auto out_numel = src_numel / src_size[dim] * std::max<int32_t>(indptr_size[dim] - 1, 0);
141-
142-
cudaMemcpyAsync(
143-
std::get<0>(out), base, sizeof(scalar_t) * out_numel, cudaMemcpyDeviceToDevice, stream);
119+
if (num_rows < 0 || num_cols < 0 || indptr_size < 0) return -1;
144120

121+
auto num_segments = std::max<int32_t>(indptr_size - 1, 0);
122+
auto out_numel = static_cast<size_t>(num_segments) * static_cast<size_t>(num_cols);
145123
if ((REDUCE == ReductionType::MIN || REDUCE == ReductionType::MAX) && std::get<1>(out) != nullptr)
146124
fill_kernel<int64_t>
147-
<<<BLOCKS(1, out_numel), THREADS, 0, stream>>>(std::get<1>(out), out_numel, src_size[dim]);
148-
149-
if (src_numel == 0) return 0;
150-
151-
auto N = max(indptr_size[dim] - 1, 0) * (indptr_numel / indptr_size[dim]);
152-
auto K = out_numel / N;
153-
auto E = src_size[dim];
154-
int64_t * indptr_size_dev;
155-
cudaMallocAsync(&indptr_size_dev, sizeof(int64_t) * indptr_size.size(), stream);
156-
cudaMemcpyAsync(
157-
indptr_size_dev, indptr_size.data(), sizeof(int64_t) * indptr_size.size(),
158-
cudaMemcpyHostToDevice, stream);
159-
160-
if (K == 1)
161-
segment_csr_kernel<scalar_t, REDUCE, 1><<<BLOCKS(32, N), THREADS, 0, stream>>>(
162-
src, indptr, indptr_size_dev, indptr_size.size(), std::get<0>(out), std::get<1>(out), N, E);
163-
else
164-
segment_csr_broadcast_kernel<scalar_t, REDUCE><<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
165-
src, indptr, indptr_size_dev, indptr_size.size(), std::get<0>(out), std::get<1>(out), N, K,
166-
E);
125+
<<<BLOCKS(1, out_numel), THREADS, 0, stream>>>(std::get<1>(out), out_numel, num_rows);
167126

168-
cudaFreeAsync(indptr_size_dev, stream);
169-
return 0;
170-
}
127+
if (num_segments == 0 || num_cols == 0) return 0;
171128

172-
template <typename scalar_t, ReductionType REDUCE>
173-
int32_t segment_csr_launch(
174-
const scalar_t * src, const std::vector<int32_t> & src_size, const int64_t * indptr,
175-
const std::vector<int32_t> & indptr_size, std::tuple<scalar_t *, int64_t *> out,
176-
cudaStream_t stream)
177-
{
178-
auto dim = indptr_size.size() - 1;
179-
auto src_numel =
180-
std::accumulate(src_size.begin(), src_size.end(), 1, [](int a, int b) { return a * b; });
181-
auto out_numel = src_numel / src_size[dim] * std::max<int32_t>(indptr_size[dim] - 1, 0);
129+
fill_kernel<scalar_t><<<BLOCKS(1, out_numel), THREADS, 0, stream>>>(
130+
std::get<0>(out), out_numel, static_cast<scalar_t>(0));
182131

183-
scalar_t * base;
184-
cudaMallocAsync(&base, sizeof(scalar_t) * out_numel, stream);
185-
fill_kernel<scalar_t><<<BLOCKS(1, out_numel), THREADS, 0, stream>>>(base, out_numel, (scalar_t)0);
186-
187-
auto status =
188-
segment_csr_launch<scalar_t, REDUCE>(src, src_size, indptr, indptr_size, base, out, stream);
189-
if (status != 0) return status;
190-
191-
cudaFreeAsync(base, stream);
132+
if (num_cols == 1)
133+
segment_csr_kernel<scalar_t, REDUCE, 1><<<BLOCKS(32, num_segments), THREADS, 0, stream>>>(
134+
src, indptr, std::get<0>(out), std::get<1>(out), num_segments);
135+
else
136+
segment_csr_broadcast_kernel<scalar_t, REDUCE>
137+
<<<BLOCKS(1, num_segments * num_cols), THREADS, 0, stream>>>(
138+
src, indptr, std::get<0>(out), std::get<1>(out), num_segments, num_cols);
192139
return 0;
193140
}
194141

perception/autoware_tensorrt_plugins/src/segment_csr_plugin.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include <exception>
2828
#include <string>
2929
#include <tuple>
30-
#include <vector>
3130
namespace nvinfer1::plugin
3231
{
3332

@@ -178,9 +177,9 @@ std::int32_t SegmentCSRPlugin::enqueue(
178177
void const * const * inputs, void * const * outputs, [[maybe_unused]] void * workspace,
179178
cudaStream_t stream) noexcept
180179
{
181-
std::vector<int32_t> src_size{
182-
static_cast<int32_t>(input_desc[0].dims.d[0]), static_cast<int32_t>(input_desc[0].dims.d[1])};
183-
std::vector<int32_t> indptr_size{static_cast<int32_t>(input_desc[1].dims.d[0])};
180+
const auto num_rows = static_cast<int32_t>(input_desc[0].dims.d[0]);
181+
const auto num_cols = static_cast<int32_t>(input_desc[0].dims.d[1]);
182+
const auto indptr_size = static_cast<int32_t>(input_desc[1].dims.d[0]);
184183

185184
std::int32_t result = 0;
186185

@@ -192,8 +191,8 @@ std::int32_t SegmentCSRPlugin::enqueue(
192191
std::make_tuple(static_cast<float *>(outputs[0]), nullptr);
193192

194193
AT_DISPATCH_REDUCTION_TYPES(reduce_, [&] {
195-
result =
196-
segment_csr_launch<float, REDUCE>(src_ptr, src_size, indptr_ptr, indptr_size, out, stream);
194+
result = segment_csr_launch<float, REDUCE>(
195+
src_ptr, num_rows, num_cols, indptr_ptr, indptr_size, out, stream);
197196
});
198197
} else if (input_desc[0].type == nvinfer1::DataType::kHALF) {
199198
const half * src_ptr = reinterpret_cast<const half *>(inputs[0]);
@@ -203,8 +202,8 @@ std::int32_t SegmentCSRPlugin::enqueue(
203202
std::make_tuple(static_cast<half *>(outputs[0]), nullptr);
204203

205204
AT_DISPATCH_REDUCTION_TYPES(reduce_, [&] {
206-
result =
207-
segment_csr_launch<half, REDUCE>(src_ptr, src_size, indptr_ptr, indptr_size, out, stream);
205+
result = segment_csr_launch<half, REDUCE>(
206+
src_ptr, num_rows, num_cols, indptr_ptr, indptr_size, out, stream);
208207
});
209208
}
210209

0 commit comments

Comments
 (0)