Skip to content

Commit 98bba9b

Browse files
committed
Add license headers
1 parent 730f264 commit 98bba9b

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

tests/test_sparsevector.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// SPDX-FileCopyrightText: 2025 VTT Technical Research Centre of Finland Ltd
2+
// SPDX-License-Identifier: AGPL-3.0-or-later
3+
4+
#include <cassert>
5+
#include <iostream>
6+
#include <numeric>
7+
#include <stdexcept>
8+
#include <string>
9+
#include <vector>
10+
11+
using namespace pfc;
12+
13+
TEST_CASE("Construct empty SparseVector", "[SparseVector (Host)]") {
14+
auto vector = sparsevector::create<double, HostTag>(3);
15+
REQUIRE(get_size(vector) == 3);
16+
REQUIRE(on_host(vector));
17+
}
18+
19+
TEST_CASE("Update SparseVector", "[SparseVector (Host)]") {
20+
auto vector = sparsevector::create<double>(3); // defaults to HostTag
21+
set_index(vector, {2, 3, 4});
22+
set_data(vector, {1.0, 2.0, 3.0});
23+
auto index = get_index(vector);
24+
auto data = get_data(vector);
25+
REQUIRE(index[0] == 2);
26+
REQUIRE(data[0] == 1.0);
27+
}
28+
29+
TEST_CASE("Get single data from SparseVector", "[SparseVector (Host)]") {
30+
auto indices = std::vector<size_t>({2, 4, 6});
31+
auto data = std::vector<double>({1.0, 2.0, 3.0});
32+
auto vector = sparsevector::create(indices, data);
33+
// Check that the buffer has the expected properties
34+
REQUIRE(get_size(vector) == 3);
35+
REQUIRE(get_index(vector, 0) == 2);
36+
REQUIRE(get_data(vector, 0) == 1.0);
37+
REQUIRE(get_entry(vector, 0) == {2, 1.0});
38+
}
39+
40+
TEST_CASE("Gather data from source", "[SparseVector (Host)]") {
41+
auto vector = sparsevector::create<double>({0, 1, 3});
42+
auto big_data = {1.0, 2.0, 3.0, 4.0};
43+
gather(vector, big_data);
44+
REQUIRE(get_entry(vector, 1) == {1, 2.0});
45+
}
46+
47+
TEST_CASE("Scatter data to destination", "[SparseVector (Host)]") {
48+
auto vector = sparsevector::create<double>({0, 1, 3});
49+
auto big_data = {1.0, 2.0, 3.0, 4.0};
50+
scatter(vector, big_data);
51+
REQUIRE(get_entry(vector, 1) == {1, 2.0});
52+
}

tests/test_sparsevector_cuda.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// SPDX-FileCopyrightText: 2025 VTT Technical Research Centre of Finland Ltd
2+
// SPDX-License-Identifier: AGPL-3.0-or-later
3+
4+
#include <cassert>
5+
#include <iostream>
6+
#include <numeric>
7+
#include <stdexcept>
8+
#include <string>
9+
#include <thrust/device_vector.h>
10+
#include <vector>
11+
12+
using namespace pfc;
13+
14+
TEST_CASE("Construct empty SparseVector", "[SparseVector (CUDA)]") {
15+
auto vector = sparsevector::create<double, CUDATag>(3);
16+
REQUIRE(get_size(vector) == 3);
17+
REQUIRE(!on_host(vector));
18+
}
19+
20+
TEST_CASE("Update SparseVector", "[SparseVector (CUDA)]") {
21+
auto vector = sparsevector::create<double, CUDATag>(3);
22+
set_index(vector, {2, 3, 4});
23+
set_data(vector, {1.0, 2.0, 3.0});
24+
auto index = get_index(vector);
25+
auto data = get_data(vector);
26+
REQUIRE(index[0] == 2);
27+
REQUIRE(data[0] == 1.0);
28+
}
29+
30+
TEST_CASE("Construct filled SparseVector", "[SparseVector (CUDA)]") {
31+
thrust::device_vector<size_t> index = {2, 4, 6};
32+
thrust::device_vector<double> data = {1.0, 2.0, 3.0};
33+
auto vector = sparsevector::create(index, data);
34+
REQUIRE(get_size(vector) == 3);
35+
}
36+
37+
TEST_CASE("Gather data from source", "[SparseVector (CUDA)]") {
38+
thrust::device_vector<size_t> index = {0, 1, 3};
39+
auto vector = sparsevector::create<double>(index);
40+
thrust::device_vector<double> big_data = {1.0, 2.0, 3.0, 4.0};
41+
gather(vector, big_data);
42+
REQUIRE(get_data(vector) == {1.0, 2.0, 4.0});
43+
}
44+
45+
TEST_CASE("Scatter data to destination", "[SparseVector (CUDA)]") {
46+
thrust::device_vector<size_t> indices = {0, 1, 3};
47+
auto vector = sparsevector::create<double, CUDATag>({0, 1, 3}, {1.0, 2.0, 4.0});
48+
thrust::device_vector<double> big_data = {0.0, 0.0, 0.0, 0.0};
49+
scatter(vector, big_data);
50+
REQUIRE(big_data == thrust::device_vector<double>({1.0, 2.0, 0.0, 4.0}));
51+
}

0 commit comments

Comments
 (0)