Skip to content

Commit f70c7b6

Browse files
yhmtsaiMarcelKoch
andcommitted
move traits due to requirement of complete type, simplify the traits, revert the apply due to real apply on complex vector
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
1 parent 69a30a4 commit f70c7b6

5 files changed

Lines changed: 143 additions & 111 deletions

File tree

core/config/preconditioner_ic_config.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,30 @@ namespace gko {
1717
namespace config {
1818

1919

20-
GKO_PARSE_VALUE_AND_INDEX_TYPE(Ic, gko::preconditioner::Ic);
20+
template <>
21+
deferred_factory_parameter<gko::LinOpFactory>
22+
parse<gko::config::LinOpFactoryType::Ic>(const gko::config::pnode& config,
23+
const gko::config::registry& context,
24+
const gko::config::type_descriptor& td)
25+
{
26+
auto updated = gko::config::update_type(config, td);
27+
if (config.get("l_solver_type_or_value_type")) {
28+
GKO_INVALID_STATE(
29+
"preconditioner::Ic only allows value_type from "
30+
"l_solver_type_or_value_type. To avoid type confusion between "
31+
"these types and value_type, l_solver_type_or_value_type uses "
32+
"the value_type directly.");
33+
}
34+
return gko::config::dispatch<gko::LinOpFactory, gko::preconditioner::Ic>(
35+
config, context, updated,
36+
gko::config::make_type_selector(updated.get_value_typestr(),
37+
gko::config::value_type_list()),
38+
gko::config::make_type_selector(updated.get_index_typestr(),
39+
gko::config::index_type_list()));
40+
}
41+
static_assert(true,
42+
"This assert is used to counter the false positive extra "
43+
"semi-colon warnings");
2144

2245

2346
} // namespace config

core/preconditioner/ic.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,19 @@ namespace preconditioner {
1818
namespace detail {
1919

2020

21-
template <typename Ic,
22-
std::enable_if_t<support_ic_parse<
23-
typename gko::detail::get_first_template<Ic>::type>>* = nullptr>
21+
template <typename Ic, std::enable_if_t<support_ic_parse<Ic>>* = nullptr>
2422
typename Ic::parameters_type ic_parse(
2523
const config::pnode& config, const config::registry& context,
2624
const config::type_descriptor& td_for_child)
2725
{
2826
auto params = Ic::build();
29-
27+
using l_solver_type = typename Ic::l_solver_type;
28+
static_assert(std::is_same_v<l_solver_type, LinOp>,
29+
"only support IC parse when l_solver_type is LinOp.");
3030
if (auto& obj = config.get("l_solver")) {
31-
if constexpr (std::is_same_v<typename Ic::l_solver_type, LinOp>) {
32-
params.with_l_solver(
33-
gko::config::parse_or_get_factory<const LinOpFactory>(
34-
obj, context, td_for_child));
35-
} else {
36-
params.with_l_solver(gko::config::parse_or_get_specific_factory<
37-
const typename Ic::l_solver_type>(
31+
params.with_l_solver(
32+
gko::config::parse_or_get_factory<const LinOpFactory>(
3833
obj, context, td_for_child));
39-
}
4034
}
4135
if (auto& obj = config.get("factorization")) {
4236
params.with_factorization(
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_PUBLIC_CORE_BASE_TYPE_TRAITS_HPP_
6+
#define GKO_PUBLIC_CORE_BASE_TYPE_TRAITS_HPP_
7+
8+
#include <type_traits>
9+
10+
#include <ginkgo/core/base/lin_op.hpp>
11+
12+
namespace gko {
13+
namespace detail {
14+
15+
16+
template <typename Type>
17+
constexpr bool is_ginkgo_linop = std::is_base_of_v<LinOp, Type>;
18+
19+
20+
// helper to get factory type of concrete type or LinOp
21+
template <typename Type>
22+
struct factory_type_impl {
23+
using type = typename Type::Factory;
24+
};
25+
26+
// It requires LinOp to be complete type
27+
template <>
28+
struct factory_type_impl<LinOp> {
29+
using type = LinOpFactory;
30+
};
31+
32+
33+
template <typename Type>
34+
using factory_type = typename factory_type_impl<Type>::type;
35+
36+
37+
template <typename Type>
38+
struct get_solver_type_impl {
39+
using type = std::conditional_t<is_ginkgo_linop<Type>, Type, LinOp>;
40+
};
41+
42+
template <typename Type>
43+
using get_solver_type = typename get_solver_type_impl<Type>::type;
44+
45+
46+
// helper for handle the transposed type of concrete type and LinOp
47+
template <typename Type>
48+
struct transposed_type_impl {
49+
using type = typename Type::transposed_type;
50+
};
51+
52+
// It requires LinOp to be complete type
53+
template <>
54+
struct transposed_type_impl<LinOp> {
55+
using type = LinOp;
56+
};
57+
58+
59+
template <typename Type>
60+
using transposed_type = typename transposed_type_impl<Type>::type;
61+
62+
63+
// helper to get value_type of concrete type or void for LinOp
64+
template <typename Type, typename = void>
65+
struct get_value_type_impl {
66+
using type = typename Type::value_type;
67+
};
68+
69+
// We need to use SFINAE not conditional_t because both type needs to be
70+
// valid in conditional_t
71+
template <typename Type>
72+
struct get_value_type_impl<Type, std::enable_if_t<!is_ginkgo_linop<Type>>> {
73+
using type = Type;
74+
};
75+
76+
77+
template <typename Type>
78+
using get_value_type = typename get_value_type_impl<Type>::type;
79+
80+
81+
} // namespace detail
82+
} // namespace gko
83+
84+
#endif // GKO_PUBLIC_CORE_BASE_TYPE_TRAITS_HPP_

include/ginkgo/core/base/utils_helper.hpp

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -152,79 +152,6 @@ template <typename Pointer>
152152
using shared_type = std::shared_ptr<pointee<Pointer>>;
153153

154154

155-
// helper for handle the transposed type of concrete type and LinOp
156-
template <typename Type>
157-
struct transposed_type_impl {
158-
using type = typename Type::transposed_type;
159-
};
160-
161-
template <>
162-
struct transposed_type_impl<LinOp> {
163-
using type = LinOp;
164-
};
165-
166-
167-
template <typename Type>
168-
using transposed_type = typename transposed_type_impl<Type>::type;
169-
170-
171-
// helper to get factory type of concrete type or LinOp
172-
template <typename Type>
173-
struct factory_type_impl {
174-
using type = typename Type::Factory;
175-
};
176-
177-
template <>
178-
struct factory_type_impl<LinOp> {
179-
using type = LinOpFactory;
180-
};
181-
182-
183-
template <typename Type>
184-
using factory_type = typename factory_type_impl<Type>::type;
185-
186-
template <typename Type>
187-
constexpr bool is_ginkgo_linop = std::is_convertible_v<Type*, LinOp*>;
188-
189-
template <typename Type>
190-
struct get_solver_type_impl {
191-
using type = std::conditional_t<is_ginkgo_linop<Type>, Type, LinOp>;
192-
};
193-
194-
template <typename Type>
195-
using get_solver_type = typename get_solver_type_impl<Type>::type;
196-
197-
198-
// helper to get value_type of concrete type or void for LinOp
199-
template <typename Type, typename = void>
200-
struct get_value_type_impl {
201-
using type = typename Type::value_type;
202-
};
203-
204-
// We need to use SFINAE not conditional_t because both type needs to be valid
205-
// in conditional_t
206-
template <typename Type>
207-
struct get_value_type_impl<Type, std::enable_if_t<!is_ginkgo_linop<Type>>> {
208-
using type = Type;
209-
};
210-
211-
212-
template <typename Type>
213-
using get_value_type = typename get_value_type_impl<Type>::type;
214-
215-
216-
// get_first_template is to get the first template argument of class.
217-
// It can be easily done by introducing another member type of IC to alias the
218-
// first template argument, but it introduces another public interface.
219-
template <class>
220-
struct get_first_template {};
221-
222-
template <template <typename...> class Base, class First, class... Rest>
223-
struct get_first_template<Base<First, Rest...>> {
224-
using type = First;
225-
};
226-
227-
228155
} // namespace detail
229156

230157

include/ginkgo/core/preconditioner/ic.hpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <ginkgo/core/base/exception_helpers.hpp>
1616
#include <ginkgo/core/base/lin_op.hpp>
1717
#include <ginkgo/core/base/precision_dispatch.hpp>
18-
#include <ginkgo/core/base/utils_helper.hpp>
18+
#include <ginkgo/core/base/type_traits.hpp>
1919
#include <ginkgo/core/config/config.hpp>
2020
#include <ginkgo/core/config/registry.hpp>
2121
#include <ginkgo/core/factorization/par_ic.hpp>
@@ -32,14 +32,12 @@ namespace preconditioner {
3232
namespace detail {
3333

3434

35-
template <typename SolverTypeOrValueType>
35+
template <typename Type>
3636
constexpr bool support_ic_parse =
37-
std::is_same_v<gko::detail::get_solver_type<SolverTypeOrValueType>, LinOp>;
37+
std::is_same_v<typename Type::l_solver_type, LinOp>;
3838

3939

40-
template <typename Ic,
41-
std::enable_if_t<!support_ic_parse<
42-
typename gko::detail::get_first_template<Ic>::type>>* = nullptr>
40+
template <typename Ic, std::enable_if_t<!support_ic_parse<Ic>>* = nullptr>
4341
typename Ic::parameters_type ic_parse(
4442
const config::pnode& config, const config::registry& context,
4543
const config::type_descriptor& td_for_child)
@@ -48,9 +46,7 @@ typename Ic::parameters_type ic_parse(
4846
"preconditioner::Ic only supports limited type for parse.");
4947
}
5048

51-
template <typename Ic,
52-
std::enable_if_t<support_ic_parse<
53-
typename gko::detail::get_first_template<Ic>::type>>* = nullptr>
49+
template <typename Ic, std::enable_if_t<support_ic_parse<Ic>>* = nullptr>
5450
typename Ic::parameters_type ic_parse(
5551
const config::pnode& config, const config::registry& context,
5652
const config::type_descriptor& td_for_child);
@@ -116,7 +112,9 @@ class Ic : public EnableLinOp<Ic<LSolverTypeOrValueType, IndexType>>,
116112
friend class EnablePolymorphicObject<Ic, LinOp>;
117113

118114
public:
119-
using l_solver_type = gko::detail::get_solver_type<LSolverTypeOrValueType>;
115+
using l_solver_type =
116+
std::conditional_t<gko::detail::is_ginkgo_linop<LSolverTypeOrValueType>,
117+
LSolverTypeOrValueType, LinOp>;
120118
static_assert(std::is_same<gko::detail::transposed_type<
121119
gko::detail::transposed_type<l_solver_type>>,
122120
l_solver_type>::value,
@@ -334,26 +332,30 @@ class Ic : public EnableLinOp<Ic<LSolverTypeOrValueType, IndexType>>,
334332
protected:
335333
void apply_impl(const LinOp* b, LinOp* x) const override
336334
{
337-
this->set_cache_to(b);
338-
if (l_solver_->apply_uses_initial_guess()) {
339-
cache_.intermediate->copy_from(b);
340-
}
341-
l_solver_->apply(b, cache_.intermediate);
342-
if (lh_solver_->apply_uses_initial_guess()) {
343-
x->copy_from(cache_.intermediate);
344-
}
345-
lh_solver_->apply(cache_.intermediate, x);
335+
// take care of real-to-complex apply
336+
precision_dispatch_real_complex<value_type>(
337+
[&](auto dense_b, auto dense_x) {
338+
this->set_cache_to(dense_b);
339+
l_solver_->apply(dense_b, cache_.intermediate);
340+
if (lh_solver_->apply_uses_initial_guess()) {
341+
dense_x->copy_from(cache_.intermediate);
342+
}
343+
lh_solver_->apply(cache_.intermediate, dense_x);
344+
},
345+
b, x);
346346
}
347347

348348
void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
349349
LinOp* x) const override
350350
{
351-
this->set_cache_to(b);
352-
if (l_solver_->apply_uses_initial_guess()) {
353-
cache_.intermediate->copy_from(b);
354-
}
355-
l_solver_->apply(b, cache_.intermediate);
356-
lh_solver_->apply(alpha, cache_.intermediate, beta, x);
351+
precision_dispatch_real_complex<value_type>(
352+
[&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
353+
this->set_cache_to(dense_b);
354+
l_solver_->apply(dense_b, cache_.intermediate);
355+
lh_solver_->apply(dense_alpha, cache_.intermediate, dense_beta,
356+
dense_x);
357+
},
358+
alpha, b, beta, x);
357359
}
358360

359361
explicit Ic(std::shared_ptr<const Executor> exec)
@@ -432,6 +434,8 @@ class Ic : public EnableLinOp<Ic<LSolverTypeOrValueType, IndexType>>,
432434
cache_.intermediate =
433435
matrix::Dense<value_type>::create(this->get_executor());
434436
}
437+
// Use b as the initial guess for the first triangular solve
438+
cache_.intermediate->copy_from(b);
435439
}
436440

437441
/**

0 commit comments

Comments
 (0)