Skip to content

Commit f0d75e4

Browse files
MarcelKochyhmtsai
andcommitted
add view::hybrid for kernels
Co-authored-by: Yu-Hsiang M. Tsai <yhmtsai@gmail.com>
1 parent 6a3abf8 commit f0d75e4

8 files changed

Lines changed: 183 additions & 66 deletions

File tree

common/unified/matrix/hybrid_kernels.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -51,7 +51,7 @@ template <typename ValueType, typename IndexType>
5151
void fill_in_matrix_data(std::shared_ptr<const DefaultExecutor> exec,
5252
const device_matrix_data<ValueType, IndexType>& data,
5353
const int64* row_ptrs, const int64* coo_row_ptrs,
54-
matrix::Hybrid<ValueType, IndexType>* result)
54+
matrix::view::hybrid<ValueType, IndexType> result)
5555
{
5656
using device_value_type = device_type<ValueType>;
5757
run_kernel(
@@ -83,25 +83,25 @@ void fill_in_matrix_data(std::shared_ptr<const DefaultExecutor> exec,
8383
},
8484
data.get_size()[0], row_ptrs, data.get_const_values(),
8585
data.get_const_row_idxs(), data.get_const_col_idxs(),
86-
result->get_ell_stride(), result->get_ell_num_stored_elements_per_row(),
87-
result->get_ell_col_idxs(), result->get_ell_values(), coo_row_ptrs,
88-
result->get_coo_row_idxs(), result->get_coo_col_idxs(),
89-
result->get_coo_values());
86+
result.ell_part.stride, result.ell_part.num_stored_elements_per_row,
87+
result.ell_part.col_idxs, result.ell_part.values, coo_row_ptrs,
88+
result.coo_part.row_idxs, result.coo_part.col_idxs,
89+
result.coo_part.values);
9090
}
9191

9292
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
9393
GKO_DECLARE_HYBRID_FILL_IN_MATRIX_DATA_KERNEL);
9494

9595

9696
template <typename ValueType, typename IndexType>
97-
void convert_to_csr(std::shared_ptr<const DefaultExecutor> exec,
98-
const matrix::Hybrid<ValueType, IndexType>* source,
99-
const IndexType* ell_row_ptrs,
100-
const IndexType* coo_row_ptrs,
101-
matrix::Csr<ValueType, IndexType>* result)
97+
void convert_to_csr(
98+
std::shared_ptr<const DefaultExecutor> exec,
99+
matrix::view::hybrid<const ValueType, const IndexType> source,
100+
const IndexType* ell_row_ptrs, const IndexType* coo_row_ptrs,
101+
matrix::Csr<ValueType, IndexType>* result)
102102
{
103-
const auto ell = source->get_ell();
104-
const auto coo = source->get_coo();
103+
const auto ell = source.ell_part;
104+
const auto coo = source.coo_part;
105105
// ELL is stored in column-major, so we swap row and column parameters
106106
run_kernel(
107107
exec,
@@ -117,18 +117,16 @@ void convert_to_csr(std::shared_ptr<const DefaultExecutor> exec,
117117
out_vals[out_idx] = in_vals[ell_idx];
118118
}
119119
},
120-
dim<2>{ell->get_num_stored_elements_per_row(), ell->get_size()[0]},
121-
static_cast<int64>(ell->get_stride()), ell->get_const_col_idxs(),
122-
ell->get_const_values(), ell_row_ptrs, coo_row_ptrs,
123-
result->get_col_idxs(), result->get_values());
120+
dim<2>{ell.num_stored_elements_per_row, ell.size[0]},
121+
static_cast<int64>(ell.stride), ell.col_idxs, ell.values, ell_row_ptrs,
122+
coo_row_ptrs, result->get_col_idxs(), result->get_values());
124123
run_kernel(
125124
exec,
126125
[] GKO_KERNEL(auto idx, auto ell_row_ptrs, auto coo_row_ptrs,
127126
auto out_row_ptrs) {
128127
out_row_ptrs[idx] = ell_row_ptrs[idx] + coo_row_ptrs[idx];
129128
},
130-
source->get_size()[0] + 1, ell_row_ptrs, coo_row_ptrs,
131-
result->get_row_ptrs());
129+
source.size[0] + 1, ell_row_ptrs, coo_row_ptrs, result->get_row_ptrs());
132130
run_kernel(
133131
exec,
134132
[] GKO_KERNEL(auto idx, auto in_rows, auto in_cols, auto in_vals,
@@ -145,9 +143,9 @@ void convert_to_csr(std::shared_ptr<const DefaultExecutor> exec,
145143
out_cols[out_idx] = col;
146144
out_vals[out_idx] = val;
147145
},
148-
coo->get_num_stored_elements(), coo->get_const_row_idxs(),
149-
coo->get_const_col_idxs(), coo->get_const_values(), ell_row_ptrs,
150-
coo_row_ptrs, result->get_col_idxs(), result->get_values());
146+
coo.num_stored_elements, coo.row_idxs, coo.col_idxs, coo.values,
147+
ell_row_ptrs, coo_row_ptrs, result->get_col_idxs(),
148+
result->get_values());
151149
}
152150

