Skip to content

Commit 2d1113a

Browse files
committed
improve CSR-to-COO conversion
1 parent f0c5d97 commit 2d1113a

8 files changed

Lines changed: 178 additions & 23 deletions

File tree

common/cuda_hip/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ set(CUDA_HIP_SOURCES
33
base/batch_multi_vector_kernels.cpp
44
base/device_matrix_data_kernels.cpp
55
base/index_set_kernels.cpp
6+
components/format_conversion_kernels.cpp
67
components/prefix_sum_kernels.cpp
78
distributed/assembly_kernels.cpp
89
distributed/index_map_kernels.cpp
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include "core/components/format_conversion_kernels.hpp"
6+
7+
#include <thrust/for_each.h>
8+
#include <thrust/iterator/counting_iterator.h>
9+
#include <thrust/iterator/transform_iterator.h>
10+
11+
#include <ginkgo/core/base/types.hpp>
12+
13+
#include "common/cuda_hip/base/thrust.hpp"
14+
#include "common/cuda_hip/components/bitvector.hpp"
15+
16+
17+
namespace gko {
18+
namespace kernels {
19+
namespace GKO_DEVICE_NAMESPACE {
20+
namespace components {
21+
22+
23+
template <typename IndexType, typename RowPtrType>
24+
void convert_ptrs_to_idxs(std::shared_ptr<const DefaultExecutor> exec,
25+
const RowPtrType* ptrs, size_type num_blocks,
26+
IndexType* idxs)
27+
{
28+
const auto policy = thrust_policy(exec);
29+
const auto num_elements = exec->copy_val_to_host(ptrs + num_blocks);
30+
// transform the ptrs to a bitvector in unary delta encoding, i.e.
31+
// every row with n elements is encoded as 1 0 ... n times ... 0
32+
auto it = thrust::make_transform_iterator(
33+
thrust::make_counting_iterator(IndexType{}),
34+
[ptrs] __device__(IndexType i) -> RowPtrType { return ptrs[i] + i; });
35+
auto bv = bitvector::from_sorted_indices(exec, it, num_blocks,
36+
num_blocks + num_elements);
37+
auto device_bv = bv.device_view();
38+
thrust::for_each_n(policy, thrust::make_counting_iterator(IndexType{}),
39+
num_blocks + num_elements,
40+
[device_bv, idxs] __device__(RowPtrType i) {
41+
if (!device_bv.get(i)) {
42+
auto rank = device_bv.rank(i);
43+
idxs[i - rank] = rank - 1;
44+
}
45+
});
46+
}
47+
48+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS32);
49+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS64);
50+
51+
52+
} // namespace components
53+
} // namespace GKO_DEVICE_NAMESPACE
54+
} // namespace kernels
55+
} // namespace gko

common/unified/components/format_conversion_kernels.cpp

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -16,27 +16,6 @@ namespace GKO_DEVICE_NAMESPACE {
1616
namespace components {
1717

1818

19-
template <typename IndexType, typename RowPtrType>
20-
void convert_ptrs_to_idxs(std::shared_ptr<const DefaultExecutor> exec,
21-
const RowPtrType* ptrs, size_type num_blocks,
22-
IndexType* idxs)
23-
{
24-
run_kernel(
25-
exec,
26-
[] GKO_KERNEL(auto block, auto ptrs, auto idxs) {
27-
auto begin = ptrs[block];
28-
auto end = ptrs[block + 1];
29-
for (auto i = begin; i < end; i++) {
30-
idxs[i] = block;
31-
}
32-
},
33-
num_blocks, ptrs, idxs);
34-
}
35-
36-
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS32);
37-
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS64);
38-
39-
4019
template <typename IndexType, typename RowPtrType>
4120
void convert_idxs_to_ptrs(std::shared_ptr<const DefaultExecutor> exec,
4221
const IndexType* idxs, size_type num_idxs,

dpcpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ target_sources(
2020
base/scoped_device_id.dp.cpp
2121
base/timer.dp.cpp
2222
base/version.dp.cpp
23+
components/format_conversion_kernels.dp.cpp
2324
components/prefix_sum_kernels.dp.cpp
2425
distributed/assembly_kernels.dp.cpp
2526
distributed/index_map_kernels.dp.cpp
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include <oneapi/dpl/iterator>
6+
7+
#include "core/components/format_conversion_kernels.hpp"
8+
9+
#include <ginkgo/core/base/types.hpp>
10+
11+
#include "common/unified/base/kernel_launch.hpp"
12+
#include "core/components/fill_array_kernels.hpp"
13+
#include "dpcpp/base/onedpl.hpp"
14+
#include "dpcpp/components/bitvector.dp.hpp"
15+
16+
17+
namespace gko {
18+
namespace kernels {
19+
namespace dpcpp {
20+
namespace components {
21+
22+
23+
template <typename IndexType, typename RowPtrType>
24+
void convert_ptrs_to_idxs(std::shared_ptr<const DefaultExecutor> exec,
25+
const RowPtrType* ptrs, size_type num_blocks,
26+
IndexType* idxs)
27+
{
28+
auto policy = onedpl_policy(exec);
29+
const auto num_elements = exec->copy_val_to_host(ptrs + num_blocks);
30+
// transform the ptrs to a bitvector in unary delta encoding, i.e.
31+
// every row with n elements is encoded as 1 0 ... n times ... 0
32+
auto it = oneapi::dpl::make_transform_iterator(
33+
oneapi::dpl::counting_iterator<IndexType>{0},
34+
[ptrs](IndexType i) -> RowPtrType { return ptrs[i] + i; });
35+
auto bv = bitvector::from_sorted_indices(exec, it, num_blocks,
36+
num_blocks + num_elements);
37+
auto device_bv = bv.device_view();
38+
exec->get_queue()->submit([&](sycl::handler& cgh) {
39+
cgh.parallel_for(num_blocks + num_elements, [=](sycl::id<1> i) {
40+
if (!device_bv.get(i)) {
41+
auto rank = device_bv.rank(i);
42+
idxs[i - rank] = rank - 1;
43+
}
44+
});
45+
});
46+
}
47+
48+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS32);
49+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS64);
50+
51+
52+
} // namespace components
53+
} // namespace dpcpp
54+
} // namespace kernels
55+
} // namespace gko

