2424#include " autoware/scatter_ops/utils.cuh"
2525
2626#include < algorithm>
27- #include < numeric>
2827#include < string>
29- #include < vector>
3028
3129#define THREADS 256
3230#define BLOCKS (TB, N ) (TB * N + THREADS - 1 ) / THREADS
3331#define FULL_MASK 0xffffffff
34- #define SEGMENT_CSR_LAUNCH_INSTANTIATION_TR (T, R ) \
35- template int32_t segment_csr_launch<T, R>( \
36- const T * src_in, const std::vector<int32_t > & src_size_in, const int64_t * indptr_in, \
37- const std::vector<int32_t > & indptr_size_in, T * reduced_values_out, \
38- int64_t * arg_indices_out, cudaStream_t stream_in); \
39- template int32_t segment_csr_launch<T, R>( \
40- const T * src_in, const std::vector<int32_t > & src_size_in, const int64_t * indptr_in, \
41- const std::vector<int32_t > & indptr_size_in, const T * base_values_in, T * reduced_values_out, \
42- int64_t * arg_indices_out, cudaStream_t stream_in);
32+ #define SEGMENT_CSR_LAUNCH_INSTANTIATION_TR (T, R ) \
33+ template int32_t segment_csr_launch<T, R>( \
34+ const T * src_in, int32_t num_rows_in, int32_t num_cols_in, const int64_t * indptr_in, \
35+ int32_t indptr_size_in, T * reduced_values_out, int64_t * arg_indices_out, \
36+ cudaStream_t stream_in);
4337#define SEGMENT_CSR_LAUNCH_INSTANTIATION (T ) \
4438 SEGMENT_CSR_LAUNCH_INSTANTIATION_TR (T, ReductionType::SUM ) \
4539 SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, ReductionType::MEAN ) \
4842 SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, ReductionType::MIN ) \
4943 SEGMENT_CSR_LAUNCH_INSTANTIATION_TR(T, ReductionType::MAX )
5044
51- namespace
52- {
53- size_t get_output_numel (
54- const std::vector<int32_t > & src_size, const std::vector<int32_t > & indptr_size, size_t dim)
55- {
56- size_t out_numel = static_cast <size_t >(std::max<int32_t >(indptr_size[dim] - 1 , 0 ));
57- for (size_t i = 0 ; i < src_size.size (); ++i) {
58- if (i != dim) out_numel *= static_cast <size_t >(src_size[i]);
59- }
60- return out_numel;
61- }
62- } // namespace
63-
6445template <typename scalar_t, ReductionType REDUCE, int TB>
6546__global__ void segment_csr_kernel(
66- const scalar_t * src , const int64_t * indptr, const int64_t * indptr_size, int32_t indptr_dim ,
67- scalar_t * out, int64_t * arg_out , size_t N, size_t E )
47+ const scalar_t * src_in , const int64_t * indptr_in, scalar_t * reduced_values_out ,
48+ int64_t * arg_indices_out , size_t num_segments_in )
6849{
6950 // Each warp processes exactly `32/TB` rows and aggregates all row values
7051 // via a parallel reduction.
7152
7253 int thread_idx = blockIdx .x * blockDim .x + threadIdx .x ;
7354 int row_idx = thread_idx / TB ;
7455 int lane_idx = thread_idx & (TB - 1 );
75- if (row_idx >= N ) return ;
56+ if (row_idx >= num_segments_in ) return ;
7657
77- int offset = indptr_to_offset (indptr_size, indptr_dim, row_idx);
78- int64_t row_start = __ldg (indptr + offset);
79- int64_t row_end = __ldg (indptr + offset + 1 );
58+ int64_t row_start = __ldg (indptr_in + row_idx);
59+ int64_t row_end = __ldg (indptr_in + row_idx + 1 );
8060
8161 scalar_t val = Reducer<scalar_t , REDUCE >::init ();
82- int64_t arg, arg_tmp;
62+ int64_t arg{ 0 } , arg_tmp{ 0 } ;
8363
84- offset = (row_idx / (indptr_size[indptr_dim - 1 ] - 1 )) * E;
8564 for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB )
86- Reducer<scalar_t , REDUCE >::update (&val, src[offset + src_idx], &arg, src_idx);
65+ Reducer<scalar_t , REDUCE >::update (&val, src_in[ src_idx], &arg, src_idx);
8766
8867#pragma unroll
8968 for (int i = TB / 2 ; i > 0 ; i /= 2 ) {
@@ -94,124 +73,78 @@ __global__ void segment_csr_kernel(
9473 }
9574
9675 if (lane_idx == 0 )
97- if (arg_out != nullptr )
76+ if (arg_indices_out != nullptr )
9877 Reducer<scalar_t , REDUCE >::write (
99- out + row_idx, val, arg_out + row_idx, arg, row_end - row_start);
78+ reduced_values_out + row_idx, val, arg_indices_out + row_idx, arg, row_end - row_start);
10079 else
101- Reducer<scalar_t , REDUCE >::write (out + row_idx, val, row_end - row_start);
80+ Reducer<scalar_t , REDUCE >::write (reduced_values_out + row_idx, val, row_end - row_start);
10281}
10382
10483template <typename scalar_t , ReductionType REDUCE >
10584__global__ void segment_csr_broadcast_kernel (
106- const scalar_t * src , const int64_t * indptr, const int64_t * indptr_size, int32_t indptr_dim ,
107- scalar_t * out, int64_t * arg_out , size_t N , size_t K, size_t E )
85+ const scalar_t * src_in , const int64_t * indptr_in, scalar_t * reduced_values_out ,
86+ int64_t * arg_indices_out , size_t num_segments_in , size_t num_cols_in )
10887{
10988 // Each thread processes exactly one row. It turned out that is more
11089 // efficient than using shared memory due to avoiding synchronization
11190 // barriers.
11291
11392 int thread_idx = blockIdx .x * blockDim .x + threadIdx .x ;
114- int row_idx = thread_idx / K ;
115- int lane_idx = thread_idx % K ;
116- if (thread_idx >= N * K ) return ;
93+ int row_idx = thread_idx / num_cols_in ;
94+ int lane_idx = thread_idx % num_cols_in ;
95+ if (thread_idx >= num_segments_in * num_cols_in ) return ;
11796
118- int offset = indptr_to_offset (indptr_size, indptr_dim, row_idx);
119- int64_t row_start = __ldg (indptr + offset);
120- int64_t row_end = __ldg (indptr + offset + 1 );
97+ int64_t row_start = __ldg (indptr_in + row_idx);
98+ int64_t row_end = __ldg (indptr_in + row_idx + 1 );
12199
122100 scalar_t val = Reducer<scalar_t , REDUCE >::init ();
123- int64_t arg;
101+ int64_t arg{ 0 } ;
124102
125- offset = (row_idx / (indptr_size[indptr_dim - 1 ] - 1 )) * E * K;
126103 for (int64_t src_idx = row_start; src_idx < row_end; src_idx++)
127- Reducer<scalar_t , REDUCE >::update (&val, src[offset + K * src_idx + lane_idx], &arg, src_idx);
104+ Reducer<scalar_t , REDUCE >::update (
105+ &val, src_in[num_cols_in * src_idx + lane_idx], &arg, src_idx);
128106
129- if (arg_out != nullptr )
107+ if (arg_indices_out != nullptr )
130108 Reducer<scalar_t , REDUCE >::write (
131- out + thread_idx, val, arg_out + thread_idx, arg, row_end - row_start);
109+ reduced_values_out + thread_idx, val, arg_indices_out + thread_idx, arg, row_end - row_start);
132110 else
133- Reducer<scalar_t , REDUCE >::write (out + thread_idx, val, row_end - row_start);
111+ Reducer<scalar_t , REDUCE >::write (reduced_values_out + thread_idx, val, row_end - row_start);
134112}
135113
136114// ! \todo test different devices (cudaSetDevice(src.get_device());)
137115// ! \todo expand index
138116template <typename scalar_t , ReductionType REDUCE >
139117int32_t segment_csr_launch (
140- const scalar_t * src_in, const std::vector< int32_t > & src_size_in , const int64_t * indptr_in,
141- const std::vector< int32_t > & indptr_size_in, const scalar_t * base_values_in ,
142- scalar_t * reduced_values_out, int64_t * arg_indices_out, cudaStream_t stream_in)
118+ const scalar_t * src_in, int32_t num_rows_in, int32_t num_cols_in , const int64_t * indptr_in,
119+ int32_t indptr_size_in, scalar_t * reduced_values_out, int64_t * arg_indices_out ,
120+ cudaStream_t stream_in)
143121{
144- if (indptr_size_in.empty () || src_size_in.size () < indptr_size_in.size ()) return -1 ;
145-
146- if (!std::equal (indptr_size_in.begin (), indptr_size_in.end () - 1 , src_size_in.begin ())) return -1 ;
122+ if (num_rows_in < 0 || num_cols_in < 0 || indptr_size_in < 0 ) return -1 ;
147123
148- auto dim = indptr_size_in.size () - 1 ;
149-
150- auto _mul = [](int a, int b) { return a * b; };
151- auto src_numel = std::accumulate (src_size_in.begin (), src_size_in.end (), 1 , _mul);
152- auto indptr_numel = std::accumulate (indptr_size_in.begin (), indptr_size_in.end (), 1 , _mul);
153- auto out_numel = get_output_numel (src_size_in, indptr_size_in, dim);
124+ auto num_segments = std::max<int32_t >(indptr_size_in - 1 , 0 );
125+ auto out_numel = static_cast <size_t >(num_segments) * static_cast <size_t >(num_cols_in);
154126
155127 if (out_numel == 0 ) return 0 ;
156128
157- cudaMemcpyAsync (
158- reduced_values_out, base_values_in, sizeof (scalar_t ) * out_numel, cudaMemcpyDeviceToDevice,
159- stream_in);
160-
161129 if ((REDUCE == ReductionType::MIN || REDUCE == ReductionType::MAX ) && arg_indices_out != nullptr )
162- fill_kernel<int64_t ><<<BLOCKS (1 , out_numel), THREADS , 0 , stream_in>>> (
163- arg_indices_out, out_numel, src_size_in[dim]);
164-
165- if (src_numel == 0 ) return 0 ;
166-
167- auto N = max (indptr_size_in[dim] - 1 , 0 ) * (indptr_numel / indptr_size_in[dim]);
168- auto K = out_numel / N;
169- auto E = src_size_in[dim];
170- int64_t * indptr_size_dev;
171- cudaMallocAsync (&indptr_size_dev, sizeof (int64_t ) * indptr_size_in.size (), stream_in);
172- cudaMemcpyAsync (
173- indptr_size_dev, indptr_size_in.data (), sizeof (int64_t ) * indptr_size_in.size (),
174- cudaMemcpyHostToDevice, stream_in);
175-
176- if (K == 1 )
177- segment_csr_kernel<scalar_t , REDUCE , 1 ><<<BLOCKS (32 , N), THREADS , 0 , stream_in>>> (
178- src_in, indptr_in, indptr_size_dev, indptr_size_in.size (), reduced_values_out,
179- arg_indices_out, N, E);
180- else
181- segment_csr_broadcast_kernel<scalar_t , REDUCE ><<<BLOCKS (1 , N * K), THREADS , 0 , stream_in>>> (
182- src_in, indptr_in, indptr_size_dev, indptr_size_in.size (), reduced_values_out,
183- arg_indices_out, N, K, E);
130+ fill_kernel<int64_t >
131+ <<<BLOCKS (1 , out_numel), THREADS , 0 , stream_in>>> (arg_indices_out, out_numel, num_rows_in);
184132
185- cudaFreeAsync (indptr_size_dev, stream_in);
186- return 0 ;
187- }
188-
189- template <typename scalar_t , ReductionType REDUCE >
190- int32_t segment_csr_launch (
191- const scalar_t * src_in, const std::vector<int32_t > & src_size_in, const int64_t * indptr_in,
192- const std::vector<int32_t > & indptr_size_in, scalar_t * reduced_values_out,
193- int64_t * arg_indices_out, cudaStream_t stream_in)
194- {
195- if (indptr_size_in.empty () || src_size_in.size () < indptr_size_in.size ()) return -1 ;
196-
197- if (!std::equal (indptr_size_in.begin (), indptr_size_in.end () - 1 , src_size_in.begin ())) return -1 ;
198-
199- auto dim = indptr_size_in.size () - 1 ;
200- auto out_numel = get_output_numel (src_size_in, indptr_size_in, dim);
201- if (out_numel == 0 ) return 0 ;
202-
203- scalar_t * base_values;
133+ scalar_t * base_values{nullptr };
204134 cudaMallocAsync (&base_values, sizeof (scalar_t ) * out_numel, stream_in);
205135 fill_kernel<scalar_t ><<<BLOCKS (1 , out_numel), THREADS , 0 , stream_in>>> (
206136 base_values, out_numel, static_cast <scalar_t >(0 ));
137+ cudaMemcpyAsync (
138+ reduced_values_out, base_values, sizeof (scalar_t ) * out_numel, cudaMemcpyDeviceToDevice,
139+ stream_in);
207140
208- auto status = segment_csr_launch< scalar_t , REDUCE >(
209- src_in, src_size_in, indptr_in, indptr_size_in, base_values, reduced_values_out,
210- arg_indices_out, stream_in );
211- if (status != 0 ) {
212- cudaFreeAsync (base_values, stream_in);
213- return status;
214- }
141+ if (num_cols_in == 1 )
142+ segment_csr_kernel< scalar_t , REDUCE , 1 > <<< BLOCKS ( 32 , num_segments), THREADS , 0 , stream_in>>> (
143+ src_in, indptr_in, reduced_values_out, arg_indices_out, num_segments );
144+ else
145+ segment_csr_broadcast_kernel< scalar_t , REDUCE >
146+ <<< BLOCKS ( 1 , num_segments * num_cols_in), THREADS , 0 , stream_in>>> (
147+ src_in, indptr_in, reduced_values_out, arg_indices_out, num_segments, num_cols_in);
215148
216149 cudaFreeAsync (base_values, stream_in);
217150 return 0 ;
0 commit comments