Skip to content

Commit ccf93b8

Browse files
Thoemi09Wentzell
authored andcommitted
Add svd and svd_in_place functions to compute SVDs
1 parent 362dc78 commit ccf93b8

File tree

3 files changed

+146
-2
lines changed

3 files changed

+146
-2
lines changed

c++/nda/linalg.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@
3131
#include "./linalg/matmul.hpp"
3232
#include "./linalg/norm.hpp"
3333
#include "./linalg/solve.hpp"
34+
#include "./linalg/svd.hpp"

c++/nda/linalg/svd.hpp

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright (c) 2019-2024 Simons Foundation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0.txt
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16+
17+
/**
18+
* @file
19+
* @brief Provides functions to compute the singular value decoomposition of a matrix.
20+
*/
21+
22+
#pragma once
23+
24+
#include "../basic_array.hpp"
25+
#include "../blas/tools.hpp"
26+
#include "../declarations.hpp"
27+
#include "../exceptions.hpp"
28+
#include "../lapack/gesvd.hpp"
29+
#include "../layout/policies.hpp"
30+
#include "../macros.hpp"
31+
#include "../mem/address_space.hpp"
32+
#include "../mem/policies.hpp"
33+
#include "../traits.hpp"
34+
35+
#include <algorithm>
36+
#include <tuple>
37+
#include <type_traits>
38+
39+
namespace nda {
40+
41+
/**
42+
* @addtogroup linalg_tools
43+
* @{
44+
*/
45+
46+
/**
47+
* @brief Compute the singular value decomposition (SVD) of a matrix in place.
48+
*
49+
* @details The function computes the SVD of a given m-by-n matrix \f$ \mathbf{A} \f$:
50+
* \f[
51+
* \mathbf{A} = \mathbf{U} \mathbf{S} \mathbf{V}^H \; ,
52+
* \f]
53+
* where \f$ \mathbf{U} \f$ is a unitary m-by-m matrix, \f$ \mathbf{V} \f$ is a unitary n-by-n matrix and \f$
54+
* \mathbf{S} \f$ is an m-by-n matrix with non-negative real numbers on the diagonal.
55+
*
56+
* It first constructs the output vector \f$ \mathbf{s} \f$, which contains the singular values, and the output
57+
* matrices \f$ \mathbf{U} \f$ and \f$ \mathbf{V}^H \f$. It then calls nda::lapack::gesvd to compute the SVD.
58+
*
59+
* @note If the input matrix \f$ \mathbf{A} \f$ is in Fortran layout, the output matrices \f$ \mathbf{U} \f$ and
60+
* \f$ \mathbf{V}^H \f$ are also in Fortran layout. Otherwise, they are in C layout.
61+
*
62+
* @tparam A nda::MemoryMatrix type.
63+
* @param a Input/output matrix. On entry, the m-by-n matrix \f$ \mathbf{A} \f$. On exit, the contents of \f$
64+
* \mathbf{A} \f$ are destroyed.
65+
* @return `std::tuple` containing \f$ \mathbf{U} \f$, \f$ \mathbf{s} \f$ and \f$ \mathbf{V}^H \f$.
66+
*/
67+
template <MemoryMatrix A>
68+
requires(is_blas_lapack_v<get_value_t<A>>)
69+
auto svd_in_place(A &&a) { // NOLINT (temporary views are allowed here)
70+
using layout_policy = detail::layout_to_policy<typename std::remove_cvref_t<A>::layout_t>::type;
71+
constexpr auto addr_space = mem::get_addr_space<A>;
72+
73+
// vector s and matrices U and V^H
74+
auto s = vector<double, heap<addr_space>>(std::min(a.extent(0), a.extent(1)));
75+
auto U = matrix<get_value_t<A>, layout_policy, heap<addr_space>>(a.extent(0), a.extent(0));
76+
auto VH = matrix<get_value_t<A>, layout_policy, heap<addr_space>>(a.extent(1), a.extent(1));
77+
78+
// call lapack gesvd
79+
int info = lapack::gesvd(a, s, U, VH);
80+
if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::svd_in_place: gesvd returned a non-zero value: info = " << info;
81+
82+
return std::make_tuple(U, s, VH);
83+
}
84+
85+
/**
86+
* @brief Compute the singular value decomposition (SVD) of a matrix.
87+
*
88+
* @details The function computes the SVD of a given m-by-n matrix \f$ \mathbf{A} \f$:
89+
* \f[
90+
* \mathbf{A} = \mathbf{U} \mathbf{S} \mathbf{V}^H \; ,
91+
* \f]
92+
* where \f$ \mathbf{U} \f$ is a unitary m-by-m matrix, \f$ \mathbf{V} \f$ is a unitary n-by-n matrix and \f$
93+
* \mathbf{S} \f$ is an m-by-n matrix with non-negative real numbers on the diagonal.
94+
*
95+
* It first makes a copy of the input matrix \f$ \mathbf{A} \f$ and then calls nda::svd_in_place with the copy.
96+
*
97+
* @tparam A nda::MemoryMatrix type.
98+
* @param a Input matrix \f$ \mathbf{A} \f$.
99+
* @return `std::tuple` containing \f$ \mathbf{U} \f$, \f$ \mathbf{s} \f$ and \f$ \mathbf{V}^H \f$.
100+
*/
101+
template <Matrix A>
102+
requires(is_blas_lapack_v<get_value_t<A>>)
103+
auto svd(A const &a) { // NOLINT (temporary views are allowed here)
104+
using layout_policy = detail::layout_to_policy<typename A::layout_t>::type;
105+
constexpr auto addr_space = mem::get_addr_space<A>;
106+
auto a_copy = matrix<get_value_t<A>, layout_policy, heap<addr_space>>(a);
107+
return svd_in_place(a_copy);
108+
}
109+
110+
/** @} */
111+
112+
} // namespace nda

