Skip to content

Commit 2765b75

Browse files
committed
use view::hybrid in other kernels
1 parent f0d75e4 commit 2765b75

14 files changed

Lines changed: 74 additions & 101 deletions

File tree

common/cuda_hip/matrix/csr_kernels.template.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include <ginkgo/core/matrix/coo.hpp>
2424
#include <ginkgo/core/matrix/dense.hpp>
2525
#include <ginkgo/core/matrix/ell.hpp>
26-
#include <ginkgo/core/matrix/hybrid.hpp>
2726
#include <ginkgo/core/matrix/sellp.hpp>
2827

2928
#include "accessor/cuda_hip_helper.hpp"

common/cuda_hip/matrix/dense_kernels.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
#include <ginkgo/core/matrix/coo.hpp>
1010
#include <ginkgo/core/matrix/csr.hpp>
1111
#include <ginkgo/core/matrix/diagonal.hpp>
12-
#include <ginkgo/core/matrix/ell.hpp>
1312
#include <ginkgo/core/matrix/fbcsr.hpp>
14-
#include <ginkgo/core/matrix/hybrid.hpp>
1513
#include <ginkgo/core/matrix/sellp.hpp>
1614
#include <ginkgo/core/matrix/sparsity_csr.hpp>
1715

@@ -571,19 +569,19 @@ template <typename ValueType, typename IndexType>
571569
void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec,
572570
matrix::view::dense<const ValueType> source,
573571
const int64* coo_row_ptrs,
574-
matrix::Hybrid<ValueType, IndexType>* result)
572+
matrix::view::hybrid<ValueType, IndexType> result)
575573
{
576-
const auto num_rows = result->get_size()[0];
577-
const auto num_cols = result->get_size()[1];
574+
const auto num_rows = result.size[0];
575+
const auto num_cols = result.size[1];
578576
const auto ell_max_nnz_per_row =
579-
result->get_ell_num_stored_elements_per_row();
577+
result.ell_part.num_stored_elements_per_row;
580578
const auto source_stride = source.stride;
581-
const auto ell_stride = result->get_ell_stride();
582-
auto ell_col_idxs = result->get_ell_col_idxs();
583-
auto ell_values = result->get_ell_values();
584-
auto coo_row_idxs = result->get_coo_row_idxs();
585-
auto coo_col_idxs = result->get_coo_col_idxs();
586-
auto coo_values = result->get_coo_values();
579+
const auto ell_stride = result.ell_part.stride;
580+
auto ell_col_idxs = result.ell_part.col_idxs;
581+
auto ell_values = result.ell_part.values;
582+
auto coo_row_idxs = result.coo_part.row_idxs;
583+
auto coo_col_idxs = result.coo_part.col_idxs;
584+
auto coo_values = result.coo_part.values;
587585

588586
auto grid_dim = ceildiv(num_rows, default_block_size / config::warp_size);
589587
if (grid_dim > 0) {

common/unified/matrix/csr_kernels.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ template <typename ValueType, typename IndexType>
189189
void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec,
190190
const matrix::Csr<ValueType, IndexType>* source,
191191
const int64* coo_row_ptrs,
192-
matrix::Hybrid<ValueType, IndexType>* result)
192+
matrix::view::hybrid<ValueType, IndexType> result)
193193
{
194194
run_kernel(
195195
exec,
@@ -219,10 +219,10 @@ void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec,
219219
},
220220
source->get_size()[0], source->get_const_row_ptrs(),
221221
source->get_const_col_idxs(), source->get_const_values(),
222-
result->get_ell_stride(), result->get_ell_num_stored_elements_per_row(),
223-
result->get_ell_col_idxs(), result->get_ell_values(), coo_row_ptrs,
224-
result->get_coo_row_idxs(), result->get_coo_col_idxs(),
225-
result->get_coo_values());
222+
result.ell_part.stride, result.ell_part.num_stored_elements_per_row,
223+
result.ell_part.col_idxs, result.ell_part.values, coo_row_ptrs,
224+
result.coo_part.row_idxs, result.coo_part.col_idxs,
225+
result.coo_part.values);
226226
}
227227

228228
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(

core/matrix/csr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ginkgo/core/matrix/dense.hpp>
1818
#include <ginkgo/core/matrix/ell.hpp>
1919
#include <ginkgo/core/matrix/fbcsr.hpp>
20+
#include <ginkgo/core/matrix/hybrid.hpp>
2021
#include <ginkgo/core/matrix/identity.hpp>
2122
#include <ginkgo/core/matrix/permutation.hpp>
2223
#include <ginkgo/core/matrix/scaled_permutation.hpp>
@@ -442,7 +443,7 @@ void Csr<ValueType, IndexType>::convert_to(
442443
auto tmp = make_temporary_clone(exec, result);
443444
tmp->resize(this->get_size(), ell_lim, coo_nnz);
444445
exec->run(csr::make_convert_to_hybrid(this, coo_row_ptrs.get_const_data(),
445-
tmp.get()));
446+
tmp->get_device_view()));
446447
}
447448

448449

