|
101 | 101 |
|
102 | 102 | #include <cub/cub.cuh> |
103 | 103 |
|
104 | | -#include <thrust/adjacent_difference.h> |
105 | | -#include <thrust/device_ptr.h> |
106 | | -#include <thrust/execution_policy.h> |
107 | | -#include <thrust/scan.h> |
108 | | -#include <thrust/scatter.h> |
109 | | -#include <thrust/sequence.h> |
110 | | -#include <thrust/sort.h> |
111 | | -#include <thrust/unique.h> |
| 104 | +#include <algorithm> |
| 105 | +#include <cstdint> |
112 | 106 |
|
113 | | -std::int64_t unique( |
114 | | - const std::int64_t * input, std::int64_t * unique, std::int64_t * inverse_indices, |
115 | | - std::int64_t * unique_counts, void * workspace, std::size_t num_input_elements, |
116 | | - std::size_t unique_workspace_size, cudaStream_t stream) |
| 107 | +namespace |
117 | 108 | { |
118 | | - auto policy = thrust::cuda::par.on(stream); |
119 | | - |
120 | | - thrust::device_ptr<std::int64_t> idx_ptr(reinterpret_cast<std::int64_t *>(workspace)); |
121 | 109 |
|
122 | | - thrust::sequence(policy, idx_ptr, idx_ptr + num_input_elements + 1, 0); |
| 110 | +constexpr int kThreadsPerBlock = 256; |
123 | 111 |
|
124 | | - std::int64_t * sorted_input = unique; |
125 | | - std::int64_t * sorted_idx = thrust::raw_pointer_cast(idx_ptr) + 2 * num_input_elements + 1; |
126 | | - std::int64_t * inv_loc_ptr = thrust::raw_pointer_cast(idx_ptr) + 3 * num_input_elements + 1; |
| 112 | +std::size_t align_up(const std::size_t size, const std::size_t alignment) |
| 113 | +{ |
| 114 | + return ((size + alignment - 1U) / alignment) * alignment; |
| 115 | +} |
127 | 116 |
|
128 | | - void * sort_workspace_ptr = |
129 | | - reinterpret_cast<void *>(thrust::raw_pointer_cast(idx_ptr) + 4 * num_input_elements + 1); |
| 117 | +std::size_t query_unique_temp_storage_size(const std::size_t num_elements) |
| 118 | +{ |
| 119 | + std::size_t sort_temp_size = 0; |
| 120 | + std::size_t scan_temp_size = 0; |
| 121 | + std::size_t unique_temp_size = 0; |
130 | 122 |
|
131 | | - auto sort_workspace_size = |
132 | | - unique_workspace_size - (4 * num_input_elements + 1) * sizeof(std::int64_t); |
| 123 | + std::int64_t * int64_nullptr = nullptr; |
| 124 | + std::int32_t * int32_nullptr = nullptr; |
133 | 125 |
|
134 | 126 | cub::DeviceRadixSort::SortPairs( |
135 | | - sort_workspace_ptr, sort_workspace_size, input, sorted_input, thrust::raw_pointer_cast(idx_ptr), |
136 | | - sorted_idx, num_input_elements, 0, 64, stream); |
137 | | - |
138 | | - auto equal = [] __device__(const std::int64_t a, const std::int64_t b) { return a == b; }; |
| 127 | + nullptr, sort_temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr, |
| 128 | + num_elements, 0, 64, nullptr); |
| 129 | + cub::DeviceScan::InclusiveSum( |
| 130 | + nullptr, scan_temp_size, int32_nullptr, int32_nullptr, num_elements, nullptr); |
| 131 | + cub::DeviceSelect::UniqueByKey( |
| 132 | + nullptr, unique_temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr, |
| 133 | + int64_nullptr, num_elements, nullptr); |
| 134 | + |
| 135 | + return std::max(sort_temp_size, std::max(scan_temp_size, unique_temp_size)); |
| 136 | +} |
139 | 137 |
|
140 | | - auto not_equal = [] __device__(const std::int64_t a, const std::int64_t b) { return a != b; }; |
| 138 | +__global__ void mark_run_starts( |
| 139 | + const std::int64_t * sorted_input, std::int32_t * run_ids, const std::size_t num_input_elements) |
| 140 | +{ |
| 141 | + const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 142 | + if (index >= num_input_elements) { |
| 143 | + return; |
| 144 | + } |
141 | 145 |
|
142 | | - thrust::adjacent_difference( |
143 | | - policy, sorted_input, sorted_input + num_input_elements, inv_loc_ptr, not_equal); |
| 146 | + run_ids[index] = (index == 0U || sorted_input[index] != sorted_input[index - 1U]) ? 1 : 0; |
| 147 | +} |
144 | 148 |
|
145 | | - cudaMemsetAsync(inv_loc_ptr, 0, sizeof(int64_t), stream); |
| 149 | +__global__ void fill_iota(std::int64_t * output, const std::size_t num_input_elements) |
| 150 | +{ |
| 151 | + const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 152 | + if (index >= num_input_elements) { |
| 153 | + return; |
| 154 | + } |
146 | 155 |
|
147 | | - thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_input_elements, inv_loc_ptr); |
148 | | - thrust::scatter( |
149 | | - policy, inv_loc_ptr, inv_loc_ptr + num_input_elements, sorted_idx, inverse_indices); |
| 156 | + output[index] = static_cast<std::int64_t>(index); |
| 157 | +} |
150 | 158 |
|
151 | | - std::int64_t num_out; |
| 159 | +__global__ void scatter_inverse_indices( |
| 160 | + const std::int64_t * sorted_idx, const std::int32_t * run_ids, std::int64_t * inverse_indices, |
| 161 | + const std::size_t num_input_elements) |
| 162 | +{ |
| 163 | + const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 164 | + if (index >= num_input_elements) { |
| 165 | + return; |
| 166 | + } |
152 | 167 |
|
153 | | - std::int64_t * range_ptr = idx_ptr.get(); |
154 | | - num_out = |
155 | | - thrust::unique_by_key(policy, sorted_input, sorted_input + num_input_elements, range_ptr, equal) |
156 | | - .first - |
157 | | - sorted_input; |
| 168 | + inverse_indices[sorted_idx[index]] = static_cast<std::int64_t>(run_ids[index] - 1); |
| 169 | +} |
158 | 170 |
|
159 | | - cudaMemcpyAsync( |
160 | | - range_ptr + num_out * sizeof(int64_t), &num_input_elements, sizeof(std::int64_t), |
161 | | - cudaMemcpyHostToDevice, stream); |
| 171 | +__global__ void write_unique_offset_sentinel( |
| 172 | + std::int64_t * unique_offsets, const std::int64_t * num_unique, |
| 173 | + const std::size_t num_input_elements) |
| 174 | +{ |
| 175 | + unique_offsets[*num_unique] = static_cast<std::int64_t>(num_input_elements); |
| 176 | +} |
162 | 177 |
|
163 | | - thrust::adjacent_difference(policy, range_ptr + 1, range_ptr + num_out + 1, unique_counts); |
| 178 | +__global__ void write_unique_counts( |
| 179 | + const std::int64_t * unique_offsets, const std::int64_t * num_unique, std::int64_t * unique_counts) |
| 180 | +{ |
| 181 | + const auto index = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 182 | + if (index >= static_cast<std::size_t>(*num_unique)) { |
| 183 | + return; |
| 184 | + } |
164 | 185 |
|
165 | | - return num_out; |
| 186 | + unique_counts[index] = unique_offsets[index + 1U] - unique_offsets[index]; |
166 | 187 | } |
167 | 188 |
|
168 | | -std::size_t get_unique_workspace_size(std::size_t num_elements) |
| 189 | +} // namespace |
| 190 | + |
| 191 | +std::int64_t unique( |
| 192 | + const std::int64_t * input, std::int64_t * unique, std::int64_t * inverse_indices, |
| 193 | + std::int64_t * unique_counts, void * workspace, std::size_t num_input_elements, |
| 194 | + std::size_t unique_workspace_size, cudaStream_t stream) |
169 | 195 | { |
170 | | - std::size_t temp_size = 0; |
171 | | - std::int64_t * int64_nullptr = nullptr; |
| 196 | + if (num_input_elements == 0U) { |
| 197 | + return 0; |
| 198 | + } |
172 | 199 |
|
| 200 | + const auto temp_storage_size = get_unique_temp_storage_size(num_input_elements); |
| 201 | + const auto scratch_offset = align_up(temp_storage_size, alignof(std::int64_t)); |
| 202 | + auto * scratch = reinterpret_cast<char *>(workspace) + scratch_offset; |
| 203 | + |
| 204 | + auto * input_positions = reinterpret_cast<std::int64_t *>(scratch); |
| 205 | + auto * sorted_input = input_positions + num_input_elements; |
| 206 | + auto * unique_offsets = sorted_input + num_input_elements; |
| 207 | + auto * num_unique_d = unique_offsets + num_input_elements + 1U; |
| 208 | + auto * run_ids = reinterpret_cast<std::int32_t *>(num_unique_d + 1U); |
| 209 | + |
| 210 | + const auto num_blocks = |
| 211 | + static_cast<unsigned int>((num_input_elements + kThreadsPerBlock - 1U) / kThreadsPerBlock); |
| 212 | + |
| 213 | + fill_iota<<<num_blocks, kThreadsPerBlock, 0, stream>>>(input_positions, num_input_elements); |
173 | 214 | cub::DeviceRadixSort::SortPairs( |
174 | | - nullptr, temp_size, int64_nullptr, int64_nullptr, int64_nullptr, int64_nullptr, num_elements, 0, |
175 | | - 64, nullptr); |
| 215 | + workspace, temp_storage_size, input, sorted_input, input_positions, unique_offsets, |
| 216 | + num_input_elements, 0, 64, stream); |
| 217 | + |
| 218 | + mark_run_starts<<<num_blocks, kThreadsPerBlock, 0, stream>>>( |
| 219 | + sorted_input, run_ids, num_input_elements); |
| 220 | + cub::DeviceScan::InclusiveSum( |
| 221 | + workspace, temp_storage_size, run_ids, run_ids, num_input_elements, stream); |
| 222 | + |
| 223 | + scatter_inverse_indices<<<num_blocks, kThreadsPerBlock, 0, stream>>>( |
| 224 | + unique_offsets, run_ids, inverse_indices, num_input_elements); |
176 | 225 |
|
177 | | - return temp_size + (4 * num_elements + 1) * sizeof(std::int64_t); |
| 226 | + cub::DeviceSelect::UniqueByKey( |
| 227 | + workspace, temp_storage_size, sorted_input, input_positions, unique, unique_offsets, |
| 228 | + num_unique_d, num_input_elements, stream); |
| 229 | + |
| 230 | + write_unique_offset_sentinel<<<1, 1, 0, stream>>>( |
| 231 | + unique_offsets, num_unique_d, num_input_elements); |
| 232 | + write_unique_counts<<<num_blocks, kThreadsPerBlock, 0, stream>>>( |
| 233 | + unique_offsets, num_unique_d, unique_counts); |
| 234 | + |
| 235 | + std::int64_t num_out = 0; |
| 236 | + cudaMemcpyAsync(&num_out, num_unique_d, sizeof(std::int64_t), cudaMemcpyDeviceToHost, stream); |
| 237 | + cudaStreamSynchronize(stream); |
| 238 | + return num_out; |
| 239 | +} |
| 240 | + |
| 241 | +std::size_t get_unique_temp_storage_size(std::size_t num_elements) |
| 242 | +{ |
| 243 | + return query_unique_temp_storage_size(num_elements); |
| 244 | +} |
| 245 | + |
| 246 | +std::size_t get_unique_workspace_size(std::size_t num_elements) |
| 247 | +{ |
| 248 | + const auto temp_size = query_unique_temp_storage_size(num_elements); |
| 249 | + const auto scratch_offset = align_up(temp_size, alignof(std::int64_t)); |
| 250 | + return scratch_offset + (3 * num_elements + 2U) * sizeof(std::int64_t) + |
| 251 | + num_elements * sizeof(std::int32_t); |
178 | 252 | } |
0 commit comments