Skip to content

Commit e22d736

Browse files
yasahi-hpcYuuichi Asahi
andauthored
Add batched serial gbtrf/gbtrs example (kokkos#2664)
Signed-off-by: Yuuichi Asahi <[email protected]> Co-authored-by: Yuuichi Asahi <[email protected]>
1 parent 45e68ce commit e22d736

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

example/batched_solve/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,8 @@ KOKKOSKERNELS_ADD_EXECUTABLE(
2020
serial_pbtrs
2121
SOURCES serial_pbtrs.cpp
2222
)
23+
24+
KOKKOSKERNELS_ADD_EXECUTABLE(
25+
serial_gbtrs
26+
SOURCES serial_gbtrs.cpp
27+
)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2022) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#include <Kokkos_Core.hpp>
18+
#include <Kokkos_Random.hpp>
19+
#include <KokkosBatched_Gbtrf.hpp>
20+
#include <KokkosBatched_Gbtrs.hpp>
21+
22+
using ExecutionSpace = Kokkos::DefaultExecutionSpace;
23+
24+
/// \brief Example of batched gbtrf/gbtrs
25+
/// Solving A * x = b, where
26+
/// A: [[1, -3, -2, 0],
27+
/// [-1, 1, -3, -2],
28+
/// [2, -1, 1, -3],
29+
/// [0, 2, -1, 1]]
30+
/// b: [1, 1, 1, 1]
31+
/// x: [67/81, 22/81, -40/81, -1/27]
32+
///
33+
/// In band storage,
34+
/// Ab: [[0, 0, 0, 0],
35+
/// [0, 0, 0, 0],
36+
/// [0, 0, -2, -2],
37+
/// [0, -3, -3, -3],
38+
/// [1, 1, 1, 1],
39+
/// [-1, -1, -1, 0],
40+
/// [2, 2, 0, 0]]
41+
///
42+
/// This corresponds to the following system of equations:
43+
/// x0 - 3 x1 - 2 x2 = 1
44+
/// - x0 + x1 - 3 x2 - 2 x3 = 1
45+
/// 2 x0 - x1 + x3 - 3 x3 = 1
46+
/// 2 x1 - x2 + x3 = 1
47+
///
48+
int main(int /*argc*/, char** /*argv*/) {
49+
Kokkos::initialize();
50+
{
51+
using View2DType = Kokkos::View<double**, ExecutionSpace>;
52+
using View3DType = Kokkos::View<double***, ExecutionSpace>;
53+
using PivViewType = Kokkos::View<int**, ExecutionSpace>;
54+
const int Nb = 10, n = 4, kl = 2, ku = 2;
55+
const int ldab = 2 * kl + ku + 1;
56+
57+
// Matrix Ab in band storage
58+
View3DType Ab("Ab", Nb, ldab, n);
59+
60+
// Solution
61+
View2DType x("x", Nb, n);
62+
63+
// Pivot
64+
PivViewType ipiv("ipiv", Nb, n);
65+
66+
// Initialize Ab and x
67+
Kokkos::deep_copy(x, 1.0);
68+
auto h_Ab = Kokkos::create_mirror_view(Ab);
69+
70+
// Upper triangular matrix
71+
for (int ib = 0; ib < Nb; ib++) {
72+
// Fill the band matrix
73+
h_Ab(ib, 2, 2) = -2;
74+
h_Ab(ib, 2, 3) = -2;
75+
h_Ab(ib, 3, 1) = -3;
76+
h_Ab(ib, 3, 2) = -3;
77+
h_Ab(ib, 3, 3) = -3;
78+
h_Ab(ib, 4, 0) = 1;
79+
h_Ab(ib, 4, 1) = 1;
80+
h_Ab(ib, 4, 2) = 1;
81+
h_Ab(ib, 4, 3) = 1;
82+
h_Ab(ib, 5, 0) = -1;
83+
h_Ab(ib, 5, 1) = -1;
84+
h_Ab(ib, 5, 2) = -1;
85+
h_Ab(ib, 6, 0) = 2;
86+
h_Ab(ib, 6, 1) = 2;
87+
}
88+
Kokkos::deep_copy(Ab, h_Ab);
89+
90+
// solve A * x = b with gbtrf and gbtrs
91+
ExecutionSpace exec;
92+
using policy_type = Kokkos::RangePolicy<ExecutionSpace, Kokkos::IndexType<int>>;
93+
policy_type policy{exec, 0, Nb};
94+
Kokkos::parallel_for(
95+
"gbtrf-gbtrs", policy, KOKKOS_LAMBDA(int ib) {
96+
auto sub_Ab = Kokkos::subview(Ab, ib, Kokkos::ALL, Kokkos::ALL);
97+
auto sub_ipiv = Kokkos::subview(ipiv, ib, Kokkos::ALL);
98+
auto sub_x = Kokkos::subview(x, ib, Kokkos::ALL);
99+
100+
// Factorize Ab by gbtrf
101+
KokkosBatched::SerialGbtrf<KokkosBatched::Algo::Gbtrf::Unblocked>::invoke(sub_Ab, sub_ipiv, kl, ku);
102+
103+
// Solve A * x = b with gbtrs
104+
KokkosBatched::SerialGbtrs<KokkosBatched::Trans::NoTranspose, KokkosBatched::Algo::Gbtrs::Unblocked>::invoke(
105+
sub_Ab, sub_ipiv, sub_x, kl, ku);
106+
});
107+
108+
// Confirm that the results are correct
109+
auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x);
110+
bool correct = true;
111+
double eps = 1.0e-12;
112+
for (int ib = 0; ib < Nb; ib++) {
113+
if (Kokkos::abs(h_x(ib, 0) - 67.0 / 81.0) > eps) correct = false;
114+
if (Kokkos::abs(h_x(ib, 1) - 22.0 / 81.0) > eps) correct = false;
115+
if (Kokkos::abs(h_x(ib, 2) + 40.0 / 81.0) > eps) correct = false;
116+
if (Kokkos::abs(h_x(ib, 3) + 1.0 / 27.0) > eps) correct = false;
117+
}
118+
119+
if (correct) {
120+
std::cout << "gbtrf/gbtrs works correctly!" << std::endl;
121+
}
122+
}
123+
Kokkos::finalize();
124+
}

0 commit comments

Comments
 (0)