Skip to content

Commit 3c4661e

Browse files
committed
add OpenMP segmented prefix sum
1 parent c959035 commit 3c4661e

3 files changed

Lines changed: 188 additions & 0 deletions

File tree

omp/components/prefix_sum.hpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_OMP_COMPONENTS_PREFIX_SUM_HPP_
6+
#define GKO_OMP_COMPONENTS_PREFIX_SUM_HPP_
7+
8+
#include <algorithm>
9+
#include <iterator>
10+
#include <limits>
11+
#include <string>
12+
13+
#include <omp.h>
14+
15+
#include "core/base/allocator.hpp"
16+
#include "core/base/iterator_factory.hpp"
17+
18+
19+
namespace gko {
20+
namespace kernels {
21+
namespace omp {
22+
namespace components {
23+
24+
25+
/*
26+
* Similar to prefix_sum, only reduces within runs of the same key value (each
27+
* key run must only occur once, otherwise the scan operation is not necessarily
28+
* associaive). It also doesn't ignore the last value!
29+
* Similar to thrust::exclusive_scan_by_key
30+
*/
31+
template <typename KeyIterator, typename Iterator,
32+
typename ScanOp =
33+
std::plus<typename std::iterator_traits<Iterator>::value_type>>
34+
void segmented_prefix_sum(
35+
std::shared_ptr<const OmpExecutor> exec, KeyIterator key, Iterator it,
36+
const size_type num_entries,
37+
typename std::iterator_traits<KeyIterator>::value_type key_init = {},
38+
typename std::iterator_traits<Iterator>::value_type init = {},
39+
ScanOp op = {})
40+
{
41+
using key_type = typename std::iterator_traits<KeyIterator>::value_type;
42+
using value_type = typename std::iterator_traits<Iterator>::value_type;
43+
// the operation only makes sense for arrays of size at least 2
44+
if (num_entries < 2) {
45+
if (num_entries == 0) {
46+
return;
47+
} else {
48+
*it = init;
49+
return;
50+
}
51+
}
52+
53+
const int nthreads = omp_get_max_threads();
54+
vector<value_type> proc_sums(nthreads, init, {exec});
55+
vector<key_type> proc_first_key(nthreads, key_init, {exec});
56+
vector<key_type> proc_last_key(nthreads, key_init, {exec});
57+
const size_type def_num_witems = (num_entries - 1) / nthreads + 1;
58+
59+
#pragma omp parallel
60+
{
61+
const int thread_id = omp_get_thread_num();
62+
const size_type startidx = thread_id * def_num_witems;
63+
const size_type endidx =
64+
std::min(num_entries, (thread_id + 1) * def_num_witems);
65+
66+
auto partial_sum = init;
67+
auto cur_key = startidx < num_entries ? key[startidx] : key_init;
68+
proc_first_key[thread_id] = cur_key;
69+
for (size_type i = startidx; i < endidx; ++i) {
70+
auto value = it[i];
71+
auto new_key = key[i];
72+
if (cur_key != new_key) {
73+
partial_sum = init;
74+
cur_key = new_key;
75+
}
76+
it[i] = partial_sum;
77+
partial_sum = op(partial_sum, value);
78+
}
79+
80+
proc_sums[thread_id] = partial_sum;
81+
proc_last_key[thread_id] = cur_key;
82+
83+
#pragma omp barrier
84+
85+
#pragma omp single
86+
{
87+
for (int i = 0; i < nthreads - 1; i++) {
88+
// the next block carries over the previous partial sum
89+
// if it starts and ends with the same key as the next one
90+
if (proc_last_key[i] == proc_first_key[i + 1] &&
91+
proc_first_key[i + 1] == proc_last_key[i + 1]) {
92+
proc_sums[i + 1] = op(proc_sums[i], proc_sums[i + 1]);
93+
}
94+
}
95+
}
96+
97+
if (thread_id > 0) {
98+
for (size_type i = startidx; i < endidx; i++) {
99+
if (key[i] == proc_last_key[thread_id - 1]) {
100+
it[i] = op(it[i], proc_sums[thread_id - 1]);
101+
}
102+
}
103+
}
104+
}
105+
}
106+
107+
108+
} // namespace components
109+
} // namespace omp
110+
} // namespace kernels
111+
} // namespace gko
112+
113+
#endif // GKO_OMP_COMPONENTS_PREFIX_SUM_HPP_

omp/test/components/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ginkgo_create_omp_test(prefix_sum)

omp/test/components/prefix_sum.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include "omp/components/prefix_sum.hpp"
6+
7+
#include <algorithm>
8+
#include <iterator>
9+
#include <limits>
10+
#include <memory>
11+
#include <random>
12+
#include <type_traits>
13+
#include <vector>
14+
15+
#include <gtest/gtest.h>
16+
17+
#include <ginkgo/core/base/executor.hpp>
18+
19+
#include "core/base/index_range.hpp"
20+
#include "core/test/utils.hpp"
21+
22+
23+
template <typename T>
24+
class PrefixSum : public ::testing::Test {
25+
protected:
26+
using index_type = T;
27+
28+
PrefixSum() : exec{gko::OmpExecutor::create()}, rand(293) {}
29+
30+
std::shared_ptr<const gko::OmpExecutor> exec;
31+
std::default_random_engine rand;
32+
gko::size_type total_size;
33+
};
34+
35+
TYPED_TEST_SUITE(PrefixSum, gko::test::IndexTypes, TypenameNameGenerator);
36+
37+
38+
TYPED_TEST(PrefixSum, SegmentedPrefixSumWorks)
39+
{
40+
using index_type = typename TestFixture::index_type;
41+
const auto max_threads = omp_get_max_threads();
42+
for (int num_threads = 1; num_threads <= max_threads; num_threads++) {
43+
SCOPED_TRACE(num_threads);
44+
omp_set_num_threads(num_threads);
45+
for (int num_ranges : {10, 100, 1000}) {
46+
SCOPED_TRACE(num_ranges);
47+
// repeate multiple times for different random seeds
48+
for (int repetition : gko::irange{10}) {
49+
std::uniform_int_distribution<int> count_dist{0, 100};
50+
std::uniform_int_distribution<index_type> value_dist{-200, 200};
51+
std::vector<index_type> ref_result;
52+
std::vector<int> keys;
53+
std::vector<index_type> input;
54+
for (int i = 0; i < num_ranges; i++) {
55+
const auto start = keys.size();
56+
const auto new_count = count_dist(this->rand);
57+
keys.insert(keys.end(), new_count, i);
58+
std::generate_n(std::back_inserter(input), new_count,
59+
[&] { return value_dist(this->rand); });
60+
std::copy(input.begin() + start, input.end(),
61+
std::back_inserter(ref_result));
62+
std::exclusive_scan(
63+
ref_result.begin() + start, ref_result.end(),
64+
ref_result.begin() + start, index_type{});
65+
}
66+
67+
gko::kernels::omp::components::segmented_prefix_sum(
68+
this->exec, keys.cbegin(), input.begin(), keys.size());
69+
70+
ASSERT_EQ(input, ref_result);
71+
}
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)