153151
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(

core/matrix/hybrid.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,23 @@ Hybrid<ValueType, IndexType>::Hybrid(std::shared_ptr<const Executor> exec,
113113
{}
114114

115115

116+
template <typename ValueType, typename IndexType>
117+
typename Hybrid<ValueType, IndexType>::device_view
118+
Hybrid<ValueType, IndexType>::get_device_view()
119+
{
120+
return device_view{ell_->get_device_view(), coo_->get_device_view()};
121+
}
122+
123+
124+
template <typename ValueType, typename IndexType>
125+
typename Hybrid<ValueType, IndexType>::const_device_view
126+
Hybrid<ValueType, IndexType>::get_const_device_view() const
127+
{
128+
return const_device_view{ell_->get_const_device_view(),
129+
coo_->get_const_device_view()};
130+
}
131+
132+
116133
template <typename ValueType, typename IndexType>
117134
std::unique_ptr<Hybrid<ValueType, IndexType>>
118135
Hybrid<ValueType, IndexType>::create(std::shared_ptr<const Executor> exec,
@@ -316,8 +333,8 @@ void Hybrid<ValueType, IndexType>::convert_to(
316333
tmp->values_.resize_and_reset(nnz);
317334
tmp->set_size(this->get_size());
318335
exec->run(hybrid::make_convert_to_csr(
319-
this, ell_row_ptrs.get_const_data(), coo_row_ptrs.get_const_data(),
320-
tmp.get()));
336+
this->get_const_device_view(), ell_row_ptrs.get_const_data(),
337+
coo_row_ptrs.get_const_data(), tmp.get()));
321338
}
322339
result->make_srow();
323340
}
@@ -367,9 +384,9 @@ void Hybrid<ValueType, IndexType>::read(const device_mat_data& data)
367384
coo_row_ptrs.get_data()));
368385
coo_nnz = get_element(coo_row_ptrs, num_rows);
369386
this->resize(data.get_size(), ell_max_nnz, coo_nnz);
370-
exec->run(
371-
hybrid::make_fill_in_matrix_data(*local_data, row_ptrs.get_const_data(),
372-
coo_row_ptrs.get_const_data(), this));
387+
exec->run(hybrid::make_fill_in_matrix_data(
388+
*local_data, row_ptrs.get_const_data(), coo_row_ptrs.get_const_data(),
389+
this->get_device_view()));
373390
}
374391

375392

core/matrix/hybrid_kernels.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

55
#ifndef GKO_CORE_MATRIX_HYBRID_KERNELS_HPP_
66
#define GKO_CORE_MATRIX_HYBRID_KERNELS_HPP_
77

88

9+
#include <ginkgo/core/matrix/csr.hpp>
910
#include <ginkgo/core/matrix/dense.hpp>
10-
#include <ginkgo/core/matrix/hybrid.hpp>
1111

1212
#include "core/base/kernel_declaration.hpp"
1313

@@ -30,14 +30,14 @@ namespace kernels {
3030
std::shared_ptr<const DefaultExecutor> exec, \
3131
const device_matrix_data<ValueType, IndexType>& data, \
3232
const int64* row_ptrs, const int64* coo_row_ptrs, \
33-
matrix::Hybrid<ValueType, IndexType>* result)
34-
35-
#define GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL(ValueType, IndexType) \
36-
void convert_to_csr(std::shared_ptr<const DefaultExecutor> exec, \
37-
const matrix::Hybrid<ValueType, IndexType>* source, \
38-
const IndexType* ell_row_ptrs, \
39-
const IndexType* coo_row_ptrs, \
40-
matrix::Csr<ValueType, IndexType>* result)
33+
matrix::view::hybrid<ValueType, IndexType> result)
34+
35+
#define GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL(ValueType, IndexType) \
36+
void convert_to_csr( \
37+
std::shared_ptr<const DefaultExecutor> exec, \
38+
matrix::view::hybrid<const ValueType, const IndexType> source, \
39+
const IndexType* ell_row_ptrs, const IndexType* coo_row_ptrs, \
40+
matrix::Csr<ValueType, IndexType>* result)
4141

4242

4343
#define GKO_DECLARE_ALL_AS_TEMPLATES \

core/test/matrix/hybrid.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -457,3 +457,59 @@ TYPED_TEST(Hybrid, GetCorrectAutomatic)
457457

