Skip to content

Commit 34b694b

Browse files
authored
Batched - SVD: adding new special matrix test for issue kokkos#2650 (kokkos#2706)
The issue reports that Batched SVD hangs on a 3x6 matrix. Adding it to our list of special matrices to test for will avoid potential future regressions. Limiting the maximum number of iterations by introducing a for loop instead of a while loop and an input parameter to pass a maximum number of iterations. Signed-off-by: Luc Berger-Vergiat <[email protected]>
1 parent 838633e commit 34b694b

File tree

4 files changed

+82
-33
lines changed

4 files changed

+82
-33
lines changed

batched/dense/impl/KokkosBatched_SVD_Serial_Impl.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace KokkosBatched {
2525
template <typename AViewType, typename UViewType, typename VViewType, typename SViewType, typename WViewType>
2626
KOKKOS_INLINE_FUNCTION int SerialSVD::invoke(SVD_USV_Tag, const AViewType &A, const UViewType &U,
2727
const SViewType &sigma, const VViewType &Vt, const WViewType &work,
28-
typename AViewType::const_value_type tol) {
28+
typename AViewType::const_value_type tol, int max_iters) {
2929
static_assert(Kokkos::is_view_v<AViewType> && AViewType::rank == 2, "SVD: A must be a rank-2 view");
3030
static_assert(Kokkos::is_view_v<UViewType> && UViewType::rank == 2, "SVD: U must be a rank-2 view");
3131
static_assert(Kokkos::is_view_v<SViewType> && SViewType::rank == 1, "SVD: s must be a rank-1 view");
@@ -36,13 +36,14 @@ KOKKOS_INLINE_FUNCTION int SerialSVD::invoke(SVD_USV_Tag, const AViewType &A, co
3636
using value_type = typename AViewType::non_const_value_type;
3737
return KokkosBatched::SerialSVDInternal::invoke<value_type>(
3838
A.extent(0), A.extent(1), A.data(), A.stride(0), A.stride(1), U.data(), U.stride(0), U.stride(1), Vt.data(),
39-
Vt.stride(0), Vt.stride(1), sigma.data(), sigma.stride(0), work.data(), tol);
39+
Vt.stride(0), Vt.stride(1), sigma.data(), sigma.stride(0), work.data(), tol, max_iters);
4040
}
4141

4242
// Version which computes only singular values
4343
template <typename AViewType, typename SViewType, typename WViewType>
4444
KOKKOS_INLINE_FUNCTION int SerialSVD::invoke(SVD_S_Tag, const AViewType &A, const SViewType &sigma,
45-
const WViewType &work, typename AViewType::const_value_type tol) {
45+
const WViewType &work, typename AViewType::const_value_type tol,
46+
int max_iters) {
4647
static_assert(Kokkos::is_view_v<AViewType> && AViewType::rank == 2, "SVD: A must be a rank-2 view");
4748
static_assert(Kokkos::is_view_v<SViewType> && SViewType::rank == 1, "SVD: s must be a rank-1 view");
4849
static_assert(Kokkos::is_view_v<WViewType> && WViewType::rank == 1, "SVD: W must be a rank-1 view");
@@ -51,7 +52,7 @@ KOKKOS_INLINE_FUNCTION int SerialSVD::invoke(SVD_S_Tag, const AViewType &A, cons
5152
using value_type = typename AViewType::non_const_value_type;
5253
return KokkosBatched::SerialSVDInternal::invoke<value_type>(A.extent(0), A.extent(1), A.data(), A.stride(0),
5354
A.stride(1), nullptr, 0, 0, nullptr, 0, 0, sigma.data(),
54-
sigma.stride(0), work.data(), tol);
55+
sigma.stride(0), work.data(), tol, max_iters);
5556
}
5657

5758
} // namespace KokkosBatched

