Skip to content

Commit e56863a

Browse files
yasahi-hpcYuuichi Asahi
andauthored
refactor serial pbtrf implementation details and tests (kokkos#2503)
Signed-off-by: Yuuichi Asahi <[email protected]> Co-authored-by: Yuuichi Asahi <[email protected]>
1 parent b8539d7 commit e56863a

File tree

7 files changed

+127
-309
lines changed

7 files changed

+127
-309
lines changed

batched/dense/impl/KokkosBatched_Pbtrf_Serial_Impl.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
/// \author Yuuichi Asahi ([email protected])
2323

2424
namespace KokkosBatched {
25-
25+
namespace Impl {
2626
template <typename ABViewType>
2727
KOKKOS_INLINE_FUNCTION static int checkPbtrfInput([[maybe_unused]] const ABViewType &Ab) {
2828
static_assert(Kokkos::is_view_v<ABViewType>, "KokkosBatched::pbtrf: ABViewType is not a Kokkos::View.");
@@ -41,6 +41,7 @@ KOKKOS_INLINE_FUNCTION static int checkPbtrfInput([[maybe_unused]] const ABViewT
4141
#endif
4242
return 0;
4343
}
44+
} // namespace Impl
4445

4546
//// Lower ////
4647
template <>
@@ -51,11 +52,12 @@ struct SerialPbtrf<Uplo::Lower, Algo::Pbtrf::Unblocked> {
5152
const int n = Ab.extent(1);
5253
if (n == 0) return 0;
5354

54-
auto info = checkPbtrfInput(Ab);
55+
auto info = Impl::checkPbtrfInput(Ab);
5556
if (info) return info;
5657

5758
const int kd = Ab.extent(0) - 1;
58-
return SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(), kd);
59+
return Impl::SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(),
60+
kd);
5961
}
6062
};
6163

@@ -68,11 +70,12 @@ struct SerialPbtrf<Uplo::Upper, Algo::Pbtrf::Unblocked> {
6870
const int n = Ab.extent(1);
6971
if (n == 0) return 0;
7072

71-
auto info = checkPbtrfInput(Ab);
73+
auto info = Impl::checkPbtrfInput(Ab);
7274
if (info) return info;
7375

7476
const int kd = Ab.extent(0) - 1;
75-
return SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(), kd);
77+
return Impl::SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(),
78+
kd);
7679
}
7780
};
7881

batched/dense/impl/KokkosBatched_Pbtrf_Serial_Internal.hpp

Lines changed: 24 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
#include "KokkosBatched_Util.hpp"
2121
#include "KokkosBlas1_serial_scal_impl.hpp"
22+
#include "KokkosBatched_Syr_Serial_Internal.hpp"
23+
#include "KokkosBatched_Lacgv_Serial_Internal.hpp"
2224

2325
namespace KokkosBatched {
26+
namespace Impl {
2427

2528
///
2629
/// Serial Internal Impl
@@ -36,17 +39,8 @@ struct SerialPbtrfInternalLower {
3639
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
3740
/**/ ValueType *KOKKOS_RESTRICT AB, const int as0, const int as1,
3841
const int kd);
39-
40-
template <typename ValueType>
41-
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
42-
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0,
43-
const int as1, const int kd);
4442
};
4543

46-
///
47-
/// Real matrix
48-
///
49-
5044
template <>
5145
template <typename ValueType>
5246
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(const int an,
@@ -55,7 +49,7 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::inv
5549
const int kd) {
5650
// Compute the Cholesky factorization A = L*L'.
5751
for (int j = 0; j < an; ++j) {
58-
auto a_jj = AB[0 * as0 + j * as1];
52+
auto a_jj = Kokkos::ArithTraits<ValueType>::real(AB[0 * as0 + j * as1]);
5953

6054
// Check if L (j, j) is positive definite
6155
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
@@ -75,68 +69,13 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::inv
7569
const ValueType alpha = 1.0 / a_jj;
7670
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[1 * as0 + j * as1]), 1);
7771

