2323#include " autoware/scatter_ops/segment_csr.h"
2424#include " autoware/scatter_ops/utils.cuh"
2525
26- #include < numeric>
2726#include < string>
2827#include < tuple>
29- #include < vector>
3028
3129#define THREADS 256
3230#define BLOCKS (TB, N ) (TB * N + THREADS - 1 ) / THREADS
3331#define FULL_MASK 0xffffffff
34- #define SEGMENT_CSR_LAUNCH_INSTANTIATION_TR (T, R ) \
35- template int32_t segment_csr_launch<T, R>( \
36- const T * src, const std::vector<int32_t > & src_size, const int64_t * indptr, \
37- const std::vector<int32_t > & indptr_size, std::tuple<T *, int64_t *> out, \
38- cudaStream_t stream); \
39- template int32_t segment_csr_launch<T, R>( \
40- const T * src, const std::vector<int32_t > & src_size, const int64_t * indptr, \
41- const std::vector<int32_t > & indptr_size, const T * base, std::tuple<T *, int64_t *> out, \
42- cudaStream_t stream);
32+ #define SEGMENT_CSR_LAUNCH_INSTANTIATION_TR (T, R ) \
33+ template int32_t segment_csr_launch<T, R>( \
34+ const T * src, int32_t num_rows, int32_t num_cols, const int64_t * indptr, \
35+ int32_t indptr_size, std::tuple<T *, int64_t *> out, cudaStream_t stream);
4336#define SEGMENT_CSR_LAUNCH_INSTANTIATION (T ) \
4437 SEGMENT_CSR_LAUNCH_INSTANTIATION_TR (T, ReductionType::SUM) \
4538 SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, ReductionType::MEAN) \
5043
5144template <typename scalar_t, ReductionType REDUCE, int TB>
5245__global__ void segment_csr_kernel(
53- const scalar_t * src, const int64_t * indptr, const int64_t * indptr_size, int32_t indptr_dim ,
54- scalar_t * out, int64_t * arg_out, size_t N, size_t E )
46+ const scalar_t * src, const int64_t * indptr, scalar_t * out, int64_t * arg_out ,
47+ size_t num_segments )
5548{
5649 // Each warp processes exactly `32/TB` rows and aggregates all row values
5750 // via a parallel reduction.
5851
5952 int thread_idx = blockIdx .x * blockDim .x + threadIdx .x ;
6053 int row_idx = thread_idx / TB;
6154 int lane_idx = thread_idx & (TB - 1 );
62- if (row_idx >= N ) return ;
55+ if (row_idx >= num_segments ) return ;
6356
64- int offset = indptr_to_offset (indptr_size, indptr_dim, row_idx);
65- int64_t row_start = __ldg (indptr + offset);
66- int64_t row_end = __ldg (indptr + offset + 1 );
57+ int64_t row_start = __ldg (indptr + row_idx);
58+ int64_t row_end = __ldg (indptr + row_idx + 1 );
6759
6860 scalar_t val = Reducer<scalar_t , REDUCE>::init ();
69- int64_t arg, arg_tmp;
61+ int64_t arg{ 0 } , arg_tmp{ 0 } ;
7062
71- offset = (row_idx / (indptr_size[indptr_dim - 1 ] - 1 )) * E;
7263 for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB)
73- Reducer<scalar_t , REDUCE>::update (&val, src[offset + src_idx], &arg, src_idx);
64+ Reducer<scalar_t , REDUCE>::update (&val, src[src_idx], &arg, src_idx);
7465
7566#pragma unroll
7667 for (int i = TB / 2 ; i > 0 ; i /= 2 ) {
@@ -90,28 +81,26 @@ __global__ void segment_csr_kernel(
9081
9182template <typename scalar_t , ReductionType REDUCE>
9283__global__ void segment_csr_broadcast_kernel (
93- const scalar_t * src, const int64_t * indptr, const int64_t * indptr_size, int32_t indptr_dim ,
94- scalar_t * out, int64_t * arg_out, size_t N , size_t K, size_t E )
84+ const scalar_t * src, const int64_t * indptr, scalar_t * out, int64_t * arg_out ,
85+ size_t num_segments , size_t num_cols )
9586{
9687 // Each thread processes exactly one row. It turned out that is more
9788 // efficient than using shared memory due to avoiding synchronization
9889 // barriers.
9990
10091 int thread_idx = blockIdx .x * blockDim .x + threadIdx .x ;
101- int row_idx = thread_idx / K ;
102- int lane_idx = thread_idx % K ;
103- if (thread_idx >= N * K ) return ;
92+ int row_idx = thread_idx / num_cols ;
93+ int lane_idx = thread_idx % num_cols ;
94+ if (thread_idx >= num_segments * num_cols ) return ;
10495
105- int offset = indptr_to_offset (indptr_size, indptr_dim, row_idx);
106- int64_t row_start = __ldg (indptr + offset);
107- int64_t row_end = __ldg (indptr + offset + 1 );
96+ int64_t row_start = __ldg (indptr + row_idx);
97+ int64_t row_end = __ldg (indptr + row_idx + 1 );
10898
10999 scalar_t val = Reducer<scalar_t , REDUCE>::init ();
110- int64_t arg;
100+ int64_t arg{ 0 } ;
111101
112- offset = (row_idx / (indptr_size[indptr_dim - 1 ] - 1 )) * E * K;
113102 for (int64_t src_idx = row_start; src_idx < row_end; src_idx++)
114- Reducer<scalar_t , REDUCE>::update (&val, src[offset + K * src_idx + lane_idx], &arg, src_idx);
103+ Reducer<scalar_t , REDUCE>::update (&val, src[num_cols * src_idx + lane_idx], &arg, src_idx);
115104
116105 if (arg_out != nullptr )
117106 Reducer<scalar_t , REDUCE>::write (
@@ -124,71 +113,29 @@ __global__ void segment_csr_broadcast_kernel(
124113// ! \todo expand index
125114template <typename scalar_t , ReductionType REDUCE>
126115int32_t segment_csr_launch (
127- const scalar_t * src, const std::vector<int32_t > & src_size, const int64_t * indptr,
128- const std::vector<int32_t > & indptr_size, const scalar_t * base,
129- std::tuple<scalar_t *, int64_t *> out, cudaStream_t stream)
116+ const scalar_t * src, int32_t num_rows, int32_t num_cols, const int64_t * indptr,
117+ int32_t indptr_size, std::tuple<scalar_t *, int64_t *> out, cudaStream_t stream)
130118{
131- if (src_size.size () < indptr_size.size ()) return -1 ;
132-
133- if (!std::equal (indptr_size.begin (), indptr_size.end () - 1 , src_size.begin ())) return -1 ;
134-
135- auto dim = indptr_size.size () - 1 ;
136-
137- auto _mul = [](int a, int b) { return a * b; };
138- auto src_numel = std::accumulate (src_size.begin (), src_size.end (), 1 , _mul);
139- auto indptr_numel = std::accumulate (indptr_size.begin (), indptr_size.end (), 1 , _mul);
140- auto out_numel = src_numel / src_size[dim] * std::max<int32_t >(indptr_size[dim] - 1 , 0 );
141-
142- cudaMemcpyAsync (
143- std::get<0 >(out), base, sizeof (scalar_t ) * out_numel, cudaMemcpyDeviceToDevice, stream);
119+ if (num_rows < 0 || num_cols < 0 || indptr_size < 0 ) return -1 ;
144120
121+ auto num_segments = std::max<int32_t >(indptr_size - 1 , 0 );
122+ auto out_numel = static_cast <size_t >(num_segments) * static_cast <size_t >(num_cols);
145123 if ((REDUCE == ReductionType::MIN || REDUCE == ReductionType::MAX) && std::get<1 >(out) != nullptr )
146124 fill_kernel<int64_t >
147- <<<BLOCKS(1 , out_numel), THREADS, 0 , stream>>> (std::get<1 >(out), out_numel, src_size[dim]);
148-
149- if (src_numel == 0 ) return 0 ;
150-
151- auto N = max (indptr_size[dim] - 1 , 0 ) * (indptr_numel / indptr_size[dim]);
152- auto K = out_numel / N;
153- auto E = src_size[dim];
154- int64_t * indptr_size_dev;
155- cudaMallocAsync (&indptr_size_dev, sizeof (int64_t ) * indptr_size.size (), stream);
156- cudaMemcpyAsync (
157- indptr_size_dev, indptr_size.data (), sizeof (int64_t ) * indptr_size.size (),
158- cudaMemcpyHostToDevice, stream);
159-
160- if (K == 1 )
161- segment_csr_kernel<scalar_t , REDUCE, 1 ><<<BLOCKS(32 , N), THREADS, 0 , stream>>> (
162- src, indptr, indptr_size_dev, indptr_size.size (), std::get<0 >(out), std::get<1 >(out), N, E);
163- else
164- segment_csr_broadcast_kernel<scalar_t , REDUCE><<<BLOCKS(1 , N * K), THREADS, 0 , stream>>> (
165- src, indptr, indptr_size_dev, indptr_size.size (), std::get<0 >(out), std::get<1 >(out), N, K,
166- E);
125+ <<<BLOCKS(1 , out_numel), THREADS, 0 , stream>>> (std::get<1 >(out), out_numel, num_rows);
167126
168- cudaFreeAsync (indptr_size_dev, stream);
169- return 0 ;
170- }
127+ if (num_segments == 0 || num_cols == 0 ) return 0 ;
171128
172- template <typename scalar_t , ReductionType REDUCE>
173- int32_t segment_csr_launch (
174- const scalar_t * src, const std::vector<int32_t > & src_size, const int64_t * indptr,
175- const std::vector<int32_t > & indptr_size, std::tuple<scalar_t *, int64_t *> out,
176- cudaStream_t stream)
177- {
178- auto dim = indptr_size.size () - 1 ;
179- auto src_numel =
180- std::accumulate (src_size.begin (), src_size.end (), 1 , [](int a, int b) { return a * b; });
181- auto out_numel = src_numel / src_size[dim] * std::max<int32_t >(indptr_size[dim] - 1 , 0 );
129+ fill_kernel<scalar_t ><<<BLOCKS(1 , out_numel), THREADS, 0 , stream>>> (
130+ std::get<0 >(out), out_numel, static_cast <scalar_t >(0 ));
182131
183- scalar_t * base;
184- cudaMallocAsync (&base, sizeof (scalar_t ) * out_numel, stream);
185- fill_kernel<scalar_t ><<<BLOCKS(1 , out_numel), THREADS, 0 , stream>>> (base, out_numel, (scalar_t )0 );
186-
187- auto status =
188- segment_csr_launch<scalar_t , REDUCE>(src, src_size, indptr, indptr_size, base, out, stream);
189- if (status != 0 ) return status;
190-
191- cudaFreeAsync (base, stream);
132+ if (num_cols == 1 )
133+ segment_csr_kernel<scalar_t , REDUCE, 1 ><<<BLOCKS(32 , num_segments), THREADS, 0 , stream>>> (
134+ src, indptr, std::get<0 >(out), std::get<1 >(out), num_segments);
135+ else
136+ segment_csr_broadcast_kernel<scalar_t , REDUCE>
137+ <<<BLOCKS(1 , num_segments * num_cols), THREADS, 0 , stream>>> (
138+ src, indptr, std::get<0 >(out), std::get<1 >(out), num_segments, num_cols);
192139 return 0 ;
193140}
194141
0 commit comments