Skip to content

Commit b64440c

Browse files
authored
Merge pull request #480 from nR3D/sycl
[SYCL] Improved SYCL particle sorting
2 parents c1c3f47 + 8ea1507 commit b64440c

File tree

3 files changed

+212
-165
lines changed

3 files changed

+212
-165
lines changed

src/shared/particle_neighborhood/neighborhood.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Neighborhood& Neighborhood::operator=(const NeighborhoodDevice &device) {
3434
}
3535

3636
NeighborhoodDevice::NeighborhoodDevice() : current_size_(allocateDeviceData<size_t>(1)),
37-
allocated_size_(Dimensions == 2 ? 28 : 68),
37+
allocated_size_(Dimensions == 2 ? 28 : 82),
3838
j_(allocateDeviceData<size_t>(allocated_size_)),
3939
W_ij_(allocateDeviceData<DeviceReal>(allocated_size_)),
4040
dW_ijV_j_(allocateDeviceData<DeviceReal>(allocated_size_)),

src/shared/particles/particle_sorting.cpp

Lines changed: 182 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ void ParticleSorting::sortingParticleData(size_t *begin, size_t size, execution:
9292
if (!index_sorting_device_variables_)
9393
index_sorting_device_variables_ = allocateDeviceData<size_t>(size);
9494

95-
sort_by_key(begin, index_sorting_device_variables_, size, execution::executionQueue.getQueue(), 256, 4, [](size_t *data, size_t idx)
96-
{ return idx; }).wait();
95+
device_radix_sorting.sort_by_key(begin, index_sorting_device_variables_, size, execution::executionQueue.getQueue(), 512, 4).wait();
9796

9897
move_sortable_particle_device_data_(index_sorting_device_variables_, size);
9998

10099
updateSortedDeviceId();
101100
}
102101
//=================================================================================================//
103-
size_t split_count(bool bit, sycl::nd_item<1> &item)
102+
template <class ValueType>
103+
SYCL_EXTERNAL size_t DeviceRadixSort<ValueType>::split_count(bool bit, sycl::nd_item<1> &item)
104104
{
105105
const auto group_range = item.get_local_range().size();
106106
const size_t id = item.get_local_id();
@@ -117,14 +117,191 @@ size_t split_count(bool bit, sycl::nd_item<1> &item)
117117
return bit ? true_before - 1 + false_totals : id - true_before;
118118
}
119119
//=================================================================================================//
120-
size_t get_digit(size_t key, size_t d, size_t radix_bits)
120+
template <class ValueType>
121+
size_t DeviceRadixSort<ValueType>::get_digit(size_t key, size_t d, size_t radix_bits)
121122
{
122123
return (key >> d * radix_bits) & ((1ul << radix_bits) - 1);
123124
}
124125
//=================================================================================================//
125-
size_t get_bit(size_t key, size_t b)
126+
template <class ValueType>
127+
size_t DeviceRadixSort<ValueType>::get_bit(size_t key, size_t b)
126128
{
127129
return (key >> b) & 1;
128130
}
129131
//=================================================================================================//
132+
template <class ValueType>
133+
size_t DeviceRadixSort<ValueType>::find_max_element(const size_t *data, size_t size, size_t identity)
134+
{
135+
size_t result = identity;
136+
auto &sycl_queue = execution::executionQueue.getQueue();
137+
{
138+
sycl::buffer<size_t> buffer_result(&result, 1);
139+
sycl_queue.submit([&](sycl::handler &cgh)
140+
{
141+
auto reduction_operator = sycl::reduction(buffer_result, cgh, sycl::maximum<>());
142+
cgh.parallel_for(execution::executionQueue.getUniformNdRange(size), reduction_operator,
143+
[=](sycl::nd_item<1> item, auto& reduction) {
144+
if(item.get_global_id() < size)
145+
reduction.combine(data[item.get_global_linear_id()]);
146+
}); })
147+
.wait_and_throw();
148+
}
149+
return result;
150+
}
151+
//=================================================================================================//
152+
template <class ValueType>
153+
void DeviceRadixSort<ValueType>::resize(size_t data_size, size_t radix_bits, size_t workgroup_size)
154+
{
155+
data_size_ = data_size;
156+
radix_bits_ = radix_bits;
157+
workgroup_size_ = workgroup_size;
158+
uniform_case_masking_ = data_size % workgroup_size;
159+
uniform_global_size_ = uniform_case_masking_ ? (data_size / workgroup_size + 1) * workgroup_size : data_size;
160+
kernel_range_ = {uniform_global_size_, workgroup_size};
161+
workgroups_ = kernel_range_.get_group_range().size();
162+
163+
radix_ = 1ul << radix_bits; // radix = 2^b
164+
165+
sycl::range<2> buckets_column_major_range = {radix_, workgroups_}, buckets_row_major_range = {workgroups_, radix_};
166+
// Each entry contains global number of digits with the same value
167+
// Column-major, so buckets offsets can be computed by just applying a scan over it
168+
global_buckets_ = std::make_unique<sycl::buffer<size_t, 2>>(buckets_column_major_range);
169+
// Each entry contains global number of digits with the same and lower values
170+
global_buckets_offsets_ = std::make_unique<sycl::buffer<size_t, 2>>(buckets_column_major_range);
171+
local_buckets_offsets_buffer_ = std::make_unique<sycl::buffer<size_t, 2>>(buckets_row_major_range); // save state of local accessor
172+
data_swap_buffer_ = std::make_unique<sycl::buffer<SortablePair>>(uniform_global_size_); // temporary memory for swapping
173+
// Keep extra values to be swapped when kernel range has been made uniform
174+
uniform_extra_swap_buffer_ = std::make_unique<sycl::buffer<SortablePair>>(uniform_global_size_ - data_size);
175+
}
176+
//=================================================================================================//
177+
template <class ValueType>
178+
sycl::event DeviceRadixSort<ValueType>::sort_by_key(size_t *keys, ValueType *data, size_t data_size, sycl::queue &queue, size_t workgroup_size, size_t radix_bits)
179+
{
180+
if(data_size_ != data_size || radix_bits_ != radix_bits || workgroup_size_ != workgroup_size)
181+
resize(data_size, radix_bits, workgroup_size);
182+
183+
// Largest key, increased by 1 if the workgroup is not homogeneous with the data vector,
184+
// the new maximum will be used for those work-items out of data range, that will then be excluded once sorted
185+
const size_t max_key = find_max_element(keys, data_size, 0ul) + (uniform_case_masking_ ? 1 : 0);
186+
const size_t bits_max_key = std::floor(std::log2(max_key)) + 1.0; // bits needed to represent max_key
187+
const size_t length = max_key ? bits_max_key / radix_bits + (bits_max_key % radix_bits ? 1 : 0) : 1; // max number of radix digits
188+
189+
sycl::event sort_event{};
190+
for (int digit = 0; digit < length; ++digit)
191+
{
192+
193+
auto buckets_event = queue.submit([&](sycl::handler &cgh)
194+
{
195+
cgh.depends_on(sort_event);
196+
auto data_swap_acc = data_swap_buffer_->get_access(cgh, sycl::write_only, sycl::no_init);
197+
auto local_buckets = sycl::local_accessor<size_t>(radix_, cgh);
198+
auto local_output = sycl::local_accessor<SortablePair>(kernel_range_.get_local_range(), cgh);
199+
auto global_buckets_accessor = global_buckets_->get_access(cgh, sycl::read_write, sycl::no_init);
200+
auto local_buckets_offsets_accessor = local_buckets_offsets_buffer_->get_access(cgh, sycl::write_only,
201+
sycl::no_init);
202+
203+
cgh.parallel_for(kernel_range_, [=, radix=radix_](sycl::nd_item<1> item) {
204+
const size_t workgroup = item.get_group_linear_id(),
205+
global_id = item.get_global_id();
206+
207+
SortablePair number;
208+
// Initialize key-data pair, with masking in case of non-homogeneous data_size/workgroup_size
209+
if(global_id < data_size)
210+
number = {keys[global_id],
211+
// Give possibility to initialize data here to avoid calling
212+
// another kernel before sort_by_key in order to initialize it
213+
digit ? data[global_id] : global_id};
214+
else // masking extra indexes
215+
// Initialize exceeding values to the largest key considered
216+
number.first = (1 << bits_max_key) - 1; // max key for given number of bits
217+
218+
219+
// Locally sort digit with split primitive
220+
auto radix_digit = get_digit(number.first, digit, radix_bits);
221+
auto rank = split_count(get_bit(radix_digit, 0), item); // sorting first bit
222+
local_output[rank] = number;
223+
for (size_t b = 1; b < radix_bits; ++b) { // sorting remaining bits
224+
item.barrier(sycl::access::fence_space::local_space);
225+
number = local_output[item.get_local_id()];
226+
radix_digit = get_digit(number.first, digit, radix_bits);
227+
228+
rank = split_count(get_bit(radix_digit, b), item);
229+
local_output[rank] = number;
230+
}
231+
232+
// Initialize local buckets to zero, since they are uninitialized by default
233+
for (size_t r = 0; r < radix; ++r)
234+
local_buckets[r] = 0;
235+
236+
item.barrier(sycl::access::fence_space::local_space);
237+
{
238+
sycl::atomic_ref<size_t, sycl::memory_order_relaxed, sycl::memory_scope_work_group,
239+
sycl::access::address_space::local_space> bucket_r{local_buckets[radix_digit]};
240+
++bucket_r;
241+
item.barrier(sycl::access::fence_space::local_space);
242+
}
243+
244+
// Save local buckets to global memory, with one row per work-group (in column-major order)
245+
for (size_t r = 0; r < radix; ++r)
246+
global_buckets_accessor[r][workgroup] = local_buckets[r];
247+
248+
if(global_id < data_size)
249+
data_swap_acc[workgroup_size * workgroup + rank] = number; // save local sorting back to data
250+
251+
// Compute local buckets offsets
252+
size_t *begin = local_buckets.get_pointer(), *end = begin + radix,
253+
*outBegin = local_buckets_offsets_accessor.get_pointer().get() + workgroup * radix;
254+
sycl::joint_exclusive_scan(item.get_group(), begin, end, outBegin, sycl::plus<size_t>{});
255+
}); });
256+
257+
// Global synchronization to make sure that all locally computed buckets have been copied to global memory
258+
259+
sycl::event scan_event = queue.submit([&](sycl::handler &cgh) {
260+
cgh.depends_on(buckets_event);
261+
auto global_buckets_accessor = global_buckets_->get_access(cgh, sycl::read_only);
262+
auto global_buckets_offsets_accessor = global_buckets_offsets_->get_access(cgh, sycl::write_only);
263+
cgh.parallel_for(kernel_range_, [=](sycl::nd_item<1> item) {
264+
// Compute global buckets offsets
265+
if(item.get_group_linear_id() == 0) {
266+
size_t *begin = global_buckets_accessor.get_pointer(), *end = begin + global_buckets_accessor.size();
267+
sycl::joint_exclusive_scan(item.get_group(), begin, end,
268+
global_buckets_offsets_accessor.get_pointer(), sycl::plus<size_t>{});
269+
}
270+
});
271+
});
272+
273+
sort_event = queue.submit([&](sycl::handler &cgh)
274+
{
275+
cgh.depends_on(scan_event);
276+
auto data_swap_acc = data_swap_buffer_->get_access(cgh, sycl::read_only);
277+
auto global_buckets_accessor = global_buckets_->get_access(cgh, sycl::read_only);
278+
auto global_buckets_offsets_accessor = global_buckets_offsets_->get_access(cgh, sycl::read_write);
279+
auto local_buckets_offsets_accessor = local_buckets_offsets_buffer_->get_access(cgh, sycl::read_only);
280+
cgh.parallel_for(kernel_range_, [=](sycl::nd_item<1> item) {
281+
// Compute global buckets offsets
282+
size_t *begin = global_buckets_accessor.get_pointer(), *end = begin + global_buckets_accessor.size();
283+
sycl::joint_exclusive_scan(item.get_group(), begin, end,
284+
global_buckets_offsets_accessor.get_pointer(), sycl::plus<size_t>{});
285+
286+
// Mask only relevant indexes. All max_keys added to homogenize the computations
287+
// should be owned by work-items with global_id >= data_size
288+
if(item.get_global_id() < data_size) {
289+
// Retrieve position and sorted data from swap memory
290+
const size_t rank = item.get_local_id(), workgroup = item.get_group_linear_id();
291+
const SortablePair number = data_swap_acc[workgroup_size * workgroup + rank];
292+
const size_t radix_digit = get_digit(number.first, digit, radix_bits);
293+
294+
// Compute sorted position based on global and local buckets
295+
const size_t data_offset = global_buckets_offsets_accessor[radix_digit][workgroup] + rank -
296+
local_buckets_offsets_accessor[workgroup][radix_digit];
297+
298+
// Copy to original data pointers
299+
keys[data_offset] = number.first;
300+
data[data_offset] = number.second;
301+
}
302+
}); });
303+
}
304+
return sort_event;
305+
}
306+
//=================================================================================================//
130307
} // namespace SPH

0 commit comments

Comments
 (0)