Skip to content

Commit 7163e59

Browse files
committed
perf(autoware_tensorrt_plugins): cache max sort workspace bounds
1 parent e20caa4 commit 7163e59

8 files changed

Lines changed: 192 additions & 50 deletions

File tree

perception/autoware_tensorrt_plugins/include/autoware/tensorrt_plugins/argsort_plugin.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ArgsortPlugin : public IPluginV3,
3838
{
3939
public:
4040
explicit ArgsortPlugin(const std::string & name) noexcept;
41+
ArgsortPlugin(const std::string & name, std::int64_t max_num_elements) noexcept;
4142

4243
~ArgsortPlugin() override = default;
4344

@@ -97,10 +98,11 @@ class ArgsortPlugin : public IPluginV3,
9798

9899
private:
99100
void initFieldsToSerialize();
101+
void updateMaxNumElements(std::int64_t max_num_elements);
100102

101103
std::string layer_name_;
102-
std::size_t argsort_workspace_size_{0};
103-
std::size_t max_num_elements_{0};
104+
std::size_t max_temp_storage_size_{0};
105+
std::int64_t max_num_elements_{0};
104106
std::vector<nvinfer1::PluginField> data_to_serialize_;
105107
nvinfer1::PluginFieldCollection fc_to_serialize_;
106108
};

perception/autoware_tensorrt_plugins/include/autoware/tensorrt_plugins/unique_plugin.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class UniquePlugin : public IPluginV3,
3838
{
3939
public:
4040
explicit UniquePlugin(const std::string & name) noexcept;
41+
UniquePlugin(const std::string & name, std::int64_t max_num_elements) noexcept;
4142

4243
~UniquePlugin() override = default;
4344

@@ -97,9 +98,11 @@ class UniquePlugin : public IPluginV3,
9798

9899
private:
99100
void initFieldsToSerialize();
101+
void updateMaxNumElements(std::int64_t max_num_elements);
100102

101103
std::string layer_name_;
102-
std::size_t workspace_size_{0};
104+
std::size_t max_temp_storage_size_{0};
105+
std::int64_t max_num_elements_{0};
103106
std::vector<nvinfer1::PluginField> data_to_serialize_;
104107
nvinfer1::PluginFieldCollection fc_to_serialize_;
105108
};

perception/autoware_tensorrt_plugins/src/argsort_plugin.cpp

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <NvInferRuntime.h>
2121
#include <NvInferRuntimePlugin.h>
2222