batched/dense/impl/KokkosBatched_SVD_Serial_Internal.hpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,14 @@ struct SerialSVDInternal {
212212
// U and Vt to maintain the product U*B*Vt. At the end, the singular values
213213
// are copied to sigma.
214214
template <typename value_type>
215-
KOKKOS_INLINE_FUNCTION static void bidiSVD(int m, int n, value_type* B, int Bs0, int Bs1, value_type* U, int Us0,
216-
int Us1, value_type* Vt, int Vts0, int Vts1, value_type* sigma, int ss,
217-
const value_type& tol) {
215+
KOKKOS_INLINE_FUNCTION static int bidiSVD(int m, int n, value_type* B, int Bs0, int Bs1, value_type* U, int Us0,
216+
int Us1, value_type* Vt, int Vts0, int Vts1, value_type* sigma, int ss,
217+
const value_type& tol, int max_iters) {
218218
using KAT = Kokkos::ArithTraits<value_type>;
219219
const value_type eps = Kokkos::ArithTraits<value_type>::epsilon();
220220
int p = 0;
221221
int q = 0;
222-
while (true) {
222+
for (int iters = 0; iters < max_iters; ++iters) {
223223
// Zero out tiny superdiagonal entries
224224
for (int i = 0; i < n - 1; i++) {
225225
if (Kokkos::abs(SVDIND(B, i, i + 1)) <
@@ -271,10 +271,16 @@ struct SerialSVDInternal {
271271
}
272272
// B22 is nsub * nsub, Usub is m * nsub, and Vtsub is nsub * n
273273
svdStep(Bsub, Usub, Vtsub, m, n, nsub, Bs0, Bs1, Us0, Us1, Vts0, Vts1);
274+
275+
if (iters + 1 == max_iters) {
276+
return -1;
277+
}
274278
}
275279
for (int i = 0; i < n; i++) {
276280
sigma[i * ss] = SVDIND(B, i, i);
277281
}
282+
283+
return 0;
278284
}
279285

280286
// Convert SVD into conventional form: singular values positive and in
@@ -322,7 +328,8 @@ struct SerialSVDInternal {
322328
template <typename value_type>
323329
KOKKOS_INLINE_FUNCTION static int invoke(int m, int n, value_type* A, int As0, int As1, value_type* U, int Us0,
324330
int Us1, value_type* Vt, int Vts0, int Vts1, value_type* sigma, int ss,
325-
value_type* work, value_type tol = Kokkos::ArithTraits<value_type>::zero()) {
331+
value_type* work, value_type tol = Kokkos::ArithTraits<value_type>::zero(),
332+
int max_iters = 1000000000) {
326333
// First, if m < n, need to instead compute (V, s, U^T) = A^T.
327334
// This just means swapping U & Vt, and implicitly transposing A, U and Vt.
328335
if (m < n) {
@@ -345,9 +352,9 @@ struct SerialSVDInternal {
345352
return 0;
346353
}
347354
bidiagonalize(m, n, A, As0, As1, U, Us0, Us1, Vt, Vts0, Vts1, work);
348-
bidiSVD(m, n, A, As0, As1, U, Us0, Us1, Vt, Vts0, Vts1, sigma, ss, tol);
355+
int iter_err = bidiSVD(m, n, A, As0, As1, U, Us0, Us1, Vt, Vts0, Vts1, sigma, ss, tol, max_iters);
349356
postprocessSVD(m, n, U, Us0, Us1, Vt, Vts0, Vts1, sigma, ss);
350-
return 0;
357+
return iter_err;
351358
}
352359
};
353360

batched/dense/src/KokkosBatched_SVD_Decl.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ struct SerialSVD {
5959
template <typename AViewType, typename UViewType, typename VtViewType, typename SViewType, typename WViewType>
6060
KOKKOS_INLINE_FUNCTION static int invoke(
6161
SVD_USV_Tag, const AViewType &A, const UViewType &U, const SViewType &s, const VtViewType &Vt, const WViewType &W,
62-
typename AViewType::const_value_type tol = Kokkos::ArithTraits<typename AViewType::value_type>::zero());
62+
typename AViewType::const_value_type tol = Kokkos::ArithTraits<typename AViewType::value_type>::zero(),
63+
int max_iters = 1000000000);
6364

6465
// Version which computes only singular values
6566
template <typename AViewType, typename SViewType, typename WViewType>
6667
KOKKOS_INLINE_FUNCTION static int invoke(
6768
SVD_S_Tag, const AViewType &A, const SViewType &s, const WViewType &W,
68-
typename AViewType::const_value_type tol = Kokkos::ArithTraits<typename AViewType::value_type>::zero());
69+
typename AViewType::const_value_type tol = Kokkos::ArithTraits<typename AViewType::value_type>::zero(),
70+
int max_iters = 1000000000);
6971
};
7072

7173
} // namespace KokkosBatched

batched/dense/unit_test/Test_Batched_SerialSVD.hpp

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,36 @@ typename V::non_const_value_type simpleNorm2(const V& v) {
5656

5757
// Check that all columns of X are unit length and pairwise orthogonal
5858
template <typename Mat>
59-
void verifyOrthogonal(const Mat& X) {
60-
using Scalar = typename Mat::non_const_value_type;
61-
int k = X.extent(1);
59+
void verifyOrthogonal(const Mat& X, const double epsilon = -1) {
60+
using Scalar = typename Mat::non_const_value_type;
61+
int k = X.extent(1);
62+
const double tol = (epsilon <= 0 ? Test::svdEpsilon<Scalar>() : epsilon);
6263
for (int i = 0; i < k; i++) {
6364
auto col1 = Kokkos::subview(X, Kokkos::ALL(), i);
6465
double len = simpleNorm2(col1);
65-
Test::EXPECT_NEAR_KK(len, 1.0, Test::svdEpsilon<Scalar>());
66+
Test::EXPECT_NEAR_KK(len, 1.0, tol);
6667
for (int j = 0; j < i; j++) {
6768
auto col2 = Kokkos::subview(X, Kokkos::ALL(), j);
6869
double d = Kokkos::ArithTraits<Scalar>::abs(simpleDot(col1, col2));
69-
Test::EXPECT_NEAR_KK(d, 0.0, Test::svdEpsilon<Scalar>());
70+
Test::EXPECT_NEAR_KK(d, 0.0, tol);
7071
}
7172
}
7273
}
7374

7475
template <typename AView, typename UView, typename VtView, typename SigmaView>
75-
void verifySVD(const AView& A, const UView& U, const VtView& Vt, const SigmaView& sigma) {
76+
void verifySVD(const AView& A, const UView& U, const VtView& Vt, const SigmaView& sigma, const double epsilon = -1) {
7677
using Scalar = typename AView::non_const_value_type;
7778
using KAT = Kokkos::ArithTraits<Scalar>;
78-
// Check that U/V columns are unit length and orthogonal, and that U *
79-
// diag(sigma) * V^T == A
80-
int m = A.extent(0);
81-
int n = A.extent(1);
82-
int maxrank = std::min(m, n);
83-
verifyOrthogonal(U);
79+
// Check that U/V columns are unit length and orthogonal
80+
// and that: U * diag(sigma) * V^T == A
81+
int m = A.extent(0);
82+
int n = A.extent(1);
83+
int maxrank = std::min(m, n);
84+
const double tol = (epsilon <= 0 ? Test::svdEpsilon<Scalar>() : epsilon);
85+
verifyOrthogonal(U, epsilon);
8486
// NOTE: V^T being square and orthonormal implies that V is, so we don't have
8587
// to transpose it here.
86-
verifyOrthogonal(Vt);
88+
verifyOrthogonal(Vt, epsilon);
8789
Kokkos::View<Scalar**, typename AView::device_type> usvt("USV^T", m, n);
8890
for (int i = 0; i < maxrank; i++) {
8991
auto Ucol = Kokkos::subview(U, Kokkos::ALL(), Kokkos::make_pair<int>(i, i + 1));
@@ -92,7 +94,7 @@ void verifySVD(const AView& A, const UView& U, const VtView& Vt, const SigmaView
9294
}
9395
for (int i = 0; i < m; i++) {
9496
for (int j = 0; j < n; j++) {
95-
Test::EXPECT_NEAR_KK(usvt(i, j), A(i, j), Test::svdEpsilon<Scalar>());
97+
Test::EXPECT_NEAR_KK(usvt(i, j), A(i, j), tol);
9698
}
9799
}
98100
// Make sure all singular values are positive
@@ -109,19 +111,26 @@ template <typename Matrix, typename Vector>
109111
struct SerialSVDFunctor_Full {
110112
SerialSVDFunctor_Full(const Matrix& A_, const Matrix& U_, const Matrix& Vt_, const Vector& sigma_,
111113
const Vector& work_)
112-
: A(A_), U(U_), Vt(Vt_), sigma(sigma_), work(work_) {}
114+
: A(A_), U(U_), Vt(Vt_), sigma(sigma_), work(work_) {
115+
tol = Kokkos::ArithTraits<double>::zero();
116+
}
117+
118+
SerialSVDFunctor_Full(const Matrix& A_, const Matrix& U_, const Matrix& Vt_, const Vector& sigma_,
119+
const Vector& work_, const double tol_)
120+
: A(A_), U(U_), Vt(Vt_), sigma(sigma_), work(work_), tol(tol_) {}
113121

114122
// NOTE: this functor is only meant to be launched with a single element range
115123
// policy
116124
KOKKOS_INLINE_FUNCTION void operator()(int) const {
117-
KokkosBatched::SerialSVD::invoke(KokkosBatched::SVD_USV_Tag(), A, U, sigma, Vt, work);
125+
KokkosBatched::SerialSVD::invoke(KokkosBatched::SVD_USV_Tag(), A, U, sigma, Vt, work, tol);
118126
}
119127

120128
Matrix A;
121129
Matrix U;
122130
Matrix Vt;
123131
Vector sigma;
124132
Vector work;
133+
double tol;
125134
};
126135

127136
template <typename Matrix, typename Vector>
@@ -497,6 +506,27 @@ Kokkos::View<Scalar**, Layout, Device> getTestCase(int testCase) {
497506
Ahost = MatrixHost("A5", m, n);
498507
break;
499508
}
509+
case 6: {
510+
m = 3;
511+
n = 6;
512+
Ahost = MatrixHost("A6", m, n);
513+
Ahost(0, 0) = -2.3588494081694974e-03;
514+
Ahost(0, 1) = -2.3602176428346553e-03;
515+
Ahost(0, 2) = -3.3360574050870077e-03;
516+
Ahost(0, 3) = -2.3589487578561312e-03;
517+
Ahost(0, 4) = -3.3359167956075490e-03;
518+
Ahost(0, 5) = -3.3378517656821728e-03;
519+
Ahost(1, 0) = 3.3359168246290603e-03;
520+
Ahost(1, 1) = 3.3378518006490351e-03;
521+
Ahost(1, 3) = 3.3360573263032968e-03;
522+
Ahost(2, 0) = -2.3588494081695022e-03;
523+
Ahost(2, 1) = -2.3602176428346587e-03;
524+
Ahost(2, 2) = 3.3360574050869769e-03;
525+
Ahost(2, 3) = -2.3589487578561286e-03;
526+
Ahost(2, 4) = 3.3359167956075399e-03;
527+
Ahost(2, 5) = 3.3378517656821581e-03;
528+
break;
529+
}
500530
default: throw std::runtime_error("Test case out of bounds.");
501531
}
502532
Kokkos::View<Scalar**, Layout, Device> A(Ahost.label(), m, n);
@@ -509,7 +539,7 @@ void testSpecialCases() {
509539
using Matrix = Kokkos::View<Scalar**, Layout, Device>;
510540
using Vector = Kokkos::View<Scalar*, Device>;
511541
using ExecSpace = typename Device::execution_space;
512-
for (int i = 0; i < 6; i++) {
542+
for (int i = 0; i < 7; i++) {
513543
Matrix A = getTestCase<Scalar, Layout, Device>(i);
514544
int m = A.extent(0);
515545
int n = A.extent(1);
@@ -527,15 +557,24 @@ void testSpecialCases() {
527557
typename Matrix::host_mirror_type Acopy("Acopy", m, n);
528558
Kokkos::deep_copy(Acopy, A);
529559
// Run the SVD
530-
Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, 1),
531-
SerialSVDFunctor_Full<Matrix, Vector>(A, U, Vt, sigma, work));
560+
if (std::is_same_v<Scalar, double> && i == 6) {
561+
Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, 1),
562+
SerialSVDFunctor_Full<Matrix, Vector>(A, U, Vt, sigma, work, 1e-9));
563+
} else {
564+
Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, 1),
565+
SerialSVDFunctor_Full<Matrix, Vector>(A, U, Vt, sigma, work));
566+
}
532567
// Get the results back
533568
auto Uhost = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), U);
534569
auto Vthost = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), Vt);
535570
auto sigmaHost = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), sigma);
536571

537572
// Verify the SVD is correct
538-
verifySVD(Acopy, Uhost, Vthost, sigmaHost);
573+
if (std::is_same_v<Scalar, double> && i == 6) {
574+
verifySVD(Acopy, Uhost, Vthost, sigmaHost, 1e-11);
575+
} else {
576+
verifySVD(Acopy, Uhost, Vthost, sigmaHost);
577+
}
539578
}
540579
}
541580

0 commit comments

Comments
 (0)