omp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ target_sources(
1414
base/index_set_kernels.cpp
1515
base/scoped_device_id.cpp
1616
base/version.cpp
17+
components/format_conversion_kernels.cpp
1718
components/prefix_sum_kernels.cpp
1819
distributed/assembly_kernels.cpp
1920
distributed/index_map_kernels.cpp
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include "core/components/format_conversion_kernels.hpp"
6+
7+
#include <ginkgo/core/base/types.hpp>
8+
9+
10+
namespace gko {
11+
namespace kernels {
12+
namespace omp {
13+
namespace components {
14+
15+
16+
template <typename IndexType, typename RowPtrType>
17+
void convert_ptrs_to_idxs(std::shared_ptr<const DefaultExecutor> exec,
18+
const RowPtrType* ptrs, size_type num_blocks,
19+
IndexType* idxs)
20+
{
21+
#pragma omp parallel for
22+
for (size_type block = 0; block < num_blocks; block++) {
23+
auto begin = ptrs[block];
24+
auto end = ptrs[block + 1];
25+
for (auto i = begin; i < end; i++) {
26+
idxs[i] = block;
27+
}
28+
}
29+
}
30+
31+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS32);
32+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS64);
33+
34+
35+
} // namespace components
36+
} // namespace omp
37+
} // namespace kernels
38+
} // namespace gko

test/components/format_conversion_kernels.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -10,6 +10,7 @@
1010

1111
#include <gtest/gtest.h>
1212

13+
#include "core/base/index_range.hpp"
1314
#include "core/test/utils.hpp"
1415
#include "test/utils/common_fixture.hpp"
1516

@@ -67,6 +68,30 @@ TYPED_TEST(FormatConversion, ConvertsEmptyPtrsToIdxs)
6768
}
6869

6970

71+
TYPED_TEST(FormatConversion, ConvertPtrsToIdxsImbalanced)
72+
{
73+
using index_type = typename TestFixture::index_type;
74+
std::vector<index_type> ptrs{0};
75+
std::vector<index_type> idxs;
76+
std::geometric_distribution<int> size_dist{0.01};
77+
for (auto i : gko::irange{10000}) {
78+
auto count = size_dist(this->rand);
79+
ptrs.push_back(ptrs.back() + count);
80+
idxs.insert(idxs.end(), count, i);
81+
}
82+
gko::array<index_type> ptr_array{this->exec, ptrs.begin(), ptrs.end()};
83+
gko::array<index_type> idx_array{this->exec, idxs.begin(), idxs.end()};
84+
auto ref_idx_array = idx_array;
85+
idx_array.fill(-1);
86+
87+
gko::kernels::GKO_DEVICE_NAMESPACE::components::convert_ptrs_to_idxs(
88+
this->exec, ptr_array.get_const_data(), ptrs.size() - 1,
89+
idx_array.get_data());
90+
91+
GKO_ASSERT_ARRAY_EQ(idx_array, ref_idx_array);
92+
}
93+
94+
7095
TYPED_TEST(FormatConversion, ConvertPtrsToIdxs)
7196
{
7297
auto ref_idxs = this->idxs;

0 commit comments

Comments
 (0)