102102#include < cub/cub.cuh>
103103
104104#include < algorithm>
105+ #include < cassert>
106+ #include < cstddef>
105107#include < cstdint>
106108
107109namespace
@@ -114,7 +116,37 @@ std::size_t align_up(const std::size_t size, const std::size_t alignment)
114116 return ((size + alignment - 1U ) / alignment) * alignment;
115117}
116118
117- std::size_t query_unique_temp_storage_size (const std::size_t num_elements)
119+ struct UniqueWorkspaceLayout
120+ {
121+ void * cub_temp_storage;
122+ std::size_t cub_temp_storage_size;
123+
124+ // One int64 scratch block, then one int32 scratch block:
125+ //
126+ // workspace
127+ // +-----------------------------+ 0
128+ // | CUB temp storage | cub_temp_storage_size bytes
129+ // +-----------------------------+ align_up(cub_temp_storage_size, alignof(int64))
130+ // | input_positions | num_input_elements int64
131+ // | sorted_input | num_input_elements int64
132+ // | sorted_input_positions | num_input_elements + 1 int64
133+ // | num_unique | 1 int32
134+ // | run_ids | num_input_elements int32
135+ // +-----------------------------+
136+ //
137+ // `sorted_input_positions` is reused as `unique_offsets_inout` after inverse indices are
138+ // scattered.
139+ // Its extra slot stores a sentinel end offset at index `num_input_elements`. Using the fixed last
140+ // slot keeps the sentinel outside CUB's compact `[0, num_unique)` output range and separate from
141+ // the `num_unique` scalar.
142+ std::int64_t * input_positions;
143+ std::int64_t * sorted_input;
144+ std::int64_t * sorted_input_positions;
145+ std::int32_t * num_unique;
146+ std::int32_t * run_ids;
147+ };
148+
149+ std::size_t query_unique_temp_storage_size (const std::size_t num_elements_in)
118150{
119151 std::size_t sort_temp_size = 0 ;
120152 std::size_t scan_temp_size = 0 ;
@@ -125,129 +157,160 @@ std::size_t query_unique_temp_storage_size(const std::size_t num_elements)
125157
126158 cub::DeviceRadixSort::SortPairs (
127159 nullptr , sort_temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr,
128- num_elements , 0 , 64 , nullptr );
160+ num_elements_in , 0 , 64 , nullptr );
129161 cub::DeviceScan::InclusiveSum (
130- nullptr , scan_temp_size, int32_nullptr, int32_nullptr, num_elements , nullptr );
162+ nullptr , scan_temp_size, int32_nullptr, int32_nullptr, num_elements_in , nullptr );
131163 cub::DeviceSelect::UniqueByKey (
132164 nullptr , unique_temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr,
133- int64_nullptr, num_elements , nullptr );
165+ int32_nullptr, num_elements_in , nullptr );
134166
135167 return std::max (sort_temp_size, std::max (scan_temp_size, unique_temp_size));
136168}
137169
170+ UniqueWorkspaceLayout make_unique_workspace_layout (
171+ void * workspace_inout, const std::size_t num_input_elements_in,
172+ const std::size_t cub_temp_storage_size_in)
173+ {
174+ const auto scratch_offset = align_up (cub_temp_storage_size_in, alignof (std::int64_t ));
175+ auto * scratch = reinterpret_cast <char *>(workspace_inout) + scratch_offset;
176+
177+ auto * input_positions = reinterpret_cast <std::int64_t *>(scratch);
178+ auto * sorted_input = input_positions + num_input_elements_in;
179+ auto * sorted_input_positions = sorted_input + num_input_elements_in;
180+ auto * num_unique =
181+ reinterpret_cast <std::int32_t *>(sorted_input_positions + num_input_elements_in + 1U );
182+ auto * run_ids = reinterpret_cast <std::int32_t *>(num_unique + 1U );
183+
184+ return UniqueWorkspaceLayout{workspace_inout, cub_temp_storage_size_in, input_positions,
185+ sorted_input, sorted_input_positions, num_unique,
186+ run_ids};
187+ }
188+
138189__global__ void mark_run_starts (
139- const std::int64_t * sorted_input, std::int32_t * run_ids, const std::size_t num_input_elements)
190+ const std::int64_t * sorted_input_in, std::int32_t * run_ids_out,
191+ const std::size_t num_input_elements_in)
140192{
141193 const auto index = static_cast <std::size_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
142- if (index >= num_input_elements ) {
194+ if (index >= num_input_elements_in ) {
143195 return ;
144196 }
145197
146- run_ids[index] = (index == 0U || sorted_input[index] != sorted_input[index - 1U ]) ? 1 : 0 ;
198+ run_ids_out[index] =
199+ (index == 0U || sorted_input_in[index] != sorted_input_in[index - 1U ]) ? 1 : 0 ;
147200}
148201
149- __global__ void fill_iota (std::int64_t * output , const std::size_t num_input_elements )
202+ __global__ void fill_iota (std::int64_t * output_out , const std::size_t num_input_elements_in )
150203{
151204 const auto index = static_cast <std::size_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
152- if (index >= num_input_elements ) {
205+ if (index >= num_input_elements_in ) {
153206 return ;
154207 }
155208
156- output [index] = static_cast <std::int64_t >(index);
209+ output_out [index] = static_cast <std::int64_t >(index);
157210}
158211
159212__global__ void scatter_inverse_indices (
160- const std::int64_t * sorted_idx , const std::int32_t * run_ids, std:: int64_t * inverse_indices ,
161- const std::size_t num_input_elements )
213+ const std::int64_t * sorted_idx_in , const std::int32_t * run_ids_in ,
214+ std:: int64_t * inverse_indices_out, const std::size_t num_input_elements_in )
162215{
163216 const auto index = static_cast <std::size_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
164- if (index >= num_input_elements ) {
217+ if (index >= num_input_elements_in ) {
165218 return ;
166219 }
167220
168- inverse_indices[sorted_idx [index]] = static_cast <std::int64_t >(run_ids [index] - 1 );
221+ inverse_indices_out[sorted_idx_in [index]] = static_cast <std::int64_t >(run_ids_in [index] - 1 );
169222}
170223
171224__global__ void write_unique_offset_sentinel (
172- std::int64_t * unique_offsets, const std::int64_t * num_unique,
173- const std::size_t num_input_elements)
225+ std::int64_t * unique_offsets_inout, const std::size_t num_input_elements_in)
174226{
175- unique_offsets[*num_unique ] = static_cast <std::int64_t >(num_input_elements );
227+ unique_offsets_inout[num_input_elements_in ] = static_cast <std::int64_t >(num_input_elements_in );
176228}
177229
178230__global__ void write_unique_counts (
179- const std::int64_t * unique_offsets , const std::int64_t * num_unique ,
180- std::int64_t * unique_counts )
231+ const std::int64_t * unique_offsets_in , const std::int32_t * num_unique_in ,
232+ std::int64_t * unique_counts_out, const std:: size_t num_input_elements_in )
181233{
182234 const auto index = static_cast <std::size_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
183- if (index >= static_cast <std::size_t >(*num_unique )) {
235+ if (index >= static_cast <std::size_t >(*num_unique_in )) {
184236 return ;
185237 }
186238
187- unique_counts[index] = unique_offsets[index + 1U ] - unique_offsets[index];
239+ const auto next_offset_index =
240+ (index + 1U == static_cast <std::size_t >(*num_unique_in)) ? num_input_elements_in : index + 1U ;
241+ unique_counts_out[index] = unique_offsets_in[next_offset_index] - unique_offsets_in[index];
188242}
189243
190244} // namespace
191245
192246std::int64_t unique (
193- const std::int64_t * input , std::int64_t * unique, std:: int64_t * inverse_indices ,
194- std::int64_t * unique_counts, void * workspace, std:: size_t num_input_elements ,
195- std::size_t unique_workspace_size, cudaStream_t stream )
247+ const std::int64_t * input_in , std::int64_t * unique_values_out ,
248+ std::int64_t * inverse_indices_out, std:: int64_t * unique_counts_out, void * workspace_inout ,
249+ std::size_t num_input_elements_in, std:: size_t workspace_size_in, cudaStream_t stream_in )
196250{
197- if (num_input_elements == 0U ) {
251+ if (num_input_elements_in == 0U ) {
198252 return 0 ;
199253 }
200254
201- const auto temp_storage_size = get_unique_temp_storage_size (num_input_elements);
202- const auto scratch_offset = align_up (temp_storage_size, alignof (std::int64_t ));
203- auto * scratch = reinterpret_cast <char *>(workspace) + scratch_offset;
255+ assert (workspace_size_in >= get_unique_workspace_size (num_input_elements_in));
256+ (void )workspace_size_in;
204257
205- auto * input_positions = reinterpret_cast <std::int64_t *>(scratch);
206- auto * sorted_input = input_positions + num_input_elements;
207- auto * unique_offsets = sorted_input + num_input_elements;
208- auto * num_unique_d = unique_offsets + num_input_elements + 1U ;
209- auto * run_ids = reinterpret_cast <std::int32_t *>(num_unique_d + 1U );
258+ const auto cub_temp_storage_size = get_unique_temp_storage_size (num_input_elements_in);
259+ auto layout =
260+ make_unique_workspace_layout (workspace_inout, num_input_elements_in, cub_temp_storage_size);
210261
211262 const auto num_blocks =
212- static_cast <unsigned int >((num_input_elements + kThreadsPerBlock - 1U ) / kThreadsPerBlock );
263+ static_cast <unsigned int >((num_input_elements_in + kThreadsPerBlock - 1U ) / kThreadsPerBlock );
213264
214- fill_iota<<<num_blocks, kThreadsPerBlock , 0 , stream>>> (input_positions, num_input_elements);
265+ // 1. Sort values while carrying their original input positions.
266+ fill_iota<<<num_blocks, kThreadsPerBlock , 0 , stream_in>>> (
267+ layout.input_positions , num_input_elements_in);
215268 cub::DeviceRadixSort::SortPairs (
216- workspace, temp_storage_size, input, sorted_input, input_positions, unique_offsets ,
217- num_input_elements, 0 , 64 , stream );
269+ layout. cub_temp_storage , layout. cub_temp_storage_size , input_in, layout. sorted_input ,
270+ layout. input_positions , layout. sorted_input_positions , num_input_elements_in, 0 , 64 , stream_in );
218271
219- mark_run_starts<<<num_blocks, kThreadsPerBlock , 0 , stream>>> (
220- sorted_input, run_ids, num_input_elements);
272+ // 2. Convert sorted run starts into sorted-position -> unique-index ids.
273+ mark_run_starts<<<num_blocks, kThreadsPerBlock , 0 , stream_in>>> (
274+ layout.sorted_input , layout.run_ids , num_input_elements_in);
221275 cub::DeviceScan::InclusiveSum (
222- workspace, temp_storage_size, run_ids, run_ids, num_input_elements, stream);
276+ layout.cub_temp_storage , layout.cub_temp_storage_size , layout.run_ids , layout.run_ids ,
277+ num_input_elements_in, stream_in);
223278
224- scatter_inverse_indices<<<num_blocks, kThreadsPerBlock , 0 , stream>>> (
225- unique_offsets, run_ids, inverse_indices, num_input_elements);
279+ // 3. Scatter unique ids back to original input order.
280+ scatter_inverse_indices<<<num_blocks, kThreadsPerBlock , 0 , stream_in>>> (
281+ layout.sorted_input_positions , layout.run_ids , inverse_indices_out, num_input_elements_in);
226282
283+ // 4. Compact sorted runs into unique values and each run's start offset.
284+ auto * unique_offsets_inout = layout.sorted_input_positions ;
227285 cub::DeviceSelect::UniqueByKey (
228- workspace, temp_storage_size, sorted_input, input_positions, unique, unique_offsets,
229- num_unique_d, num_input_elements, stream);
230-
231- write_unique_offset_sentinel<<<1 , 1 , 0 , stream>>> (
232- unique_offsets, num_unique_d, num_input_elements);
233- write_unique_counts<<<num_blocks, kThreadsPerBlock , 0 , stream>>> (
234- unique_offsets, num_unique_d, unique_counts);
235-
236- std::int64_t num_out = 0 ;
237- cudaMemcpyAsync (&num_out, num_unique_d, sizeof (std::int64_t ), cudaMemcpyDeviceToHost, stream);
238- cudaStreamSynchronize (stream);
239- return num_out;
286+ layout.cub_temp_storage , layout.cub_temp_storage_size , layout.sorted_input ,
287+ layout.input_positions , unique_values_out, unique_offsets_inout, layout.num_unique ,
288+ num_input_elements_in, stream_in);
289+
290+ // 5. Turn run start offsets into counts.
291+ // CUB writes each run's start offset, but not the final end offset. Store the sentinel at the
292+ // fixed extra slot so the final count can use the input length as its end boundary without
293+ // overwriting compact offsets or the selected-count scalar.
294+ write_unique_offset_sentinel<<<1 , 1 , 0 , stream_in>>> (unique_offsets_inout, num_input_elements_in);
295+ write_unique_counts<<<num_blocks, kThreadsPerBlock , 0 , stream_in>>> (
296+ unique_offsets_inout, layout.num_unique , unique_counts_out, num_input_elements_in);
297+
298+ std::int32_t num_out = 0 ;
299+ cudaMemcpyAsync (
300+ &num_out, layout.num_unique , sizeof (std::int32_t ), cudaMemcpyDeviceToHost, stream_in);
301+ cudaStreamSynchronize (stream_in);
302+ return static_cast <std::int64_t >(num_out);
240303}
241304
242- std::size_t get_unique_temp_storage_size (std::size_t num_elements )
305+ std::size_t get_unique_temp_storage_size (std::size_t num_elements_in )
243306{
244- return query_unique_temp_storage_size (num_elements );
307+ return query_unique_temp_storage_size (num_elements_in );
245308}
246309
247- std::size_t get_unique_workspace_size (std::size_t num_elements )
310+ std::size_t get_unique_workspace_size (std::size_t num_elements_in )
248311{
249- const auto temp_size = query_unique_temp_storage_size (num_elements );
312+ const auto temp_size = query_unique_temp_storage_size (num_elements_in );
250313 const auto scratch_offset = align_up (temp_size, alignof (std::int64_t ));
251- return scratch_offset + (3 * num_elements + 2U ) * sizeof (std::int64_t ) +
252- num_elements * sizeof (std::int32_t );
314+ return scratch_offset + (3 * num_elements_in + 1U ) * sizeof (std::int64_t ) +
315+ (num_elements_in + 1U ) * sizeof (std::int32_t );
253316}
0 commit comments