23+
#include <algorithm>
2324
#include <cstdint>
2425
#include <cstring>
2526
#include <exception>
@@ -29,14 +30,50 @@
2930
namespace nvinfer1::plugin
3031
{
3132

33+
namespace
34+
{
35+
36+
std::size_t alignUp(const std::size_t size, const std::size_t alignment)
37+
{
38+
return ((size + alignment - 1U) / alignment) * alignment;
39+
}
40+
41+
std::size_t getTotalWorkspaceSize(
42+
const std::size_t temp_storage_size, const std::size_t max_num_elements)
43+
{
44+
return alignUp(temp_storage_size, alignof(std::int64_t)) +
45+
sizeof(std::int64_t) * 2U * max_num_elements;
46+
}
47+
48+
} // namespace
49+
3250
ArgsortPlugin::ArgsortPlugin(const std::string & name) noexcept : layer_name_{name}
3351
{
3452
initFieldsToSerialize();
3553
}
3654

55+
ArgsortPlugin::ArgsortPlugin(const std::string & name, const std::int64_t max_num_elements) noexcept
56+
: layer_name_{name}
57+
{
58+
updateMaxNumElements(max_num_elements);
59+
initFieldsToSerialize();
60+
}
61+
62+
void ArgsortPlugin::updateMaxNumElements(const std::int64_t max_num_elements)
63+
{
64+
PLUGIN_ASSERT(max_num_elements >= 0);
65+
66+
max_num_elements_ = std::max(max_num_elements_, max_num_elements);
67+
max_temp_storage_size_ = std::max(
68+
max_temp_storage_size_,
69+
get_argsort_workspace_size(static_cast<std::size_t>(max_num_elements_)));
70+
}
71+
3772
void ArgsortPlugin::initFieldsToSerialize()
3873
{
3974
data_to_serialize_.clear();
75+
data_to_serialize_.emplace_back(
76+
"max_num_elements", &max_num_elements_, PluginFieldType::kINT64, 1);
4077
fc_to_serialize_.nbFields = data_to_serialize_.size();
4178
fc_to_serialize_.fields = data_to_serialize_.data();
4279
}
@@ -61,7 +98,7 @@ IPluginCapability * ArgsortPlugin::getCapabilityInterface(PluginCapabilityType t
6198
IPluginV3 * ArgsortPlugin::clone() noexcept
6299
{
63100
try {
64-
IPluginV3 * const plugin{new ArgsortPlugin{layer_name_}};
101+
IPluginV3 * const plugin{new ArgsortPlugin{layer_name_, max_num_elements_}};
65102
return plugin;
66103
} catch (std::exception const & e) {
67104
caughtError(e);
@@ -100,6 +137,7 @@ std::int32_t ArgsortPlugin::configurePlugin(
100137
PLUGIN_ASSERT(out[0].desc.dims.nbDims == 1);
101138

102139
PLUGIN_ASSERT(out[0].desc.type == in[0].desc.type);
140+
updateMaxNumElements(in[0].max.d[0]);
103141

104142
return 0;
105143
}
@@ -149,11 +187,14 @@ std::int32_t ArgsortPlugin::enqueue(
149187
cudaStream_t stream) noexcept
150188
{
151189
auto num_elements = static_cast<std::size_t>(input_desc[0].dims.d[0]);
152-
const auto workspace_size = get_argsort_workspace_size(num_elements);
190+
PLUGIN_ASSERT(static_cast<std::int64_t>(num_elements) <= max_num_elements_);
191+
const auto temp_storage_size = max_temp_storage_size_ == 0U
192+
? get_argsort_workspace_size(num_elements)
193+
: max_temp_storage_size_;
153194

154195
return argsort(
155196
reinterpret_cast<std::int64_t const *>(inputs[0]), reinterpret_cast<std::int64_t *>(outputs[0]),
156-
workspace, num_elements, workspace_size, stream);
197+
workspace, num_elements, temp_storage_size, stream);
157198
}
158199

159200
std::int32_t ArgsortPlugin::onShapeChange(
@@ -179,11 +220,13 @@ std::size_t ArgsortPlugin::getWorkspaceSize(
179220
[[maybe_unused]] DynamicPluginTensorDesc const * outputs,
180221
[[maybe_unused]] std::int32_t num_outputs) const noexcept
181222
{
182-
std::int64_t max_num_elements = inputs[0].max.d[0];
183-
const auto temp_size = get_argsort_workspace_size(max_num_elements);
184-
const auto scratch_offset =
185-
((temp_size + alignof(std::int64_t) - 1U) / alignof(std::int64_t)) * alignof(std::int64_t);
186-
return scratch_offset + sizeof(std::int64_t) * 2 * max_num_elements;
223+
const auto max_num_elements =
224+
std::max(max_num_elements_, static_cast<std::int64_t>(inputs[0].max.d[0]));
225+
const auto temp_storage_size =
226+
max_temp_storage_size_ == 0U
227+
? get_argsort_workspace_size(static_cast<std::size_t>(max_num_elements))
228+
: max_temp_storage_size_;
229+
return getTotalWorkspaceSize(temp_storage_size, static_cast<std::size_t>(max_num_elements));
187230
}
188231

189232
} // namespace nvinfer1::plugin

perception/autoware_tensorrt_plugins/src/argsort_plugin_creator.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
#include <NvInferRuntimePlugin.h>
2121

22+
#include <cstdint>
23+
#include <exception>
2224
#include <string>
2325

2426
namespace nvinfer1::plugin
@@ -39,10 +41,32 @@ nvinfer1::PluginFieldCollection const * ArgsortPluginCreator::getFieldNames() no
3941
}
4042

4143
IPluginV3 * ArgsortPluginCreator::createPlugin(
42-
char const * name, [[maybe_unused]] PluginFieldCollection const * fc,
43-
[[maybe_unused]] TensorRTPhase phase) noexcept
44+
char const * name, PluginFieldCollection const * fc, TensorRTPhase phase) noexcept
4445
{
45-
return new (std::nothrow) ArgsortPlugin(std::string(name));
46+
try {
47+
PLUGIN_VALIDATE(fc != nullptr);
48+
49+
if (phase == TensorRTPhase::kBUILD) {
50+
PLUGIN_VALIDATE(fc->nbFields == 0);
51+
return new (std::nothrow) ArgsortPlugin(std::string(name));
52+
}
53+
54+
if (phase == TensorRTPhase::kRUNTIME) {
55+
nvinfer1::PluginField const * fields{fc->fields};
56+
PLUGIN_VALIDATE(fc->nbFields == 1);
57+
PLUGIN_VALIDATE(fields[0].name != nullptr);
58+
PLUGIN_VALIDATE(std::string(fields[0].name) == "max_num_elements");
59+
PLUGIN_VALIDATE(fields[0].type == nvinfer1::PluginFieldType::kINT64);
60+
PLUGIN_VALIDATE(fields[0].length == 1);
61+
62+
return new (std::nothrow)
63+
ArgsortPlugin(std::string(name), *static_cast<std::int64_t const *>(fields[0].data));
64+
}
65+
} catch (std::exception const & e) {
66+
caughtError(e);
67+
}
68+
69+
return nullptr;
4670
}
4771

4872
} // namespace nvinfer1::plugin