core/matrix/csr_kernels.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,9 @@
99
#include <ginkgo/core/base/array.hpp>
1010
#include <ginkgo/core/base/index_set.hpp>
1111
#include <ginkgo/core/base/types.hpp>
12-
#include <ginkgo/core/matrix/coo.hpp>
1312
#include <ginkgo/core/matrix/csr.hpp>
14-
#include <ginkgo/core/matrix/dense.hpp>
1513
#include <ginkgo/core/matrix/device_views.hpp>
1614
#include <ginkgo/core/matrix/diagonal.hpp>
17-
#include <ginkgo/core/matrix/ell.hpp>
18-
#include <ginkgo/core/matrix/hybrid.hpp>
1915
#include <ginkgo/core/matrix/sellp.hpp>
2016
#include <ginkgo/core/matrix/sparsity_csr.hpp>
2117

@@ -113,7 +109,7 @@ namespace kernels {
113109
void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec, \
114110
const matrix::Csr<ValueType, IndexType>* source, \
115111
const int64* coo_row_ptrs, \
116-
matrix::Hybrid<ValueType, IndexType>* result)
112+
matrix::view::hybrid<ValueType, IndexType> result)
117113

118114
#define GKO_DECLARE_CSR_CONVERT_TO_SELLP_KERNEL(ValueType, IndexType) \
119115
void convert_to_sellp(std::shared_ptr<const DefaultExecutor> exec, \

core/matrix/dense.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ void Dense<ValueType>::convert_impl(Hybrid<ValueType, IndexType>* result) const
912912
tmp->resize(this->get_size(), ell_lim, coo_nnz);
913913
exec->run(dense::make_convert_to_hybrid(this->get_const_device_view(),
914914
coo_row_ptrs.get_const_data(),
915-
tmp.get()));
915+
tmp->get_device_view()));
916916
}
917917

918918

core/matrix/dense_kernels.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ namespace kernels {
170170
void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec, \
171171
matrix::view::dense<const ValueType> source, \
172172
const int64* coo_row_ptrs, \
173-
matrix::Hybrid<ValueType, IndexType>* other)
173+
matrix::view::hybrid<ValueType, IndexType> other)
174174

175175
#define GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL(ValueType, IndexType) \
176176
void convert_to_sellp(std::shared_ptr<const DefaultExecutor> exec, \

core/reorder/rcm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -13,6 +13,7 @@
1313
#include <ginkgo/core/base/types.hpp>
1414
#include <ginkgo/core/base/utils.hpp>
1515
#include <ginkgo/core/matrix/csr.hpp>
16+
#include <ginkgo/core/matrix/dense.hpp>
1617
#include <ginkgo/core/matrix/permutation.hpp>
1718
#include <ginkgo/core/matrix/sparsity_csr.hpp>
1819

dpcpp/matrix/csr_kernels.dp.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
#include <ginkgo/core/base/exception_helpers.hpp>
1616
#include <ginkgo/core/base/math.hpp>
1717
#include <ginkgo/core/base/std_extensions.hpp>
18-
#include <ginkgo/core/matrix/coo.hpp>
19-
#include <ginkgo/core/matrix/dense.hpp>
20-
#include <ginkgo/core/matrix/ell.hpp>
21-
#include <ginkgo/core/matrix/hybrid.hpp>
2218
#include <ginkgo/core/matrix/sellp.hpp>
2319

2420
#include "accessor/sycl_helper.hpp"

dpcpp/matrix/dense_kernels.dp.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010

1111
#include <ginkgo/core/base/math.hpp>
1212
#include <ginkgo/core/base/range_accessors.hpp>
13-
#include <ginkgo/core/matrix/coo.hpp>
1413
#include <ginkgo/core/matrix/csr.hpp>
1514
#include <ginkgo/core/matrix/diagonal.hpp>
16-
#include <ginkgo/core/matrix/ell.hpp>
17-
#include <ginkgo/core/matrix/hybrid.hpp>
1815
#include <ginkgo/core/matrix/sellp.hpp>
1916
#include <ginkgo/core/matrix/sparsity_csr.hpp>
2017

@@ -408,19 +405,19 @@ template <typename ValueType, typename IndexType>
408405
void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec,
409406
matrix::view::dense<const ValueType> source,
410407
const int64* coo_row_ptrs,
411-
matrix::Hybrid<ValueType, IndexType>* result)
408+
matrix::view::hybrid<ValueType, IndexType> result)
412409
{
413-
const auto num_rows = result->get_size()[0];
414-
const auto num_cols = result->get_size()[1];
415-
const auto ell_lim = result->get_ell_num_stored_elements_per_row();
410+
const auto num_rows = result.size[0];
411+
const auto num_cols = result.size[1];
412+
const auto ell_lim = result.ell_part.num_stored_elements_per_row;
416413
const auto in_vals = as_device_type(source.values);
417414
const auto in_stride = source.stride;
418-
const auto ell_stride = result->get_ell_stride();
419-
auto ell_cols = result->get_ell_col_idxs();
420-
auto ell_vals = as_device_type(result->get_ell_values());
421-
auto coo_rows = result->get_coo_row_idxs();
422-
auto coo_cols = result->get_coo_col_idxs();
423-
auto coo_vals = as_device_type(result->get_coo_values());
415+
const auto ell_stride = result.ell_part.stride;
416+
auto ell_cols = result.ell_part.col_idxs;
417+
auto ell_vals = as_device_type(result.ell_part.values);
418+
auto coo_rows = result.coo_part.row_idxs;
419+
auto coo_cols = result.coo_part.col_idxs;
420+
auto coo_vals = as_device_type(result.coo_part.values);
424421

425422
exec->get_queue()->submit([&](sycl::handler& cgh) {
426423
cgh.parallel_for(num_rows, [=](sycl::item<1> item) {

0 commit comments

Comments
 (0)