Skip to content

Commit 8da182a

Browse files
authored
Merge #1289 Add chebyshev iteration
This PR adds chebyshev Iteration Related PR: #1289
2 parents e797f4a + 7904562 commit 8da182a

25 files changed

Lines changed: 1974 additions & 81 deletions

common/unified/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set(UNIFIED_SOURCES
2525
solver/bicgstab_kernels.cpp
2626
solver/cg_kernels.cpp
2727
solver/cgs_kernels.cpp
28+
solver/chebyshev_kernels.cpp
2829
solver/common_gmres_kernels.cpp
2930
solver/fcg_kernels.cpp
3031
solver/gcr_kernels.cpp
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include "core/solver/chebyshev_kernels.hpp"
6+
7+
#include <type_traits>
8+
9+
#include <ginkgo/core/base/std_extensions.hpp>
10+
#include <ginkgo/core/matrix/dense.hpp>
11+
#include <ginkgo/core/solver/chebyshev.hpp>
12+
13+
#include "common/unified/base/kernel_launch.hpp"
14+
15+
16+
namespace gko {
17+
namespace kernels {
18+
namespace GKO_DEVICE_NAMESPACE {
19+
namespace chebyshev {
20+
21+
22+
#if GINKGO_DPCPP_SINGLE_MODE
23+
24+
25+
// we only change type in device code to keep the interface is the same as the
26+
// other backend.
27+
template <typename coeff_type>
28+
using if_single_only_type =
29+
std::conditional_t<std::is_same_v<coeff_type, double>, float,
30+
std::complex<float>>;
31+
32+
33+
#else
34+
35+
36+
template <typename coeff_type>
37+
using if_single_only_type = xstd::type_identity_t<coeff_type>;
38+
39+
40+
#endif
41+
42+
43+
template <typename ValueType>
44+
void init_update(std::shared_ptr<const DefaultExecutor> exec,
45+
const solver::detail::coeff_type<ValueType> alpha,
46+
const matrix::Dense<ValueType>* inner_sol,
47+
matrix::Dense<ValueType>* update_sol,
48+
matrix::Dense<ValueType>* output)
49+
{
50+
using coeff_type =
51+
if_single_only_type<solver::detail::coeff_type<ValueType>>;
52+
// the coeff_type always be the highest precision, so we need
53+
// to cast the others from ValueType to this precision.
54+
using arithmetic_type = device_type<coeff_type>;
55+
56+
auto alpha_val = static_cast<coeff_type>(alpha);
57+
58+
run_kernel(
59+
exec,
60+
[] GKO_KERNEL(auto row, auto col, auto alpha, auto inner_sol,
61+
auto update_sol, auto output) {
62+
const auto inner_val =
63+
static_cast<arithmetic_type>(inner_sol(row, col));
64+
update_sol(row, col) =
65+
static_cast<device_type<ValueType>>(inner_val);
66+
output(row, col) = static_cast<device_type<ValueType>>(
67+
static_cast<arithmetic_type>(output(row, col)) +
68+
alpha * inner_val);
69+
},
70+
output->get_size(), alpha_val, inner_sol, update_sol, output);
71+
}
72+
73+
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);
74+
75+
76+
template <typename ValueType>
77+
void update(std::shared_ptr<const DefaultExecutor> exec,
78+
const solver::detail::coeff_type<ValueType> alpha,
79+
const solver::detail::coeff_type<ValueType> beta,
80+
matrix::Dense<ValueType>* inner_sol,
81+
matrix::Dense<ValueType>* update_sol,
82+
matrix::Dense<ValueType>* output)
83+
{
84+
using coeff_type =
85+
if_single_only_type<solver::detail::coeff_type<ValueType>>;
86+
// the coeff_type always be the highest precision, so we need
87+
// to cast the others from ValueType to this precision.
88+
using arithmetic_type = device_type<coeff_type>;
89+
90+
auto alpha_val = static_cast<coeff_type>(alpha);
91+
auto beta_val = static_cast<coeff_type>(beta);
92+
93+
run_kernel(
94+
exec,
95+
[] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto inner_sol,
96+
auto update_sol, auto output) {
97+
const auto val =
98+
static_cast<arithmetic_type>(inner_sol(row, col)) +
99+
beta * static_cast<arithmetic_type>(update_sol(row, col));
100+
inner_sol(row, col) = static_cast<device_type<ValueType>>(val);
101+
update_sol(row, col) = static_cast<device_type<ValueType>>(val);
102+
output(row, col) = static_cast<device_type<ValueType>>(
103+
static_cast<arithmetic_type>(output(row, col)) + alpha * val);
104+
},
105+
output->get_size(), alpha_val, beta_val, inner_sol, update_sol, output);
106+
}
107+
108+
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);
109+
110+
111+
} // namespace chebyshev
112+
} // namespace GKO_DEVICE_NAMESPACE
113+
} // namespace kernels
114+
} // namespace gko

