|
| 1 | +// SPDX-FileCopyrightText: 2025 VTT Technical Research Centre of Finland Ltd |
| 2 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| 3 | + |
| 4 | +#include <cassert> |
| 5 | +#include <iostream> |
| 6 | +#include <numeric> |
| 7 | +#include <stdexcept> |
| 8 | +#include <string> |
| 9 | +#include <thrust/device_vector.h> |
| 10 | +#include <vector> |
| 11 | + |
| 12 | +using namespace pfc; |
| 13 | + |
| 14 | +TEST_CASE("Construct empty SparseVector", "[SparseVector (CUDA)]") { |
| 15 | + auto vector = sparsevector::create<double, CUDATag>(3); |
| 16 | + REQUIRE(get_size(vector) == 3); |
| 17 | + REQUIRE(!on_host(vector)); |
| 18 | +} |
| 19 | + |
| 20 | +TEST_CASE("Update SparseVector", "[SparseVector (CUDA)]") { |
| 21 | + auto vector = sparsevector::create<double, CUDATag>(3); |
| 22 | + set_index(vector, {2, 3, 4}); |
| 23 | + set_data(vector, {1.0, 2.0, 3.0}); |
| 24 | + auto index = get_index(vector); |
| 25 | + auto data = get_data(vector); |
| 26 | + REQUIRE(index[0] == 2); |
| 27 | + REQUIRE(data[0] == 1.0); |
| 28 | +} |
| 29 | + |
| 30 | +TEST_CASE("Construct filled SparseVector", "[SparseVector (CUDA)]") { |
| 31 | + thrust::device_vector<size_t> index = {2, 4, 6}; |
| 32 | + thrust::device_vector<double> data = {1.0, 2.0, 3.0}; |
| 33 | + auto vector = sparsevector::create(index, data); |
| 34 | + REQUIRE(get_size(vector) == 3); |
| 35 | +} |
| 36 | + |
| 37 | +TEST_CASE("Gather data from source", "[SparseVector (CUDA)]") { |
| 38 | + thrust::device_vector<size_t> index = {0, 1, 3}; |
| 39 | + auto vector = sparsevector::create<double>(index); |
| 40 | + thrust::device_vector<double> big_data = {1.0, 2.0, 3.0, 4.0}; |
| 41 | + gather(vector, big_data); |
| 42 | + REQUIRE(get_data(vector) == {1.0, 2.0, 4.0}); |
| 43 | +} |
| 44 | + |
| 45 | +TEST_CASE("Scatter data to destination", "[SparseVector (CUDA)]") { |
| 46 | + thrust::device_vector<size_t> indices = {0, 1, 3}; |
| 47 | + auto vector = sparsevector::create<double, CUDATag>({0, 1, 3}, {1.0, 2.0, 4.0}); |
| 48 | + thrust::device_vector<double> big_data = {0.0, 0.0, 0.0, 0.0}; |
| 49 | + scatter(vector, big_data); |
| 50 | + REQUIRE(big_data == thrust::device_vector<double>({1.0, 2.0, 0.0, 4.0})); |
| 51 | +} |
0 commit comments