@@ -82,18 +82,20 @@ __global__ void reconstructPartialKernel(
8282 float * __restrict__ output_probs, std::size_t num_classes, std::size_t num_cropped_points,
8383 std::size_t num_voxels)
8484{
85- const auto idx = static_cast <std:: uint32_t >( blockIdx .x * blockDim .x + threadIdx .x ) ;
86- if (idx >= num_cropped_points) {
87- return ;
88- }
85+ const auto point_idx = blockIdx .x * blockDim .y + threadIdx .y ;
86+ const auto class_idx = blockIdx . y * blockDim . x + threadIdx . x ;
87+
88+ if (point_idx >= num_cropped_points || class_idx >= num_classes) return ;
8989
90- const auto voxel_idx = inverse_map[idx ];
90+ const auto voxel_idx = inverse_map[point_idx ];
9191 const bool has_valid_voxel = voxel_idx >= 0 && static_cast <std::size_t >(voxel_idx) < num_voxels;
92- output_labels[idx] = has_valid_voxel ? voxel_labels[voxel_idx] : 255 ;
93- for (std::size_t class_idx = 0 ; class_idx < num_classes; ++class_idx) {
94- output_probs[idx * num_classes + class_idx] =
95- has_valid_voxel ? voxel_probs[voxel_idx * num_classes + class_idx] : 0 .0f ;
92+ if (class_idx == 0 ) {
93+ output_labels[point_idx] = has_valid_voxel ? voxel_labels[voxel_idx] : 255 ;
9694 }
95+
96+ output_probs[point_idx * num_classes + class_idx] =
97+ has_valid_voxel ? voxel_probs[static_cast <std::size_t >(voxel_idx) * num_classes + class_idx]
98+ : 0 .0f ;
9799}
98100
99101__global__ void reconstructFullKernel (
@@ -103,27 +105,30 @@ __global__ void reconstructFullKernel(
103105 float * __restrict__ output_probs, std::size_t num_classes, std::size_t num_points,
104106 std::size_t num_voxels)
105107{
106- const auto idx = static_cast <std:: uint32_t >( blockIdx .x * blockDim .x + threadIdx .x ) ;
107- if (idx >= num_points) {
108- return ;
109- }
108+ const auto point_idx = blockIdx .x * blockDim .y + threadIdx .y ;
109+ const auto class_idx = blockIdx . y * blockDim . x + threadIdx . x ;
110+
111+ if (point_idx >= num_points || class_idx >= num_classes) return ;
110112
111- if (crop_mask[idx] == 0 ) {
112- output_labels[idx] = 255 ;
113+ const auto mask = crop_mask[point_idx];
114+ if (mask == 0 ) {
115+ if (class_idx == 0 ) output_labels[point_idx] = 255 ;
113116 return ;
114117 }
115118
116- const auto cropped_idx = crop_indices[idx ] - 1 ;
119+ const auto cropped_idx = crop_indices[point_idx ] - 1 ;
117120 const auto voxel_idx = inverse_map[cropped_idx];
118121 if (voxel_idx < 0 || static_cast <std::size_t >(voxel_idx) >= num_voxels) {
119- output_labels[idx ] = 255 ;
122+ if (class_idx == 0 ) output_labels[point_idx ] = 255 ;
120123 return ;
121124 }
122125
123- output_labels[idx] = voxel_labels[voxel_idx];
124- for (std::size_t class_idx = 0 ; class_idx < num_classes; ++class_idx) {
125- output_probs[idx * num_classes + class_idx] = voxel_probs[voxel_idx * num_classes + class_idx];
126+ if (class_idx == 0 ) {
127+ output_labels[point_idx] = voxel_labels[voxel_idx];
126128 }
129+
130+ output_probs[point_idx * num_classes + class_idx] =
131+ voxel_probs[static_cast <std::size_t >(voxel_idx) * num_classes + class_idx];
127132}
128133
129134template <typename OutputPointT>
@@ -306,9 +311,10 @@ void PostprocessCuda::reconstructPartial(
306311 std::int64_t * output_labels, float * output_probs, std::size_t num_classes,
307312 std::size_t num_cropped_points, std::size_t num_voxels)
308313{
309- auto num_blocks = divup (num_cropped_points, config_.threads_per_block_ );
314+ auto block = dim3 (32 , 8 );
315+ auto grid = dim3 (divup (num_cropped_points, block.y ), divup (num_classes, block.x ));
310316
311- reconstructPartialKernel<<<num_blocks, config_.threads_per_block_ , 0 , stream_>>> (
317+ reconstructPartialKernel<<<grid, block , 0 , stream_>>> (
312318 inverse_map, voxel_labels, voxel_probs, output_labels, output_probs, num_classes,
313319 num_cropped_points, num_voxels);
314320
@@ -321,9 +327,10 @@ void PostprocessCuda::reconstructFull(
321327 std::int64_t * output_labels, float * output_probs, std::size_t num_classes,
322328 std::size_t num_points, std::size_t num_voxels)
323329{
324- auto num_blocks = divup (num_points, config_.threads_per_block_ );
330+ auto block = dim3 (32 , 8 );
331+ auto grid = dim3 (divup (num_points, block.y ), divup (num_classes, block.x ));
325332
326- reconstructFullKernel<<<num_blocks, config_.threads_per_block_ , 0 , stream_>>> (
333+ reconstructFullKernel<<<grid, block , 0 , stream_>>> (
327334 crop_mask, crop_indices, inverse_map, voxel_labels, voxel_probs, output_labels, output_probs,
328335 num_classes, num_points, num_voxels);
329336
0 commit comments