core/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ target_sources(
9393
matrix/scaled_permutation.cpp
9494
matrix/sellp.cpp
9595
matrix/sparsity_csr.cpp
96-
multigrid/pgm.cpp
9796
multigrid/fixed_coarsening.cpp
97+
multigrid/pgm.cpp
9898
preconditioner/batch_jacobi.cpp
9999
preconditioner/gauss_seidel.cpp
100100
preconditioner/sor.cpp
@@ -113,6 +113,7 @@ target_sources(
113113
solver/cb_gmres.cpp
114114
solver/cg.cpp
115115
solver/cgs.cpp
116+
solver/chebyshev.cpp
116117
solver/direct.cpp
117118
solver/fcg.cpp
118119
solver/gcr.cpp

core/config/config_helper.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ namespace gko {
2424
namespace config {
2525

2626

27-
#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \
28-
GKO_INVALID_STATE(std::string("The value >" + _value + \
29-
"< is invalid for the entry >" + _entry + \
30-
"<"))
27+
#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \
28+
GKO_INVALID_STATE(std::string("The value >") + _value + \
29+
"< is invalid for the entry >" + _entry + "<")
3130

3231

3332
#define GKO_MISSING_CONFIG_ENTRY(_entry) \
@@ -53,6 +52,7 @@ enum class LinOpFactoryType : int {
5352
Direct,
5453
LowerTrs,
5554
UpperTrs,
55+
Chebyshev,
5656
Factorization_Ic,
5757
Factorization_Ilu,
5858
Cholesky,

core/config/registry.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ configuration_map generate_config_map()
3333
{"solver::Direct", parse<LinOpFactoryType::Direct>},
3434
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
3535
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>},
36+
{"solver::Chebyshev", parse<LinOpFactoryType::Chebyshev>},
3637
{"factorization::Ic", parse<LinOpFactoryType::Factorization_Ic>},
3738
{"factorization::Ilu", parse<LinOpFactoryType::Factorization_Ilu>},
3839
{"factorization::Cholesky", parse<LinOpFactoryType::Cholesky>},

core/config/solver_config.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ginkgo/core/solver/cb_gmres.hpp>
1313
#include <ginkgo/core/solver/cg.hpp>
1414
#include <ginkgo/core/solver/cgs.hpp>
15+
#include <ginkgo/core/solver/chebyshev.hpp>
1516
#include <ginkgo/core/solver/direct.hpp>
1617
#include <ginkgo/core/solver/fcg.hpp>
1718
#include <ginkgo/core/solver/gcr.hpp>
@@ -45,6 +46,7 @@ GKO_PARSE_VALUE_TYPE(Minres, gko::solver::Minres);
4546
GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct);
4647
GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs);
4748
GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs);
49+
GKO_PARSE_VALUE_TYPE(Chebyshev, gko::solver::Chebyshev);
4850

4951

5052
template <>

core/device_hooks/common_kernels.inc.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
#include "core/solver/cb_gmres_kernels.hpp"
6464
#include "core/solver/cg_kernels.hpp"
6565
#include "core/solver/cgs_kernels.hpp"
66+
#include "core/solver/chebyshev_kernels.hpp"
6667
#include "core/solver/common_gmres_kernels.hpp"
6768
#include "core/solver/fcg_kernels.hpp"
6869
#include "core/solver/gcr_kernels.hpp"
@@ -677,6 +678,16 @@ GKO_STUB_CB_GMRES_CONST(GKO_DECLARE_CB_GMRES_SOLVE_KRYLOV_KERNEL);
677678
} // namespace cb_gmres
678679

679680

681+
namespace chebyshev {
682+
683+
684+
GKO_STUB_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);
685+
GKO_STUB_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);
686+
687+
688+
} // namespace chebyshev
689+
690+
680691
namespace ir {
681692

682693

0 commit comments

Comments
 (0)