|
| 1 | +// SPDX-FileCopyrightText: 2025 The Ginkgo authors |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: BSD-3-Clause |
| 4 | + |
| 5 | +#include "core/solver/chebyshev_kernels.hpp" |
| 6 | + |
| 7 | +#include <type_traits> |
| 8 | + |
| 9 | +#include <ginkgo/core/base/std_extensions.hpp> |
| 10 | +#include <ginkgo/core/matrix/dense.hpp> |
| 11 | +#include <ginkgo/core/solver/chebyshev.hpp> |
| 12 | + |
| 13 | +#include "common/unified/base/kernel_launch.hpp" |
| 14 | + |
| 15 | + |
| 16 | +namespace gko { |
| 17 | +namespace kernels { |
| 18 | +namespace GKO_DEVICE_NAMESPACE { |
| 19 | +namespace chebyshev { |
| 20 | + |
| 21 | + |
| 22 | +#if GINKGO_DPCPP_SINGLE_MODE |
| 23 | + |
| 24 | + |
| 25 | +// we only change type in device code to keep the interface is the same as the |
| 26 | +// other backend. |
| 27 | +template <typename coeff_type> |
| 28 | +using if_single_only_type = |
| 29 | + std::conditional_t<std::is_same_v<coeff_type, double>, float, |
| 30 | + std::complex<float>>; |
| 31 | + |
| 32 | + |
| 33 | +#else |
| 34 | + |
| 35 | + |
| 36 | +template <typename coeff_type> |
| 37 | +using if_single_only_type = xstd::type_identity_t<coeff_type>; |
| 38 | + |
| 39 | + |
| 40 | +#endif |
| 41 | + |
| 42 | + |
| 43 | +template <typename ValueType> |
| 44 | +void init_update(std::shared_ptr<const DefaultExecutor> exec, |
| 45 | + const solver::detail::coeff_type<ValueType> alpha, |
| 46 | + const matrix::Dense<ValueType>* inner_sol, |
| 47 | + matrix::Dense<ValueType>* update_sol, |
| 48 | + matrix::Dense<ValueType>* output) |
| 49 | +{ |
| 50 | + using coeff_type = |
| 51 | + if_single_only_type<solver::detail::coeff_type<ValueType>>; |
| 52 | + // the coeff_type always be the highest precision, so we need |
| 53 | + // to cast the others from ValueType to this precision. |
| 54 | + using arithmetic_type = device_type<coeff_type>; |
| 55 | + |
| 56 | + auto alpha_val = static_cast<coeff_type>(alpha); |
| 57 | + |
| 58 | + run_kernel( |
| 59 | + exec, |
| 60 | + [] GKO_KERNEL(auto row, auto col, auto alpha, auto inner_sol, |
| 61 | + auto update_sol, auto output) { |
| 62 | + const auto inner_val = |
| 63 | + static_cast<arithmetic_type>(inner_sol(row, col)); |
| 64 | + update_sol(row, col) = |
| 65 | + static_cast<device_type<ValueType>>(inner_val); |
| 66 | + output(row, col) = static_cast<device_type<ValueType>>( |
| 67 | + static_cast<arithmetic_type>(output(row, col)) + |
| 68 | + alpha * inner_val); |
| 69 | + }, |
| 70 | + output->get_size(), alpha_val, inner_sol, update_sol, output); |
| 71 | +} |
| 72 | + |
| 73 | +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL); |
| 74 | + |
| 75 | + |
| 76 | +template <typename ValueType> |
| 77 | +void update(std::shared_ptr<const DefaultExecutor> exec, |
| 78 | + const solver::detail::coeff_type<ValueType> alpha, |
| 79 | + const solver::detail::coeff_type<ValueType> beta, |
| 80 | + matrix::Dense<ValueType>* inner_sol, |
| 81 | + matrix::Dense<ValueType>* update_sol, |
| 82 | + matrix::Dense<ValueType>* output) |
| 83 | +{ |
| 84 | + using coeff_type = |
| 85 | + if_single_only_type<solver::detail::coeff_type<ValueType>>; |
| 86 | + // the coeff_type always be the highest precision, so we need |
| 87 | + // to cast the others from ValueType to this precision. |
| 88 | + using arithmetic_type = device_type<coeff_type>; |
| 89 | + |
| 90 | + auto alpha_val = static_cast<coeff_type>(alpha); |
| 91 | + auto beta_val = static_cast<coeff_type>(beta); |
| 92 | + |
| 93 | + run_kernel( |
| 94 | + exec, |
| 95 | + [] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto inner_sol, |
| 96 | + auto update_sol, auto output) { |
| 97 | + const auto val = |
| 98 | + static_cast<arithmetic_type>(inner_sol(row, col)) + |
| 99 | + beta * static_cast<arithmetic_type>(update_sol(row, col)); |
| 100 | + inner_sol(row, col) = static_cast<device_type<ValueType>>(val); |
| 101 | + update_sol(row, col) = static_cast<device_type<ValueType>>(val); |
| 102 | + output(row, col) = static_cast<device_type<ValueType>>( |
| 103 | + static_cast<arithmetic_type>(output(row, col)) + alpha * val); |
| 104 | + }, |
| 105 | + output->get_size(), alpha_val, beta_val, inner_sol, update_sol, output); |
| 106 | +} |
| 107 | + |
| 108 | +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL); |
| 109 | + |
| 110 | + |
| 111 | +} // namespace chebyshev |
| 112 | +} // namespace GKO_DEVICE_NAMESPACE |
| 113 | +} // namespace kernels |
| 114 | +} // namespace gko |
0 commit comments