perception/autoware_tensorrt_plugins/src/unique_ops/unique.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ __global__ void write_unique_offset_sentinel(
176176
}
177177

178178
__global__ void write_unique_counts(
179-
const std::int64_t * unique_offsets, const std::int64_t * num_unique, std::int64_t * unique_counts)
179+
const std::int64_t * unique_offsets, const std::int64_t * num_unique,
180+
std::int64_t * unique_counts)
180181
{
181182
const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
182183
if (index >= static_cast<std::size_t>(*num_unique)) {
@@ -191,15 +192,13 @@ __global__ void write_unique_counts(
191192
cudaError_t unique(
192193
const std::int64_t * input, std::int64_t * unique, std::int64_t * inverse_indices,
193194
std::int64_t * unique_counts, std::int64_t * num_unique, void * workspace,
194-
std::size_t num_input_elements, std::size_t unique_workspace_size, cudaStream_t stream)
195+
std::size_t num_input_elements, std::size_t unique_temp_storage_size, cudaStream_t stream)
195196
{
196-
(void)unique_workspace_size;
197197
if (num_input_elements == 0U) {
198198
return cudaMemsetAsync(num_unique, 0, sizeof(std::int64_t), stream);
199199
}
200200

201-
const auto temp_storage_size = get_unique_temp_storage_size(num_input_elements);
202-
const auto scratch_offset = align_up(temp_storage_size, alignof(std::int64_t));
201+
const auto scratch_offset = align_up(unique_temp_storage_size, alignof(std::int64_t));
203202
auto * scratch = reinterpret_cast<char *>(workspace) + scratch_offset;
204203

205204
auto * input_positions = reinterpret_cast<std::int64_t *>(scratch);
@@ -217,7 +216,7 @@ cudaError_t unique(
217216
}
218217

219218
status = cub::DeviceRadixSort::SortPairs(
220-
workspace, temp_storage_size, input, sorted_input, input_positions, unique_offsets,
219+
workspace, unique_temp_storage_size, input, sorted_input, input_positions, unique_offsets,
221220
num_input_elements, 0, 64, stream);
222221
if (status != cudaSuccess) {
223222
return status;
@@ -231,7 +230,7 @@ cudaError_t unique(
231230
}
232231

233232
status = cub::DeviceScan::InclusiveSum(
234-
workspace, temp_storage_size, run_ids, run_ids, num_input_elements, stream);
233+
workspace, unique_temp_storage_size, run_ids, run_ids, num_input_elements, stream);
235234
if (status != cudaSuccess) {
236235
return status;
237236
}
@@ -244,14 +243,13 @@ cudaError_t unique(
244243
}
245244

246245
status = cub::DeviceSelect::UniqueByKey(
247-
workspace, temp_storage_size, sorted_input, input_positions, unique, unique_offsets,
246+
workspace, unique_temp_storage_size, sorted_input, input_positions, unique, unique_offsets,
248247
num_unique, num_input_elements, stream);
249248
if (status != cudaSuccess) {
250249
return status;
251250
}
252251

253-
write_unique_offset_sentinel<<<1, 1, 0, stream>>>(
254-
unique_offsets, num_unique, num_input_elements);
252+
write_unique_offset_sentinel<<<1, 1, 0, stream>>>(unique_offsets, num_unique, num_input_elements);
255253
status = cudaGetLastError();
256254
if (status != cudaSuccess) {
257255
return status;

perception/autoware_tensorrt_plugins/src/unique_plugin.cpp

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <NvInferRuntime.h>
2121
#include <NvInferRuntimePlugin.h>
2222

23+
#include <algorithm>
2324
#include <cstdint>
2425
#include <exception>
2526
#include <string>
@@ -28,14 +29,51 @@
2829
namespace nvinfer1::plugin
2930
{
3031

32+
namespace
33+
{
34+
35+
std::size_t alignUp(const std::size_t size, const std::size_t alignment)
36+
{
37+
return ((size + alignment - 1U) / alignment) * alignment;
38+
}
39+
40+
std::size_t getTotalWorkspaceSize(
41+
const std::size_t temp_storage_size, const std::size_t max_num_elements)
42+
{
43+
return alignUp(temp_storage_size, alignof(std::int64_t)) +
44+
(3U * max_num_elements + 1U) * sizeof(std::int64_t) +
45+
max_num_elements * sizeof(std::int32_t);
46+
}
47+
48+
} // namespace
49+
3150
UniquePlugin::UniquePlugin(const std::string & name) noexcept : layer_name_(name)
3251
{
3352
initFieldsToSerialize();
3453
}
3554

55+
UniquePlugin::UniquePlugin(const std::string & name, const std::int64_t max_num_elements) noexcept
56+
: layer_name_(name)
57+
{
58+
updateMaxNumElements(max_num_elements);
59+
initFieldsToSerialize();
60+
}
61+
62+
void UniquePlugin::updateMaxNumElements(const std::int64_t max_num_elements)
63+
{
64+
PLUGIN_ASSERT(max_num_elements >= 0);
65+
66+
max_num_elements_ = std::max(max_num_elements_, max_num_elements);
67+
max_temp_storage_size_ = std::max(
68+
max_temp_storage_size_,
69+
get_unique_temp_storage_size(static_cast<std::size_t>(max_num_elements_)));
70+
}
71+
3672
void UniquePlugin::initFieldsToSerialize()
3773
{
3874
data_to_serialize_.clear();
75+
data_to_serialize_.emplace_back(
76+
"max_num_elements", &max_num_elements_, PluginFieldType::kINT64, 1);
3977
fc_to_serialize_.nbFields = data_to_serialize_.size();
4078
fc_to_serialize_.fields = data_to_serialize_.data();
4179
}
@@ -60,7 +98,7 @@ IPluginCapability * UniquePlugin::getCapabilityInterface(PluginCapabilityType ty
6098
IPluginV3 * UniquePlugin::clone() noexcept
6199
{
62100
try {
63-
IPluginV3 * const plugin{new UniquePlugin{layer_name_}};
101+
IPluginV3 * const plugin{new UniquePlugin{layer_name_, max_num_elements_}};
64102
return plugin;
65103
} catch (std::exception const & e) {
66104
caughtError(e);
@@ -105,6 +143,7 @@ std::int32_t UniquePlugin::configurePlugin(
105143
PLUGIN_ASSERT(out[1].desc.type == in[0].desc.type);
106144
PLUGIN_ASSERT(out[2].desc.type == in[0].desc.type);
107145
PLUGIN_ASSERT(out[3].desc.type == in[0].desc.type);
146+
updateMaxNumElements(in[0].max.d[0]);
108147

109148
return 0;
110149
}
@@ -164,11 +203,16 @@ std::int32_t UniquePlugin::enqueue(
164203
cudaStream_t stream) noexcept
165204
{
166205
std::int64_t num_elements = input_desc[0].dims.d[0];
167-
const auto workspace_size = get_unique_workspace_size(static_cast<std::size_t>(num_elements));
206+
PLUGIN_ASSERT(num_elements <= max_num_elements_);
207+
const auto temp_storage_size =
208+
max_temp_storage_size_ == 0U
209+
? get_unique_temp_storage_size(static_cast<std::size_t>(num_elements))
210+
: max_temp_storage_size_;
168211
return unique(
169212
reinterpret_cast<const std::int64_t *>(inputs[0]), reinterpret_cast<std::int64_t *>(outputs[0]),
170213
reinterpret_cast<std::int64_t *>(outputs[1]), reinterpret_cast<std::int64_t *>(outputs[2]),
171-
reinterpret_cast<std::int64_t *>(outputs[3]), workspace, num_elements, workspace_size, stream);
214+
reinterpret_cast<std::int64_t *>(outputs[3]), workspace, num_elements, temp_storage_size,
215+
stream);
172216
}
173217

174218
std::int32_t UniquePlugin::onShapeChange(
@@ -194,7 +238,13 @@ std::size_t UniquePlugin::getWorkspaceSize(
194238
[[maybe_unused]] DynamicPluginTensorDesc const * outputs,
195239
[[maybe_unused]] std::int32_t num_outputs) const noexcept
196240
{
197-
return get_unique_workspace_size(inputs[0].max.d[0]);
241+
const auto max_num_elements =
242+
std::max(max_num_elements_, static_cast<std::int64_t>(inputs[0].max.d[0]));
243+
const auto temp_storage_size =
244+
max_temp_storage_size_ == 0U
245+
? get_unique_temp_storage_size(static_cast<std::size_t>(max_num_elements))
246+
: max_temp_storage_size_;
247+
return getTotalWorkspaceSize(temp_storage_size, static_cast<std::size_t>(max_num_elements));
198248
}
199249

200250
} // namespace nvinfer1::plugin

0 commit comments

Comments
 (0)