78-
// syr (lower) with alpha = -1.0 to diagonal elements
79-
for (int k = 0; k < kn; ++k) {
80-
auto x_k = AB[(k + 1) * as0 + j * as1];
81-
if (x_k != 0) {
82-
auto temp = -1.0 * x_k;
83-
for (int i = k; i < kn; ++i) {
84-
auto x_i = AB[(i + 1) * as0 + j * as1];
85-
AB[i * as0 + (j + 1 + k - i) * as1] += x_i * temp;
86-
}
87-
}
88-
}
89-
}
90-
}
91-
92-
return 0;
93-
}
94-
95-
///
96-
/// Complex matrix
97-
///
98-
template <>
99-
template <typename ValueType>
100-
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(
101-
const int an,
102-
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0, const int as1, const int kd) {
103-
// Compute the Cholesky factorization A = L*L**H
104-
for (int j = 0; j < an; ++j) {
105-
auto a_jj = AB[0 * as0 + j * as1].real();
106-
107-
// Check if L (j, j) is positive definite
108-
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
109-
if (a_jj <= 0) {
110-
AB[0 * as0 + j * as1] = a_jj;
111-
return j + 1;
112-
}
113-
#endif
114-
115-
a_jj = Kokkos::sqrt(a_jj);
116-
AB[0 * as0 + j * as1] = a_jj;
117-
118-
// Compute elements J+1:J+KN of column J and update the
119-
// trailing submatrix within the band.
120-
int kn = Kokkos::min(kd, an - j - 1);
121-
if (kn > 0) {
122-
// scale to diagonal elements
123-
const ValueType alpha = 1.0 / a_jj;
124-
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[1 * as0 + j * as1]), 1);
125-
126-
// zher (lower) with alpha = -1.0 to diagonal elements
127-
for (int k = 0; k < kn; ++k) {
128-
auto x_k = AB[(k + 1) * as0 + j * as1];
129-
if (x_k != 0) {
130-
auto temp = -1.0 * Kokkos::conj(x_k);
131-
AB[k * as0 + (j + 1) * as1] = AB[k * as0 + (j + 1) * as1].real() + (temp * x_k).real();
132-
for (int i = k + 1; i < kn; ++i) {
133-
auto x_i = AB[(i + 1) * as0 + j * as1];
134-
AB[i * as0 + (j + 1 + k - i) * as1] += x_i * temp;
135-
}
136-
} else {
137-
AB[k * as0 + (j + 1) * as1] = AB[k * as0 + (j + 1) * as1].real();
138-
}
139-
}
72+
// syr or zher (lower) with alpha = -1.0 to diagonal elements
73+
using op = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj,
74+
KokkosBlas::Impl::OpID>;
75+
using op_sym = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpReal,
76+
KokkosBlas::Impl::OpID>;
77+
SerialSyrInternalLower::invoke(op(), op_sym(), kn, -1.0, &(AB[1 * as0 + j * as1]), as0,
78+
&(AB[0 * as0 + (j + 1) * as1]), as0, (as1 - as0));
14079
}
14180
}
14281

@@ -153,16 +92,8 @@ struct SerialPbtrfInternalUpper {
15392
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
15493
/**/ ValueType *KOKKOS_RESTRICT AB, const int as0, const int as1,
15594
const int kd);
156-
157-
template <typename ValueType>
158-
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
159-
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0,
160-
const int as1, const int kd);
16195
};
16296

163-
///
164-
/// Real matrix
165-
///
16697
template <>
16798
template <typename ValueType>
16899
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(const int an,
@@ -171,7 +102,7 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::inv
171102
const int kd) {
172103
// Compute the Cholesky factorization A = U'*U.
173104
for (int j = 0; j < an; ++j) {
174-
auto a_jj = AB[kd * as0 + j * as1];
105+
auto a_jj = Kokkos::ArithTraits<ValueType>::real(AB[kd * as0 + j * as1]);
175106

176107
// Check if U (j,j) is positive definite
177108
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
@@ -191,82 +122,26 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::inv
191122
const ValueType alpha = 1.0 / a_jj;
192123
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[(kd - 1) * as0 + (j + 1) * as1]), kld);
193124

