@@ -92,15 +92,15 @@ void ParticleSorting::sortingParticleData(size_t *begin, size_t size, execution:
9292 if (!index_sorting_device_variables_)
9393 index_sorting_device_variables_ = allocateDeviceData<size_t >(size);
9494
95- sort_by_key (begin, index_sorting_device_variables_, size, execution::executionQueue.getQueue (), 256 , 4 , [](size_t *data, size_t idx)
96- { return idx; }).wait ();
95+ device_radix_sorting.sort_by_key (begin, index_sorting_device_variables_, size, execution::executionQueue.getQueue (), 512 , 4 ).wait ();
9796
9897 move_sortable_particle_device_data_ (index_sorting_device_variables_, size);
9998
10099 updateSortedDeviceId ();
101100}
102101// =================================================================================================//
103- size_t split_count (bool bit, sycl::nd_item<1 > &item)
102+ template <class ValueType >
103+ SYCL_EXTERNAL size_t DeviceRadixSort<ValueType>::split_count(bool bit, sycl::nd_item<1 > &item)
104104{
105105 const auto group_range = item.get_local_range ().size ();
106106 const size_t id = item.get_local_id ();
@@ -117,14 +117,191 @@ size_t split_count(bool bit, sycl::nd_item<1> &item)
117117 return bit ? true_before - 1 + false_totals : id - true_before;
118118}
119119// =================================================================================================//
120- size_t get_digit (size_t key, size_t d, size_t radix_bits)
120+ template <class ValueType >
121+ size_t DeviceRadixSort<ValueType>::get_digit(size_t key, size_t d, size_t radix_bits)
121122{
122123 return (key >> d * radix_bits) & ((1ul << radix_bits) - 1 );
123124}
124125// =================================================================================================//
125- size_t get_bit (size_t key, size_t b)
126+ template <class ValueType >
127+ size_t DeviceRadixSort<ValueType>::get_bit(size_t key, size_t b)
126128{
127129 return (key >> b) & 1 ;
128130}
129131// =================================================================================================//
132+ template <class ValueType >
133+ size_t DeviceRadixSort<ValueType>::find_max_element(const size_t *data, size_t size, size_t identity)
134+ {
135+ size_t result = identity;
136+ auto &sycl_queue = execution::executionQueue.getQueue ();
137+ {
138+ sycl::buffer<size_t > buffer_result (&result, 1 );
139+ sycl_queue.submit ([&](sycl::handler &cgh)
140+ {
141+ auto reduction_operator = sycl::reduction (buffer_result, cgh, sycl::maximum<>());
142+ cgh.parallel_for (execution::executionQueue.getUniformNdRange (size), reduction_operator,
143+ [=](sycl::nd_item<1 > item, auto & reduction) {
144+ if (item.get_global_id () < size)
145+ reduction.combine (data[item.get_global_linear_id ()]);
146+ }); })
147+ .wait_and_throw ();
148+ }
149+ return result;
150+ }
151+ // =================================================================================================//
152+ template <class ValueType >
153+ void DeviceRadixSort<ValueType>::resize(size_t data_size, size_t radix_bits, size_t workgroup_size)
154+ {
155+ data_size_ = data_size;
156+ radix_bits_ = radix_bits;
157+ workgroup_size_ = workgroup_size;
158+ uniform_case_masking_ = data_size % workgroup_size;
159+ uniform_global_size_ = uniform_case_masking_ ? (data_size / workgroup_size + 1 ) * workgroup_size : data_size;
160+ kernel_range_ = {uniform_global_size_, workgroup_size};
161+ workgroups_ = kernel_range_.get_group_range ().size ();
162+
163+ radix_ = 1ul << radix_bits; // radix = 2^b
164+
165+ sycl::range<2 > buckets_column_major_range = {radix_, workgroups_}, buckets_row_major_range = {workgroups_, radix_};
166+ // Each entry contains global number of digits with the same value
167+ // Column-major, so buckets offsets can be computed by just applying a scan over it
168+ global_buckets_ = std::make_unique<sycl::buffer<size_t , 2 >>(buckets_column_major_range);
169+ // Each entry contains global number of digits with the same and lower values
170+ global_buckets_offsets_ = std::make_unique<sycl::buffer<size_t , 2 >>(buckets_column_major_range);
171+ local_buckets_offsets_buffer_ = std::make_unique<sycl::buffer<size_t , 2 >>(buckets_row_major_range); // save state of local accessor
172+ data_swap_buffer_ = std::make_unique<sycl::buffer<SortablePair>>(uniform_global_size_); // temporary memory for swapping
173+ // Keep extra values to be swapped when kernel range has been made uniform
174+ uniform_extra_swap_buffer_ = std::make_unique<sycl::buffer<SortablePair>>(uniform_global_size_ - data_size);
175+ }
176+ // =================================================================================================//
177+ template <class ValueType >
178+ sycl::event DeviceRadixSort<ValueType>::sort_by_key(size_t *keys, ValueType *data, size_t data_size, sycl::queue &queue, size_t workgroup_size, size_t radix_bits)
179+ {
180+ if (data_size_ != data_size || radix_bits_ != radix_bits || workgroup_size_ != workgroup_size)
181+ resize (data_size, radix_bits, workgroup_size);
182+
183+ // Largest key, increased by 1 if the workgroup is not homogeneous with the data vector,
184+ // the new maximum will be used for those work-items out of data range, that will then be excluded once sorted
185+ const size_t max_key = find_max_element (keys, data_size, 0ul ) + (uniform_case_masking_ ? 1 : 0 );
186+ const size_t bits_max_key = std::floor (std::log2 (max_key)) + 1.0 ; // bits needed to represent max_key
187+ const size_t length = max_key ? bits_max_key / radix_bits + (bits_max_key % radix_bits ? 1 : 0 ) : 1 ; // max number of radix digits
188+
189+ sycl::event sort_event{};
190+ for (int digit = 0 ; digit < length; ++digit)
191+ {
192+
193+ auto buckets_event = queue.submit ([&](sycl::handler &cgh)
194+ {
195+ cgh.depends_on (sort_event);
196+ auto data_swap_acc = data_swap_buffer_->get_access (cgh, sycl::write_only, sycl::no_init);
197+ auto local_buckets = sycl::local_accessor<size_t >(radix_, cgh);
198+ auto local_output = sycl::local_accessor<SortablePair>(kernel_range_.get_local_range (), cgh);
199+ auto global_buckets_accessor = global_buckets_->get_access (cgh, sycl::read_write, sycl::no_init);
200+ auto local_buckets_offsets_accessor = local_buckets_offsets_buffer_->get_access (cgh, sycl::write_only,
201+ sycl::no_init);
202+
203+ cgh.parallel_for (kernel_range_, [=, radix=radix_](sycl::nd_item<1 > item) {
204+ const size_t workgroup = item.get_group_linear_id (),
205+ global_id = item.get_global_id ();
206+
207+ SortablePair number;
208+ // Initialize key-data pair, with masking in case of non-homogeneous data_size/workgroup_size
209+ if (global_id < data_size)
210+ number = {keys[global_id],
211+ // Give possibility to initialize data here to avoid calling
212+ // another kernel before sort_by_key in order to initialize it
213+ digit ? data[global_id] : global_id};
214+ else // masking extra indexes
215+ // Initialize exceeding values to the largest key considered
216+ number.first = (1 << bits_max_key) - 1 ; // max key for given number of bits
217+
218+
219+ // Locally sort digit with split primitive
220+ auto radix_digit = get_digit (number.first , digit, radix_bits);
221+ auto rank = split_count (get_bit (radix_digit, 0 ), item); // sorting first bit
222+ local_output[rank] = number;
223+ for (size_t b = 1 ; b < radix_bits; ++b) { // sorting remaining bits
224+ item.barrier (sycl::access::fence_space::local_space);
225+ number = local_output[item.get_local_id ()];
226+ radix_digit = get_digit (number.first , digit, radix_bits);
227+
228+ rank = split_count (get_bit (radix_digit, b), item);
229+ local_output[rank] = number;
230+ }
231+
232+ // Initialize local buckets to zero, since they are uninitialized by default
233+ for (size_t r = 0 ; r < radix; ++r)
234+ local_buckets[r] = 0 ;
235+
236+ item.barrier (sycl::access::fence_space::local_space);
237+ {
238+ sycl::atomic_ref<size_t , sycl::memory_order_relaxed, sycl::memory_scope_work_group,
239+ sycl::access::address_space::local_space> bucket_r{local_buckets[radix_digit]};
240+ ++bucket_r;
241+ item.barrier (sycl::access::fence_space::local_space);
242+ }
243+
244+ // Save local buckets to global memory, with one row per work-group (in column-major order)
245+ for (size_t r = 0 ; r < radix; ++r)
246+ global_buckets_accessor[r][workgroup] = local_buckets[r];
247+
248+ if (global_id < data_size)
249+ data_swap_acc[workgroup_size * workgroup + rank] = number; // save local sorting back to data
250+
251+ // Compute local buckets offsets
252+ size_t *begin = local_buckets.get_pointer (), *end = begin + radix,
253+ *outBegin = local_buckets_offsets_accessor.get_pointer ().get () + workgroup * radix;
254+ sycl::joint_exclusive_scan (item.get_group (), begin, end, outBegin, sycl::plus<size_t >{});
255+ }); });
256+
257+ // Global synchronization to make sure that all locally computed buckets have been copied to global memory
258+
259+ sycl::event scan_event = queue.submit ([&](sycl::handler &cgh) {
260+ cgh.depends_on (buckets_event);
261+ auto global_buckets_accessor = global_buckets_->get_access (cgh, sycl::read_only);
262+ auto global_buckets_offsets_accessor = global_buckets_offsets_->get_access (cgh, sycl::write_only);
263+ cgh.parallel_for (kernel_range_, [=](sycl::nd_item<1 > item) {
264+ // Compute global buckets offsets
265+ if (item.get_group_linear_id () == 0 ) {
266+ size_t *begin = global_buckets_accessor.get_pointer (), *end = begin + global_buckets_accessor.size ();
267+ sycl::joint_exclusive_scan (item.get_group (), begin, end,
268+ global_buckets_offsets_accessor.get_pointer (), sycl::plus<size_t >{});
269+ }
270+ });
271+ });
272+
273+ sort_event = queue.submit ([&](sycl::handler &cgh)
274+ {
275+ cgh.depends_on (scan_event);
276+ auto data_swap_acc = data_swap_buffer_->get_access (cgh, sycl::read_only);
277+ auto global_buckets_accessor = global_buckets_->get_access (cgh, sycl::read_only);
278+ auto global_buckets_offsets_accessor = global_buckets_offsets_->get_access (cgh, sycl::read_write);
279+ auto local_buckets_offsets_accessor = local_buckets_offsets_buffer_->get_access (cgh, sycl::read_only);
280+ cgh.parallel_for (kernel_range_, [=](sycl::nd_item<1 > item) {
281+ // Compute global buckets offsets
282+ size_t *begin = global_buckets_accessor.get_pointer (), *end = begin + global_buckets_accessor.size ();
283+ sycl::joint_exclusive_scan (item.get_group (), begin, end,
284+ global_buckets_offsets_accessor.get_pointer (), sycl::plus<size_t >{});
285+
286+ // Mask only relevant indexes. All max_keys added to homogenize the computations
287+ // should be owned by work-items with global_id >= data_size
288+ if (item.get_global_id () < data_size) {
289+ // Retrieve position and sorted data from swap memory
290+ const size_t rank = item.get_local_id (), workgroup = item.get_group_linear_id ();
291+ const SortablePair number = data_swap_acc[workgroup_size * workgroup + rank];
292+ const size_t radix_digit = get_digit (number.first , digit, radix_bits);
293+
294+ // Compute sorted position based on global and local buckets
295+ const size_t data_offset = global_buckets_offsets_accessor[radix_digit][workgroup] + rank -
296+ local_buckets_offsets_accessor[workgroup][radix_digit];
297+
298+ // Copy to original data pointers
299+ keys[data_offset] = number.first ;
300+ data[data_offset] = number.second ;
301+ }
302+ }); });
303+ }
304+ return sort_event;
305+ }
306+ // =================================================================================================//
130307} // namespace SPH
0 commit comments