test/c++/nda_linear_algebra.cpp

+33-2
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ void test_solve() {
404404

405405
// solve A * X = B using the exact matrix inverse
406406
auto Ainv = matrix_t{{-24, 18, 5}, {20, -15, -4}, {-5, 4, 1}};
407-
auto X = matrix_t{Ainv * B};
407+
auto X = matrix_t{Ainv * B};
408408
EXPECT_ARRAY_NEAR(matrix_t{A * X}, B);
409409

410410
// solve A * X = B using solve_in_place
@@ -415,7 +415,7 @@ void test_solve() {
415415
EXPECT_ARRAY_NEAR(X, Bcopy);
416416

417417
// solve A * x = b using solve_in_place
418-
Acopy = A;
418+
Acopy = A;
419419
auto b = vector_t{B(nda::range::all, 0)};
420420
nda::solve_in_place(Acopy, b);
421421
EXPECT_ARRAY_NEAR(A * b, B(nda::range::all, 0));
@@ -438,3 +438,34 @@ TEST(NDA, LinearAlgebraSolve) {
438438
test_solve<std::complex<double>, nda::C_layout>();
439439
test_solve<std::complex<double>, nda::F_layout>();
440440
}
441+
442+
// Check the SVD of a matrix.
443+
template <typename value_t, typename Layout>
444+
void test_svd() {
445+
using matrix_t = nda::matrix<value_t, Layout>;
446+
447+
auto A = matrix_t{{2, -2, 1}, {-4, -8, -8}};
448+
auto s = nda::vector<double>{12, 3};
449+
450+
// compute the SVD of A
451+
auto [U_1, s_1, VH_1] = nda::svd(A);
452+
auto S_1 = matrix_t::zeros(A.shape());
453+
for (auto i : nda::range(2)) S_1(i, i) = s_1(i);
454+
EXPECT_ARRAY_NEAR(s_1, s, 1e-14);
455+
EXPECT_ARRAY_NEAR(A, U_1 * S_1 * VH_1, 1e-14);
456+
457+
// compute the SVD of A in place
458+
auto A_copy = A;
459+
auto [U_2, s_2, VH_2] = nda::svd_in_place(A_copy);
460+
auto S_2 = matrix_t::zeros(A.shape());
461+
for (auto i : nda::range(2)) S_2(i, i) = s_2(i);
462+
EXPECT_ARRAY_NEAR(s, s_2, 1e-14);
463+
EXPECT_ARRAY_NEAR(A, U_2 * S_2 * VH_2, 1e-14);
464+
}
465+
466+
TEST(NDA, LinearAlgebraSVD) {
467+
test_svd<double, nda::C_layout>();
468+
test_svd<double, nda::F_layout>();
469+
test_svd<std::complex<double>, nda::C_layout>();
470+
test_svd<std::complex<double>, nda::F_layout>();
471+
}

0 commit comments

Comments
 (0)