194-
// syr (upper) with alpha = -1.0 to diagonal elements
195-
for (int k = 0; k < kn; ++k) {
196-
auto x_k = AB[(k + kd - 1) * as0 + (j + 1 - k) * as1];
197-
if (x_k != 0) {
198-
auto temp = -1.0 * x_k;
199-
for (int i = 0; i < k + 1; ++i) {
200-
auto x_i = AB[(i + kd - 1) * as0 + (j + 1 - i) * as1];
201-
AB[(kd + i) * as0 + (j + 1 + k - i) * as1] += x_i * temp;
202-
}
203-
}
204-
}
205-
}
206-
}
207-
208-
return 0;
209-
}
210-
211-
///
212-
/// Complex matrix
213-
///
214-
template <>
215-
template <typename ValueType>
216-
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(
217-
const int an,
218-
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0, const int as1, const int kd) {
219-
// Compute the Cholesky factorization A = U**H * U.
220-
for (int j = 0; j < an; ++j) {
221-
auto a_jj = AB[kd * as0 + j * as1].real();
222-
223-
// Check if U (j,j) is positive definite
224-
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
225-
if (a_jj <= 0) {
226-
AB[kd * as0 + j * as1] = a_jj;
227-
return j + 1;
228-
}
229-
#endif
230-
231-
a_jj = Kokkos::sqrt(a_jj);
232-
AB[kd * as0 + j * as1] = a_jj;
233-
234-
// Compute elements J+1:J+KN of row J and update the
235-
// trailing submatrix within the band.
236-
int kn = Kokkos::min(kd, an - j - 1);
237-
int kld = Kokkos::max(1, as0 - 1);
238-
if (kn > 0) {
239-
// scale to diagonal elements
240-
const ValueType alpha = 1.0 / a_jj;
241-
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[(kd - 1) * as0 + (j + 1) * as1]), kld);
242-
243-
// zlacgv to diagonal elements
244-
for (int i = 0; i < kn; ++i) {
245-
AB[(i + kd - 1) * as0 + (j + 1 - i) * as1] = Kokkos::conj(AB[(i + kd - 1) * as0 + (j + 1 - i) * as1]);
246-
}
125+
// zlacgv to diagonal elements (no op for real matrix)
126+
SerialLacgvInternal::invoke(kn, &(AB[(kd - 1) * as0 + (j + 1) * as1]), (as0 - as1));
247127

248-
// zher (upper) with alpha = -1.0 to diagonal elements
249-
for (int k = 0; k < kn; ++k) {
250-
auto x_k = AB[(k + kd - 1) * as0 + (j + 1 - k) * as1];
251-
if (x_k != 0) {
252-
auto temp = -1.0 * Kokkos::conj(x_k);
253-
for (int i = 0; i < k + 1; ++i) {
254-
auto x_i = AB[(i + kd - 1) * as0 + (j + 1 - i) * as1];
255-
AB[(kd + i) * as0 + (j + 1 + k - i) * as1] += x_i * temp;
256-
}
257-
}
258-
}
128+
// syr or zher (upper) with alpha = -1.0 to diagonal elements
129+
using op = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj,
130+
KokkosBlas::Impl::OpID>;
131+
using op_sym = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpReal,
132+
KokkosBlas::Impl::OpID>;
133+
SerialSyrInternalUpper::invoke(op(), op_sym(), kn, -1.0, &(AB[(kd - 1) * as0 + (j + 1) * as1]), as0,
134+
&(AB[kd * as0 + (j + 1) * as1]), as0, (as1 - as0));
259135

260-
// zlacgv to diagonal elements
261-
for (int i = 0; i < kn; ++i) {
262-
AB[(i + kd - 1) * as0 + (j + 1 - i) * as1] = Kokkos::conj(AB[(i + kd - 1) * as0 + (j + 1 - i) * as1]);
263-
}
136+
// zlacgv to diagonal elements (no op for real matrix)
137+
SerialLacgvInternal::invoke(kn, &(AB[(kd - 1) * as0 + (j + 1) * as1]), (as0 - as1));
264138
}
265139
}
266140

267141
return 0;
268142
}
269143

144+
} // namespace Impl
270145
} // namespace KokkosBatched
271146

272147
#endif // KOKKOSBATCHED_PBTRF_SERIAL_INTERNAL_HPP_

batched/dense/src/KokkosBatched_Pbtrf.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ namespace KokkosBatched {
3333
/// L is lower triangular.
3434
/// This is the unblocked version of the algorithm, calling Level 2 BLAS.
3535
///
36-
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D
37-
/// view
36+
/// \tparam ArgUplo: Type indicating whether A is the upper (Uplo::Upper) or lower (Uplo::Lower) triangular matrix
37+
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Pbtrf::Blocked) or unblocked
38+
/// (KokkosBatched::Algo::Pbtrf::Unblocked) algorithm to be used
39+
///
40+
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D view
3841
///
3942
/// \param ab [inout]: ab is a ldab by n banded matrix, with ( kd + 1 ) diagonals
4043
///
@@ -43,6 +46,10 @@ namespace KokkosBatched {
4346

4447
template <typename ArgUplo, typename ArgAlgo>
4548
struct SerialPbtrf {
49+
static_assert(
50+
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>,
51+
"KokkosBatched::pbtrf: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix");
52+
static_assert(std::is_same_v<ArgAlgo, Algo::Pbtrf::Unblocked>, "KokkosBatched::pbtrf: Use Algo::Pbtrf::Unblocked");
4653
template <typename ABViewType>
4754
KOKKOS_INLINE_FUNCTION static int invoke(const ABViewType &ab);
4855
};

batched/dense/unit_test/Test_Batched_Dense.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@
5656
#include "Test_Batched_SerialPttrs_Real.hpp"
5757
#include "Test_Batched_SerialPttrs_Complex.hpp"
5858
#include "Test_Batched_SerialPbtrf.hpp"
59-
#include "Test_Batched_SerialPbtrf_Real.hpp"
60-
#include "Test_Batched_SerialPbtrf_Complex.hpp"
6159
#include "Test_Batched_SerialPbtrs.hpp"
6260
#include "Test_Batched_SerialPbtrs_Real.hpp"
6361
#include "Test_Batched_SerialPbtrs_Complex.hpp"

0 commit comments

Comments
 (0)