1515#include < ginkgo/core/base/exception_helpers.hpp>
1616#include < ginkgo/core/base/lin_op.hpp>
1717#include < ginkgo/core/base/precision_dispatch.hpp>
18- #include < ginkgo/core/base/utils_helper .hpp>
18+ #include < ginkgo/core/base/type_traits .hpp>
1919#include < ginkgo/core/config/config.hpp>
2020#include < ginkgo/core/config/registry.hpp>
2121#include < ginkgo/core/factorization/par_ic.hpp>
@@ -32,14 +32,12 @@ namespace preconditioner {
3232namespace detail {
3333
3434
35- template <typename SolverTypeOrValueType >
35+ template <typename Type >
3636constexpr bool support_ic_parse =
37- std::is_same_v<gko::detail::get_solver_type<SolverTypeOrValueType> , LinOp>;
37+ std::is_same_v<typename Type::l_solver_type , LinOp>;
3838
3939
40- template <typename Ic,
41- std::enable_if_t <!support_ic_parse<
42- typename gko::detail::get_first_template<Ic>::type>>* = nullptr >
40+ template <typename Ic, std::enable_if_t <!support_ic_parse<Ic>>* = nullptr >
4341typename Ic::parameters_type ic_parse (
4442 const config::pnode& config, const config::registry& context,
4543 const config::type_descriptor& td_for_child)
@@ -48,9 +46,7 @@ typename Ic::parameters_type ic_parse(
4846 " preconditioner::Ic only supports limited type for parse." );
4947}
5048
51- template <typename Ic,
52- std::enable_if_t <support_ic_parse<
53- typename gko::detail::get_first_template<Ic>::type>>* = nullptr >
49+ template <typename Ic, std::enable_if_t <support_ic_parse<Ic>>* = nullptr >
5450typename Ic::parameters_type ic_parse (
5551 const config::pnode& config, const config::registry& context,
5652 const config::type_descriptor& td_for_child);
@@ -116,7 +112,9 @@ class Ic : public EnableLinOp<Ic<LSolverTypeOrValueType, IndexType>>,
116112 friend class EnablePolymorphicObject <Ic, LinOp>;
117113
118114public:
119- using l_solver_type = gko::detail::get_solver_type<LSolverTypeOrValueType>;
115+ using l_solver_type =
116+ std::conditional_t <gko::detail::is_ginkgo_linop<LSolverTypeOrValueType>,
117+ LSolverTypeOrValueType, LinOp>;
120118 static_assert (std::is_same<gko::detail::transposed_type<
121119 gko::detail::transposed_type<l_solver_type>>,
122120 l_solver_type>::value,
@@ -334,26 +332,30 @@ class Ic : public EnableLinOp<Ic<LSolverTypeOrValueType, IndexType>>,
334332protected:
335333 void apply_impl (const LinOp* b, LinOp* x) const override
336334 {
337- this ->set_cache_to (b);
338- if (l_solver_->apply_uses_initial_guess ()) {
339- cache_.intermediate ->copy_from (b);
340- }
341- l_solver_->apply (b, cache_.intermediate );
342- if (lh_solver_->apply_uses_initial_guess ()) {
343- x->copy_from (cache_.intermediate );
344- }
345- lh_solver_->apply (cache_.intermediate , x);
335+ // take care of real-to-complex apply
336+ precision_dispatch_real_complex<value_type>(
337+ [&](auto dense_b, auto dense_x) {
338+ this ->set_cache_to (dense_b);
339+ l_solver_->apply (dense_b, cache_.intermediate );
340+ if (lh_solver_->apply_uses_initial_guess ()) {
341+ dense_x->copy_from (cache_.intermediate );
342+ }
343+ lh_solver_->apply (cache_.intermediate , dense_x);
344+ },
345+ b, x);
346346 }
347347
348348 void apply_impl (const LinOp* alpha, const LinOp* b, const LinOp* beta,
349349 LinOp* x) const override
350350 {
351- this ->set_cache_to (b);
352- if (l_solver_->apply_uses_initial_guess ()) {
353- cache_.intermediate ->copy_from (b);
354- }
355- l_solver_->apply (b, cache_.intermediate );
356- lh_solver_->apply (alpha, cache_.intermediate , beta, x);
351+ precision_dispatch_real_complex<value_type>(
352+ [&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
353+ this ->set_cache_to (dense_b);
354+ l_solver_->apply (dense_b, cache_.intermediate );
355+ lh_solver_->apply (dense_alpha, cache_.intermediate , dense_beta,
356+ dense_x);
357+ },
358+ alpha, b, beta, x);
357359 }
358360
359361 explicit Ic (std::shared_ptr<const Executor> exec)
@@ -432,6 +434,8 @@ class Ic : public EnableLinOp<Ic<LSolverTypeOrValueType, IndexType>>,
432434 cache_.intermediate =
433435 matrix::Dense<value_type>::create (this ->get_executor ());
434436 }
437+ // Use b as the initial guess for the first triangular solve
438+ cache_.intermediate ->copy_from (b);
435439 }
436440
437441 /* *
0 commit comments