458458
ASSERT_NO_THROW(gko::as<strategy2>(mtx->template get_strategy<Mtx2>()));
459459
}
460+
461+
462+
TYPED_TEST(Hybrid, CanCreateDeviceView)
463+
{
464+
auto view = this->mtx->get_device_view();
465+
466+
EXPECT_EQ(view.size, this->mtx->get_size());
467+
EXPECT_EQ(view.ell_part.num_stored_elements_per_row,
468+
this->mtx->get_ell_num_stored_elements_per_row());
469+
EXPECT_EQ(view.ell_part.stride, this->mtx->get_ell_stride());
470+
EXPECT_EQ(view.ell_part.values, this->mtx->get_ell_values());
471+
EXPECT_EQ(view.ell_part.col_idxs, this->mtx->get_ell_col_idxs());
472+
EXPECT_EQ(view.coo_part.num_stored_elements,
473+
this->mtx->get_coo_num_stored_elements());
474+
EXPECT_EQ(view.coo_part.row_idxs, this->mtx->get_coo_row_idxs());
475+
EXPECT_EQ(view.coo_part.col_idxs, this->mtx->get_coo_col_idxs());
476+
EXPECT_EQ(view.coo_part.values, this->mtx->get_coo_values());
477+
}
478+
479+
480+
TYPED_TEST(Hybrid, CanCreateConstDeviceView)
481+
{
482+
auto view = this->mtx->get_const_device_view();
483+
484+
EXPECT_EQ(view.size, this->mtx->get_size());
485+
EXPECT_EQ(view.ell_part.num_stored_elements_per_row,
486+
this->mtx->get_ell_num_stored_elements_per_row());
487+
EXPECT_EQ(view.ell_part.stride, this->mtx->get_ell_stride());
488+
EXPECT_EQ(view.ell_part.values, this->mtx->get_ell_values());
489+
EXPECT_EQ(view.ell_part.col_idxs, this->mtx->get_ell_col_idxs());
490+
EXPECT_EQ(view.coo_part.num_stored_elements,
491+
this->mtx->get_coo_num_stored_elements());
492+
EXPECT_EQ(view.coo_part.row_idxs, this->mtx->get_coo_row_idxs());
493+
EXPECT_EQ(view.coo_part.col_idxs, this->mtx->get_coo_col_idxs());
494+
EXPECT_EQ(view.coo_part.values, this->mtx->get_coo_values());
495+
}
496+
497+
498+
TEST(HybridView, CreateFailsWithNonMatchingSizes)
499+
{
500+
#ifdef NDEBUG
501+
GTEST_SKIP() << "Assertion is only enabled in debug mode";
502+
#endif
503+
504+
using value_type = double;
505+
using index_type = int;
506+
auto exec = gko::ReferenceExecutor::create();
507+
auto coo = gko::matrix::Coo<value_type, index_type>::create(
508+
exec, gko::dim<2>(1, 2));
509+
auto ell = gko::matrix::Ell<value_type, index_type>::create(
510+
exec, gko::dim<2>(2, 1));
511+
512+
using view_t = gko::matrix::view::hybrid<value_type, index_type>;
513+
EXPECT_EXIT(view_t(ell->get_device_view(), coo->get_device_view()),
514+
check_assertion_exit_code, "");
515+
}

include/ginkgo/core/matrix/device_views.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,36 @@ struct ell {
144144
};
145145

146146

147+
/**
148+
* Non-owning view of a matrix::Hybrid to be used inside device kernels.
149+
* This type is used to provide a simple and stable ABI for passing data between
150+
* libraries.
151+
*
152+
* @tparam ValueType the value type used to store matrix values.
153+
* @tparam IndexType the index type used to store matrix columns.
154+
*/
155+
template <typename ValueType, typename IndexType>
156+
struct hybrid {
157+
dim<2> size;
158+
ell<ValueType, IndexType> ell_part;
159+
coo<ValueType, IndexType> coo_part;
160+
161+
/** Constructs a hybrid view */
162+
constexpr hybrid(ell<ValueType, IndexType> ell_,
163+
coo<ValueType, IndexType> coo_)
164+
: size(ell_.size), ell_part(ell_), coo_part(coo_)
165+
{
166+
assert(ell_part.size == coo_part.size);
167+
}
168+
169+
/** Returns a const view of the same values */
170+
constexpr hybrid<const ValueType, const IndexType> as_const() const
171+
{
172+
return {ell_part.as_const(), coo_part.as_const()};
173+
}
174+
};
175+
176+
147177
} // namespace view
148178
} // namespace matrix
149179
} // namespace gko

include/ginkgo/core/matrix/hybrid.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -77,6 +77,8 @@ class Hybrid
7777
using index_type = IndexType;
7878
using mat_data = matrix_data<ValueType, IndexType>;
7979
using device_mat_data = device_matrix_data<ValueType, IndexType>;
80+
using device_view = view::hybrid<value_type, index_type>;
81+
using const_device_view = view::hybrid<const value_type, const index_type>;
8082
using coo_type = Coo<ValueType, IndexType>;
8183
using ell_type = Ell<ValueType, IndexType>;
8284
using absolute_type = remove_complex<Hybrid>;
@@ -642,6 +644,20 @@ class Hybrid
642644
template <typename HybType>
643645
std::shared_ptr<typename HybType::strategy_type> get_strategy() const;
644646

647+
/**
648+
* Returns a non-owning device view of this matrix.
649+
*
650+
* @return a device view of this matrix.
651+
*/
652+
device_view get_device_view();
653+
654+
/**
655+
* Returns a non-owning const device view of this matrix.
656+
*
657+
* @return a const device view of this matrix.
658+
*/
659+
const_device_view get_const_device_view() const;
660+
645661
/**
646662
* Creates an uninitialized Hybrid matrix of specified method.
647663
* (ell_num_stored_elements_per_row is set to the number of cols of the

0 commit comments

Comments
 (0)