|
| 1 | +//@HEADER |
| 2 | +// ************************************************************************ |
| 3 | +// |
| 4 | +// Kokkos v. 4.0 |
| 5 | +// Copyright (2022) National Technology & Engineering |
| 6 | +// Solutions of Sandia, LLC (NTESS). |
| 7 | +// |
| 8 | +// Under the terms of Contract DE-NA0003525 with NTESS, |
| 9 | +// the U.S. Government retains certain rights in this software. |
| 10 | +// |
| 11 | +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. |
| 12 | +// See https://kokkos.org/LICENSE for license information. |
| 13 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 14 | +// |
| 15 | +//@HEADER |
| 16 | + |
| 17 | +#include <Kokkos_Core.hpp> |
| 18 | +#include <Kokkos_Random.hpp> |
| 19 | +#include <KokkosBatched_Gbtrf.hpp> |
| 20 | +#include <KokkosBatched_Gbtrs.hpp> |
| 21 | + |
| 22 | +using ExecutionSpace = Kokkos::DefaultExecutionSpace; |
| 23 | + |
| 24 | +/// \brief Example of batched gbtrf/gbtrs |
| 25 | +/// Solving A * x = b, where |
| 26 | +/// A: [[1, -3, -2, 0], |
| 27 | +/// [-1, 1, -3, -2], |
| 28 | +/// [2, -1, 1, -3], |
| 29 | +/// [0, 2, -1, 1]] |
| 30 | +/// b: [1, 1, 1, 1] |
| 31 | +/// x: [67/81, 22/81, -40/81, -1/27] |
| 32 | +/// |
| 33 | +/// In band storage, |
| 34 | +/// Ab: [[0, 0, 0, 0], |
| 35 | +/// [0, 0, 0, 0], |
| 36 | +/// [0, 0, -2, -2], |
| 37 | +/// [0, -3, -3, -3], |
| 38 | +/// [1, 1, 1, 1], |
| 39 | +/// [-1, -1, -1, 0], |
| 40 | +/// [2, 2, 0, 0]] |
| 41 | +/// |
| 42 | +/// This corresponds to the following system of equations: |
| 43 | +/// x0 - 3 x1 - 2 x2 = 1 |
| 44 | +/// - x0 + x1 - 3 x2 - 2 x3 = 1 |
| 45 | +/// 2 x0 - x1 + x3 - 3 x3 = 1 |
| 46 | +/// 2 x1 - x2 + x3 = 1 |
| 47 | +/// |
| 48 | +int main(int /*argc*/, char** /*argv*/) { |
| 49 | + Kokkos::initialize(); |
| 50 | + { |
| 51 | + using View2DType = Kokkos::View<double**, ExecutionSpace>; |
| 52 | + using View3DType = Kokkos::View<double***, ExecutionSpace>; |
| 53 | + using PivViewType = Kokkos::View<int**, ExecutionSpace>; |
| 54 | + const int Nb = 10, n = 4, kl = 2, ku = 2; |
| 55 | + const int ldab = 2 * kl + ku + 1; |
| 56 | + |
| 57 | + // Matrix Ab in band storage |
| 58 | + View3DType Ab("Ab", Nb, ldab, n); |
| 59 | + |
| 60 | + // Solution |
| 61 | + View2DType x("x", Nb, n); |
| 62 | + |
| 63 | + // Pivot |
| 64 | + PivViewType ipiv("ipiv", Nb, n); |
| 65 | + |
| 66 | + // Initialize Ab and x |
| 67 | + Kokkos::deep_copy(x, 1.0); |
| 68 | + auto h_Ab = Kokkos::create_mirror_view(Ab); |
| 69 | + |
| 70 | + // Upper triangular matrix |
| 71 | + for (int ib = 0; ib < Nb; ib++) { |
| 72 | + // Fill the band matrix |
| 73 | + h_Ab(ib, 2, 2) = -2; |
| 74 | + h_Ab(ib, 2, 3) = -2; |
| 75 | + h_Ab(ib, 3, 1) = -3; |
| 76 | + h_Ab(ib, 3, 2) = -3; |
| 77 | + h_Ab(ib, 3, 3) = -3; |
| 78 | + h_Ab(ib, 4, 0) = 1; |
| 79 | + h_Ab(ib, 4, 1) = 1; |
| 80 | + h_Ab(ib, 4, 2) = 1; |
| 81 | + h_Ab(ib, 4, 3) = 1; |
| 82 | + h_Ab(ib, 5, 0) = -1; |
| 83 | + h_Ab(ib, 5, 1) = -1; |
| 84 | + h_Ab(ib, 5, 2) = -1; |
| 85 | + h_Ab(ib, 6, 0) = 2; |
| 86 | + h_Ab(ib, 6, 1) = 2; |
| 87 | + } |
| 88 | + Kokkos::deep_copy(Ab, h_Ab); |
| 89 | + |
| 90 | + // solve A * x = b with gbtrf and gbtrs |
| 91 | + ExecutionSpace exec; |
| 92 | + using policy_type = Kokkos::RangePolicy<ExecutionSpace, Kokkos::IndexType<int>>; |
| 93 | + policy_type policy{exec, 0, Nb}; |
| 94 | + Kokkos::parallel_for( |
| 95 | + "gbtrf-gbtrs", policy, KOKKOS_LAMBDA(int ib) { |
| 96 | + auto sub_Ab = Kokkos::subview(Ab, ib, Kokkos::ALL, Kokkos::ALL); |
| 97 | + auto sub_ipiv = Kokkos::subview(ipiv, ib, Kokkos::ALL); |
| 98 | + auto sub_x = Kokkos::subview(x, ib, Kokkos::ALL); |
| 99 | + |
| 100 | + // Factorize Ab by gbtrf |
| 101 | + KokkosBatched::SerialGbtrf<KokkosBatched::Algo::Gbtrf::Unblocked>::invoke(sub_Ab, sub_ipiv, kl, ku); |
| 102 | + |
| 103 | + // Solve A * x = b with gbtrs |
| 104 | + KokkosBatched::SerialGbtrs<KokkosBatched::Trans::NoTranspose, KokkosBatched::Algo::Gbtrs::Unblocked>::invoke( |
| 105 | + sub_Ab, sub_ipiv, sub_x, kl, ku); |
| 106 | + }); |
| 107 | + |
| 108 | + // Confirm that the results are correct |
| 109 | + auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x); |
| 110 | + bool correct = true; |
| 111 | + double eps = 1.0e-12; |
| 112 | + for (int ib = 0; ib < Nb; ib++) { |
| 113 | + if (Kokkos::abs(h_x(ib, 0) - 67.0 / 81.0) > eps) correct = false; |
| 114 | + if (Kokkos::abs(h_x(ib, 1) - 22.0 / 81.0) > eps) correct = false; |
| 115 | + if (Kokkos::abs(h_x(ib, 2) + 40.0 / 81.0) > eps) correct = false; |
| 116 | + if (Kokkos::abs(h_x(ib, 3) + 1.0 / 27.0) > eps) correct = false; |
| 117 | + } |
| 118 | + |
| 119 | + if (correct) { |
| 120 | + std::cout << "gbtrf/gbtrs works correctly!" << std::endl; |
| 121 | + } |
| 122 | + } |
| 123 | + Kokkos::finalize(); |
| 124 | +} |
0 commit comments