diff --git a/common/cuda_hip/matrix/csr_kernels.template.cpp b/common/cuda_hip/matrix/csr_kernels.template.cpp index 7d0d9fd4fe9..3146c3756bb 100644 --- a/common/cuda_hip/matrix/csr_kernels.template.cpp +++ b/common/cuda_hip/matrix/csr_kernels.template.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include "accessor/cuda_hip_helper.hpp" diff --git a/common/cuda_hip/matrix/dense_kernels.cpp b/common/cuda_hip/matrix/dense_kernels.cpp index d40682cdc58..83ad77f8ab7 100644 --- a/common/cuda_hip/matrix/dense_kernels.cpp +++ b/common/cuda_hip/matrix/dense_kernels.cpp @@ -9,9 +9,7 @@ #include #include #include -#include #include -#include #include #include @@ -571,19 +569,19 @@ template void convert_to_hybrid(std::shared_ptr exec, matrix::view::dense source, const int64* coo_row_ptrs, - matrix::Hybrid* result) + matrix::view::hybrid result) { - const auto num_rows = result->get_size()[0]; - const auto num_cols = result->get_size()[1]; + const auto num_rows = result.size[0]; + const auto num_cols = result.size[1]; const auto ell_max_nnz_per_row = - result->get_ell_num_stored_elements_per_row(); + result.ell_part.num_stored_elements_per_row; const auto source_stride = source.stride; - const auto ell_stride = result->get_ell_stride(); - auto ell_col_idxs = result->get_ell_col_idxs(); - auto ell_values = result->get_ell_values(); - auto coo_row_idxs = result->get_coo_row_idxs(); - auto coo_col_idxs = result->get_coo_col_idxs(); - auto coo_values = result->get_coo_values(); + const auto ell_stride = result.ell_part.stride; + auto ell_col_idxs = result.ell_part.col_idxs; + auto ell_values = result.ell_part.values; + auto coo_row_idxs = result.coo_part.row_idxs; + auto coo_col_idxs = result.coo_part.col_idxs; + auto coo_values = result.coo_part.values; auto grid_dim = ceildiv(num_rows, default_block_size / config::warp_size); if (grid_dim > 0) { diff --git a/common/unified/matrix/csr_kernels.cpp b/common/unified/matrix/csr_kernels.cpp index 90de179b678..90c6a78d91c 100644 --- a/common/unified/matrix/csr_kernels.cpp +++ b/common/unified/matrix/csr_kernels.cpp @@ -188,7 +188,7 @@ template void convert_to_hybrid(std::shared_ptr exec, const matrix::Csr* source, const int64* coo_row_ptrs, - matrix::Hybrid* result) + matrix::view::hybrid result) { run_kernel( exec, @@ -218,10 +218,10 @@ void convert_to_hybrid(std::shared_ptr exec, }, source->get_size()[0], source->get_const_row_ptrs(), source->get_const_col_idxs(), source->get_const_values(), - result->get_ell_stride(), result->get_ell_num_stored_elements_per_row(), - result->get_ell_col_idxs(), result->get_ell_values(), coo_row_ptrs, - result->get_coo_row_idxs(), result->get_coo_col_idxs(), - result->get_coo_values()); + result.ell_part.stride, result.ell_part.num_stored_elements_per_row, + result.ell_part.col_idxs, result.ell_part.values, coo_row_ptrs, + result.coo_part.row_idxs, result.coo_part.col_idxs, + result.coo_part.values); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/common/unified/matrix/hybrid_kernels.cpp b/common/unified/matrix/hybrid_kernels.cpp index 8a21a2415f7..7b7d8faebc6 100644 --- a/common/unified/matrix/hybrid_kernels.cpp +++ b/common/unified/matrix/hybrid_kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -51,7 +51,7 @@ template void fill_in_matrix_data(std::shared_ptr exec, const device_matrix_data& data, const int64* row_ptrs, const int64* coo_row_ptrs, - matrix::Hybrid* result) + matrix::view::hybrid result) { using device_value_type = device_type; run_kernel( @@ -83,10 +83,10 @@ void fill_in_matrix_data(std::shared_ptr exec, }, data.get_size()[0], row_ptrs, data.get_const_values(), data.get_const_row_idxs(), data.get_const_col_idxs(), - result->get_ell_stride(), result->get_ell_num_stored_elements_per_row(), - result->get_ell_col_idxs(), result->get_ell_values(), coo_row_ptrs, - result->get_coo_row_idxs(), result->get_coo_col_idxs(), - result->get_coo_values()); + result.ell_part.stride, result.ell_part.num_stored_elements_per_row, + result.ell_part.col_idxs, result.ell_part.values, coo_row_ptrs, + result.coo_part.row_idxs, result.coo_part.col_idxs, + result.coo_part.values); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -94,14 +94,14 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( template -void convert_to_csr(std::shared_ptr exec, - const matrix::Hybrid* source, - const IndexType* ell_row_ptrs, - const IndexType* coo_row_ptrs, - matrix::Csr* result) +void convert_to_csr( + std::shared_ptr exec, + matrix::view::hybrid source, + const IndexType* ell_row_ptrs, const IndexType* coo_row_ptrs, + matrix::Csr* result) { - const auto ell = source->get_ell(); - const auto coo = source->get_coo(); + const auto ell = source.ell_part; + const auto coo = source.coo_part; // ELL is stored in column-major, so we swap row and column parameters run_kernel( exec, @@ -117,18 +117,16 @@ void convert_to_csr(std::shared_ptr exec, out_vals[out_idx] = in_vals[ell_idx]; } }, - dim<2>{ell->get_num_stored_elements_per_row(), ell->get_size()[0]}, - static_cast(ell->get_stride()), ell->get_const_col_idxs(), - ell->get_const_values(), ell_row_ptrs, coo_row_ptrs, - result->get_col_idxs(), result->get_values()); + dim<2>{ell.num_stored_elements_per_row, ell.size[0]}, + static_cast(ell.stride), ell.col_idxs, ell.values, ell_row_ptrs, + coo_row_ptrs, result->get_col_idxs(), result->get_values()); run_kernel( exec, [] GKO_KERNEL(auto idx, auto ell_row_ptrs, auto coo_row_ptrs, auto out_row_ptrs) { out_row_ptrs[idx] = ell_row_ptrs[idx] + coo_row_ptrs[idx]; }, - source->get_size()[0] + 1, ell_row_ptrs, coo_row_ptrs, - result->get_row_ptrs()); + source.size[0] + 1, ell_row_ptrs, coo_row_ptrs, result->get_row_ptrs()); run_kernel( exec, [] GKO_KERNEL(auto idx, auto in_rows, auto in_cols, auto in_vals, @@ -145,9 +143,9 @@ void convert_to_csr(std::shared_ptr exec, out_cols[out_idx] = col; out_vals[out_idx] = val; }, - coo->get_num_stored_elements(), coo->get_const_row_idxs(), - coo->get_const_col_idxs(), coo->get_const_values(), ell_row_ptrs, - coo_row_ptrs, result->get_col_idxs(), result->get_values()); + coo.num_stored_elements, coo.row_idxs, coo.col_idxs, coo.values, + ell_row_ptrs, coo_row_ptrs, result->get_col_idxs(), + result->get_values()); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index fbb7772f485..2ff447477a6 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -442,7 +443,7 @@ void Csr::convert_to( auto tmp = make_temporary_clone(exec, result); tmp->resize(this->get_size(), ell_lim, coo_nnz); exec->run(csr::make_convert_to_hybrid(this, coo_row_ptrs.get_const_data(), - tmp.get())); + tmp->get_device_view())); } diff --git a/core/matrix/csr_kernels.hpp b/core/matrix/csr_kernels.hpp index e6ad4f513ee..ed4f5385364 100644 --- a/core/matrix/csr_kernels.hpp +++ b/core/matrix/csr_kernels.hpp @@ -9,13 +9,9 @@ #include #include #include -#include #include -#include #include #include -#include -#include #include #include @@ -113,7 +109,7 @@ namespace kernels { void convert_to_hybrid(std::shared_ptr exec, \ const matrix::Csr* source, \ const int64* coo_row_ptrs, \ - matrix::Hybrid* result) + matrix::view::hybrid result) #define GKO_DECLARE_CSR_CONVERT_TO_SELLP_KERNEL(ValueType, IndexType) \ void convert_to_sellp(std::shared_ptr exec, \ diff --git a/core/matrix/dense.cpp b/core/matrix/dense.cpp index 43dbbefb7ba..da9e95a9426 100644 --- a/core/matrix/dense.cpp +++ b/core/matrix/dense.cpp @@ -912,7 +912,7 @@ void Dense::convert_impl(Hybrid* result) const tmp->resize(this->get_size(), ell_lim, coo_nnz); exec->run(dense::make_convert_to_hybrid(this->get_const_device_view(), coo_row_ptrs.get_const_data(), - tmp.get())); + tmp->get_device_view())); } diff --git a/core/matrix/dense_kernels.hpp b/core/matrix/dense_kernels.hpp index 31eb9f0f6c3..1520fa57e03 100644 --- a/core/matrix/dense_kernels.hpp +++ b/core/matrix/dense_kernels.hpp @@ -170,7 +170,7 @@ namespace kernels { void convert_to_hybrid(std::shared_ptr exec, \ matrix::view::dense source, \ const int64* coo_row_ptrs, \ - matrix::Hybrid* other) + matrix::view::hybrid other) #define GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL(ValueType, IndexType) \ void convert_to_sellp(std::shared_ptr exec, \ diff --git a/core/matrix/hybrid.cpp b/core/matrix/hybrid.cpp index 68c2f5b1faf..675d48899ad 100644 --- a/core/matrix/hybrid.cpp +++ b/core/matrix/hybrid.cpp @@ -113,6 +113,23 @@ Hybrid::Hybrid(std::shared_ptr exec, {} +template +typename Hybrid::device_view +Hybrid::get_device_view() +{ + return device_view{ell_->get_device_view(), coo_->get_device_view()}; +} + + +template +typename Hybrid::const_device_view +Hybrid::get_const_device_view() const +{ + return const_device_view{ell_->get_const_device_view(), + coo_->get_const_device_view()}; +} + + template std::unique_ptr> Hybrid::create(std::shared_ptr exec, @@ -316,8 +333,8 @@ void Hybrid::convert_to( tmp->values_.resize_and_reset(nnz); tmp->set_size(this->get_size()); exec->run(hybrid::make_convert_to_csr( - this, ell_row_ptrs.get_const_data(), coo_row_ptrs.get_const_data(), - tmp.get())); + this->get_const_device_view(), ell_row_ptrs.get_const_data(), + coo_row_ptrs.get_const_data(), tmp.get())); } result->make_srow(); } @@ -367,9 +384,9 @@ void Hybrid::read(const device_mat_data& data) coo_row_ptrs.get_data())); coo_nnz = get_element(coo_row_ptrs, num_rows); this->resize(data.get_size(), ell_max_nnz, coo_nnz); - exec->run( - hybrid::make_fill_in_matrix_data(*local_data, row_ptrs.get_const_data(), - coo_row_ptrs.get_const_data(), this)); + exec->run(hybrid::make_fill_in_matrix_data( + *local_data, row_ptrs.get_const_data(), coo_row_ptrs.get_const_data(), + this->get_device_view())); } diff --git a/core/matrix/hybrid_kernels.hpp b/core/matrix/hybrid_kernels.hpp index 85ff74bfab5..eaf8ad5606b 100644 --- a/core/matrix/hybrid_kernels.hpp +++ b/core/matrix/hybrid_kernels.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -6,8 +6,8 @@ #define GKO_CORE_MATRIX_HYBRID_KERNELS_HPP_ +#include #include -#include #include "core/base/kernel_declaration.hpp" @@ -30,14 +30,14 @@ namespace kernels { std::shared_ptr exec, \ const device_matrix_data& data, \ const int64* row_ptrs, const int64* coo_row_ptrs, \ - matrix::Hybrid* result) - -#define GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL(ValueType, IndexType) \ - void convert_to_csr(std::shared_ptr exec, \ - const matrix::Hybrid* source, \ - const IndexType* ell_row_ptrs, \ - const IndexType* coo_row_ptrs, \ - matrix::Csr* result) + matrix::view::hybrid result) + +#define GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL(ValueType, IndexType) \ + void convert_to_csr( \ + std::shared_ptr exec, \ + matrix::view::hybrid source, \ + const IndexType* ell_row_ptrs, const IndexType* coo_row_ptrs, \ + matrix::Csr* result) #define GKO_DECLARE_ALL_AS_TEMPLATES \ diff --git a/core/reorder/rcm.cpp b/core/reorder/rcm.cpp index 1acf4d97f1f..5f832f690f0 100644 --- a/core/reorder/rcm.cpp +++ b/core/reorder/rcm.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -13,6 +13,7 @@ #include #include #include +#include #include #include diff --git a/core/test/matrix/hybrid.cpp b/core/test/matrix/hybrid.cpp index d1a69312755..700cbc432af 100644 --- a/core/test/matrix/hybrid.cpp +++ b/core/test/matrix/hybrid.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -457,3 +457,59 @@ TYPED_TEST(Hybrid, GetCorrectAutomatic) ASSERT_NO_THROW(gko::as(mtx->template get_strategy())); } + + +TYPED_TEST(Hybrid, CanCreateDeviceView) +{ + auto view = this->mtx->get_device_view(); + + EXPECT_EQ(view.size, this->mtx->get_size()); + EXPECT_EQ(view.ell_part.num_stored_elements_per_row, + this->mtx->get_ell_num_stored_elements_per_row()); + EXPECT_EQ(view.ell_part.stride, this->mtx->get_ell_stride()); + EXPECT_EQ(view.ell_part.values, this->mtx->get_ell_values()); + EXPECT_EQ(view.ell_part.col_idxs, this->mtx->get_ell_col_idxs()); + EXPECT_EQ(view.coo_part.num_stored_elements, + this->mtx->get_coo_num_stored_elements()); + EXPECT_EQ(view.coo_part.row_idxs, this->mtx->get_coo_row_idxs()); + EXPECT_EQ(view.coo_part.col_idxs, this->mtx->get_coo_col_idxs()); + EXPECT_EQ(view.coo_part.values, this->mtx->get_coo_values()); +} + + +TYPED_TEST(Hybrid, CanCreateConstDeviceView) +{ + auto view = this->mtx->get_const_device_view(); + + EXPECT_EQ(view.size, this->mtx->get_size()); + EXPECT_EQ(view.ell_part.num_stored_elements_per_row, + this->mtx->get_ell_num_stored_elements_per_row()); + EXPECT_EQ(view.ell_part.stride, this->mtx->get_ell_stride()); + EXPECT_EQ(view.ell_part.values, this->mtx->get_ell_values()); + EXPECT_EQ(view.ell_part.col_idxs, this->mtx->get_ell_col_idxs()); + EXPECT_EQ(view.coo_part.num_stored_elements, + this->mtx->get_coo_num_stored_elements()); + EXPECT_EQ(view.coo_part.row_idxs, this->mtx->get_coo_row_idxs()); + EXPECT_EQ(view.coo_part.col_idxs, this->mtx->get_coo_col_idxs()); + EXPECT_EQ(view.coo_part.values, this->mtx->get_coo_values()); +} + + +TEST(HybridView, CreateFailsWithNonMatchingSizes) +{ +#ifdef NDEBUG + GTEST_SKIP() << "Assertion is only enabled in debug mode"; +#endif + + using value_type = double; + using index_type = int; + auto exec = gko::ReferenceExecutor::create(); + auto coo = gko::matrix::Coo::create( + exec, gko::dim<2>(1, 2)); + auto ell = gko::matrix::Ell::create( + exec, gko::dim<2>(2, 1)); + + using view_t = gko::matrix::view::hybrid; + EXPECT_EXIT(view_t(ell->get_device_view(), coo->get_device_view()), + check_assertion_exit_code, ""); +} diff --git a/dpcpp/matrix/csr_kernels.dp.cpp b/dpcpp/matrix/csr_kernels.dp.cpp index 7fa78c71485..401d26768a0 100644 --- a/dpcpp/matrix/csr_kernels.dp.cpp +++ b/dpcpp/matrix/csr_kernels.dp.cpp @@ -15,10 +15,6 @@ #include #include #include -#include -#include -#include -#include #include #include "accessor/sycl_helper.hpp" diff --git a/dpcpp/matrix/dense_kernels.dp.cpp b/dpcpp/matrix/dense_kernels.dp.cpp index fc3facb87eb..219d416aa84 100644 --- a/dpcpp/matrix/dense_kernels.dp.cpp +++ b/dpcpp/matrix/dense_kernels.dp.cpp @@ -10,11 +10,8 @@ #include #include -#include #include #include -#include -#include #include #include @@ -408,19 +405,19 @@ template void convert_to_hybrid(std::shared_ptr exec, matrix::view::dense source, const int64* coo_row_ptrs, - matrix::Hybrid* result) + matrix::view::hybrid result) { - const auto num_rows = result->get_size()[0]; - const auto num_cols = result->get_size()[1]; - const auto ell_lim = result->get_ell_num_stored_elements_per_row(); + const auto num_rows = result.size[0]; + const auto num_cols = result.size[1]; + const auto ell_lim = result.ell_part.num_stored_elements_per_row; const auto in_vals = as_device_type(source.values); const auto in_stride = source.stride; - const auto ell_stride = result->get_ell_stride(); - auto ell_cols = result->get_ell_col_idxs(); - auto ell_vals = as_device_type(result->get_ell_values()); - auto coo_rows = result->get_coo_row_idxs(); - auto coo_cols = result->get_coo_col_idxs(); - auto coo_vals = as_device_type(result->get_coo_values()); + const auto ell_stride = result.ell_part.stride; + auto ell_cols = result.ell_part.col_idxs; + auto ell_vals = as_device_type(result.ell_part.values); + auto coo_rows = result.coo_part.row_idxs; + auto coo_cols = result.coo_part.col_idxs; + auto coo_vals = as_device_type(result.coo_part.values); exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for(num_rows, [=](sycl::item<1> item) { diff --git a/include/ginkgo/core/matrix/device_views.hpp b/include/ginkgo/core/matrix/device_views.hpp index 427631c3495..683bb66e0fa 100644 --- a/include/ginkgo/core/matrix/device_views.hpp +++ b/include/ginkgo/core/matrix/device_views.hpp @@ -217,6 +217,38 @@ struct sellp { }; +/** + * Non-owning view of a matrix::Hybrid to be used inside device kernels. + * This type is used to provide a simple and stable ABI for passing data between + * libraries. + * + * @tparam ValueType the value type used to store matrix values. + * @tparam IndexType the index type used to store matrix columns. + */ +template +struct hybrid { + static_assert(std::is_const_v == std::is_const_v, + "ValueType and IndexType must share the same constness"); + dim<2> size; + ell ell_part; + coo coo_part; + + /** Constructs a hybrid view */ + constexpr hybrid(ell ell_, + coo coo_) + : size(ell_.size), ell_part(ell_), coo_part(coo_) + { + assert(ell_part.size == coo_part.size); + } + + /** Returns a const view of the same values */ + constexpr hybrid as_const() const + { + return {ell_part.as_const(), coo_part.as_const()}; + } +}; + + } // namespace view } // namespace matrix } // namespace gko diff --git a/include/ginkgo/core/matrix/hybrid.hpp b/include/ginkgo/core/matrix/hybrid.hpp index 2a1e731a400..43a4e208459 100644 --- a/include/ginkgo/core/matrix/hybrid.hpp +++ b/include/ginkgo/core/matrix/hybrid.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -77,6 +77,8 @@ class Hybrid using index_type = IndexType; using mat_data = matrix_data; using device_mat_data = device_matrix_data; + using device_view = view::hybrid; + using const_device_view = view::hybrid; using coo_type = Coo; using ell_type = Ell; using absolute_type = remove_complex; @@ -642,6 +644,20 @@ class Hybrid template std::shared_ptr get_strategy() const; + /** + * Returns a non-owning device view of this matrix. + * + * @return a device view of this matrix. + */ + device_view get_device_view(); + + /** + * Returns a non-owning const device view of this matrix. + * + * @return a const device view of this matrix. + */ + const_device_view get_const_device_view() const; + /** * Creates an uninitialized Hybrid matrix of specified method. * (ell_num_stored_elements_per_row is set to the number of cols of the diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index e9280608c09..4b3b51d51a0 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -15,9 +15,6 @@ #include #include #include -#include -#include -#include #include "core/base/allocator.hpp" #include "core/base/index_range.hpp" diff --git a/omp/matrix/dense_kernels.cpp b/omp/matrix/dense_kernels.cpp index fac82313f1d..b9efadad57e 100644 --- a/omp/matrix/dense_kernels.cpp +++ b/omp/matrix/dense_kernels.cpp @@ -11,12 +11,9 @@ #include #include #include -#include #include #include -#include #include -#include #include #include @@ -290,14 +287,14 @@ template void convert_to_hybrid(std::shared_ptr exec, matrix::view::dense source, const int64* coo_row_ptrs, - matrix::Hybrid* result) + matrix::view::hybrid result) { - auto num_rows = result->get_size()[0]; - auto num_cols = result->get_size()[1]; - auto ell_lim = result->get_ell_num_stored_elements_per_row(); - auto coo_val = result->get_coo_values(); - auto coo_col = result->get_coo_col_idxs(); - auto coo_row = result->get_coo_row_idxs(); + auto num_rows = result.size[0]; + auto num_cols = result.size[1]; + auto ell_lim = result.ell_part.num_stored_elements_per_row; + auto coo_val = result.coo_part.values; + auto coo_col = result.coo_part.col_idxs; + auto coo_row = result.coo_part.row_idxs; #pragma omp parallel for for (size_type row = 0; row < num_rows; row++) { @@ -306,14 +303,14 @@ void convert_to_hybrid(std::shared_ptr exec, for (; col < num_cols && ell_count < ell_lim; col++) { auto val = source(row, col); if (is_nonzero(val)) { - result->ell_val_at(row, ell_count) = val; - result->ell_col_at(row, ell_count) = col; + result.ell_part.val_at(row, ell_count) = val; + result.ell_part.col_at(row, ell_count) = col; ell_count++; } } for (; ell_count < ell_lim; ell_count++) { - result->ell_val_at(row, ell_count) = zero(); - result->ell_col_at(row, ell_count) = invalid_index(); + result.ell_part.val_at(row, ell_count) = zero(); + result.ell_part.col_at(row, ell_count) = invalid_index(); } auto coo_idx = coo_row_ptrs[row]; for (; col < num_cols; col++) { diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index 22455007b84..1e992728d1e 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -13,10 +13,6 @@ #include #include #include -#include -#include -#include -#include #include #include "core/base/allocator.hpp" @@ -910,22 +906,20 @@ template void convert_to_hybrid(std::shared_ptr exec, const matrix::Csr* source, const int64*, - matrix::Hybrid* result) + matrix::view::hybrid result) { - auto num_rows = result->get_size()[0]; - auto num_cols = result->get_size()[1]; - auto strategy = result->get_strategy(); - auto ell_lim = result->get_ell_num_stored_elements_per_row(); - auto coo_val = result->get_coo_values(); - auto coo_col = result->get_coo_col_idxs(); - auto coo_row = result->get_coo_row_idxs(); + auto num_rows = result.size[0]; + auto num_cols = result.size[1]; + auto ell_lim = result.ell_part.num_stored_elements_per_row; + auto coo_val = result.coo_part.values; + auto coo_col = result.coo_part.col_idxs; + auto coo_row = result.coo_part.row_idxs; // Initial Hybrid Matrix - for (size_type i = 0; i < result->get_ell_num_stored_elements_per_row(); - i++) { - for (size_type j = 0; j < result->get_ell_stride(); j++) { - result->ell_val_at(j, i) = zero(); - result->ell_col_at(j, i) = invalid_index(); + for (size_type i = 0; i < ell_lim; i++) { + for (size_type j = 0; j < result.ell_part.stride; j++) { + result.ell_part.val_at(j, i) = zero(); + result.ell_part.col_at(j, i) = invalid_index(); } } @@ -938,8 +932,8 @@ void convert_to_hybrid(std::shared_ptr exec, while (csr_idx < csr_row_ptrs[row + 1]) { const auto val = csr_vals[csr_idx]; if (ell_idx < ell_lim) { - result->ell_val_at(row, ell_idx) = val; - result->ell_col_at(row, ell_idx) = + result.ell_part.val_at(row, ell_idx) = val; + result.ell_part.col_at(row, ell_idx) = source->get_const_col_idxs()[csr_idx]; ell_idx++; } else { diff --git a/reference/matrix/dense_kernels.cpp b/reference/matrix/dense_kernels.cpp index 0c07f03b4b7..4520ced55f0 100644 --- a/reference/matrix/dense_kernels.cpp +++ b/reference/matrix/dense_kernels.cpp @@ -9,12 +9,9 @@ #include #include #include -#include #include #include -#include #include -#include #include #include @@ -601,23 +598,23 @@ template void convert_to_hybrid(std::shared_ptr exec, matrix::view::dense source, const int64*, - matrix::Hybrid* result) + matrix::view::hybrid result) { - auto num_rows = result->get_size()[0]; - auto num_cols = result->get_size()[1]; - auto strategy = result->get_strategy(); - auto ell_lim = strategy->get_ell_num_stored_elements_per_row(); - auto coo_lim = strategy->get_coo_nnz(); - auto coo_val = result->get_coo_values(); - auto coo_col = result->get_coo_col_idxs(); - auto coo_row = result->get_coo_row_idxs(); - for (size_type i = 0; i < result->get_ell_num_stored_elements_per_row(); - i++) { - for (size_type j = 0; j < result->get_ell_stride(); j++) { - result->ell_val_at(j, i) = zero(); - result->ell_col_at(j, i) = invalid_index(); - } - } + auto num_rows = result.size[0]; + auto num_cols = result.size[1]; + auto ell_lim = result.ell_part.num_stored_elements_per_row; + auto coo_lim = result.coo_part.num_stored_elements; + auto coo_val = result.coo_part.values; + auto coo_col = result.coo_part.col_idxs; + auto coo_row = result.coo_part.row_idxs; + std::fill_n( + result.ell_part.values, + result.ell_part.stride * result.ell_part.num_stored_elements_per_row, + zero()); + std::fill_n( + result.ell_part.col_idxs, + result.ell_part.stride * result.ell_part.num_stored_elements_per_row, + invalid_index()); size_type coo_idx = 0; for (size_type row = 0; row < num_rows; row++) { @@ -626,8 +623,8 @@ void convert_to_hybrid(std::shared_ptr exec, col++) { auto val = source(row, col); if (is_nonzero(val)) { - result->ell_val_at(row, col_idx) = val; - result->ell_col_at(row, col_idx) = col; + result.ell_part.val_at(row, col_idx) = val; + result.ell_part.col_at(row, col_idx) = col; col_idx++; } } diff --git a/reference/matrix/hybrid_kernels.cpp b/reference/matrix/hybrid_kernels.cpp index f2a06c321f2..399d609cf5b 100644 --- a/reference/matrix/hybrid_kernels.cpp +++ b/reference/matrix/hybrid_kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -6,10 +6,8 @@ #include #include -#include #include #include -#include #include "core/components/format_conversion_kernels.hpp" #include "core/components/prefix_sum_kernels.hpp" @@ -57,10 +55,10 @@ template void fill_in_matrix_data(std::shared_ptr exec, const device_matrix_data& data, const int64* row_ptrs, const int64*, - matrix::Hybrid* result) + matrix::view::hybrid result) { - const auto num_rows = result->get_size()[0]; - const auto ell_max_nnz = result->get_ell_num_stored_elements_per_row(); + const auto num_rows = result.size[0]; + const auto ell_max_nnz = result.ell_part.num_stored_elements_per_row; const auto values = data.get_const_values(); const auto row_idxs = data.get_const_row_idxs(); const auto col_idxs = data.get_const_col_idxs(); @@ -69,19 +67,19 @@ void fill_in_matrix_data(std::shared_ptr exec, size_type ell_nz{}; for (auto nz = row_ptrs[row]; nz < row_ptrs[row + 1]; nz++) { if (ell_nz < ell_max_nnz) { - result->ell_col_at(row, ell_nz) = col_idxs[nz]; - result->ell_val_at(row, ell_nz) = values[nz]; + result.ell_part.col_at(row, ell_nz) = col_idxs[nz]; + result.ell_part.val_at(row, ell_nz) = values[nz]; ell_nz++; } else { - result->get_coo_row_idxs()[coo_nz] = row_idxs[nz]; - result->get_coo_col_idxs()[coo_nz] = col_idxs[nz]; - result->get_coo_values()[coo_nz] = values[nz]; + result.coo_part.row_idxs[coo_nz] = row_idxs[nz]; + result.coo_part.col_idxs[coo_nz] = col_idxs[nz]; + result.coo_part.values[coo_nz] = values[nz]; coo_nz++; } } for (; ell_nz < ell_max_nnz; ell_nz++) { - result->ell_col_at(row, ell_nz) = invalid_index(); - result->ell_val_at(row, ell_nz) = zero(); + result.ell_part.col_at(row, ell_nz) = invalid_index(); + result.ell_part.val_at(row, ell_nz) = zero(); } } } @@ -91,28 +89,29 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( template -void convert_to_csr(std::shared_ptr exec, - const matrix::Hybrid* source, - const IndexType*, const IndexType*, - matrix::Csr* result) +void convert_to_csr( + std::shared_ptr exec, + matrix::view::hybrid source, + const IndexType*, const IndexType*, + matrix::Csr* result) { auto csr_val = result->get_values(); auto csr_col_idxs = result->get_col_idxs(); auto csr_row_ptrs = result->get_row_ptrs(); - const auto ell = source->get_ell(); - const auto max_nnz_per_row = ell->get_num_stored_elements_per_row(); - const auto coo_val = source->get_const_coo_values(); - const auto coo_col = source->get_const_coo_col_idxs(); - const auto coo_row = source->get_const_coo_row_idxs(); - const auto coo_nnz = source->get_coo_num_stored_elements(); + const auto ell = source.ell_part; + const auto max_nnz_per_row = ell.num_stored_elements_per_row; + const auto coo_val = source.coo_part.values; + const auto coo_col = source.coo_part.col_idxs; + const auto coo_row = source.coo_part.row_idxs; + const auto coo_nnz = source.coo_part.num_stored_elements; csr_row_ptrs[0] = 0; size_type csr_idx = 0; size_type coo_idx = 0; - for (IndexType row = 0; row < source->get_size()[0]; row++) { + for (IndexType row = 0; row < source.size[0]; row++) { // Ell part for (IndexType i = 0; i < max_nnz_per_row; i++) { - const auto val = ell->val_at(row, i); - const auto col = ell->col_at(row, i); + const auto val = ell.val_at(row, i); + const auto col = ell.col_at(row, i); if (col != invalid_index()) { csr_val[csr_idx] = val; csr_col_idxs[csr_idx] = col; diff --git a/test/matrix/hybrid_kernels.cpp b/test/matrix/hybrid_kernels.cpp index b208364551b..ab73190aa7d 100644 --- a/test/matrix/hybrid_kernels.cpp +++ b/test/matrix/hybrid_kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -13,6 +13,7 @@ #include #include #include +#include #include "core/test/utils.hpp" #include "test/utils/common_fixture.hpp"