Skip to content

Generalize and fix LAPACK routines and add a solve function for linear systems #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: unstable
Choose a base branch
from
Open
42 changes: 30 additions & 12 deletions c++/nda/blas/tools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,45 @@ namespace nda::blas {
}();

/**
* @brief Get the leading dimension in LAPACK jargon of an nda::MemoryMatrix.
* @brief Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for LAPACK calls.
*
* @tparam A nda::MemoryMatrix type.
* @param a nda::MemoryMatrix object.
* @return Leading dimension.
* @details The leading dimension is the stride between two consecutive columns (rows) of a matrix in Fortran (C)
* layout. For 1-dimensional arrays, we simply return the size of the array.
*
* @tparam A nda::MemoryArray type.
* @param a nda::MemoryArray object.
* @return Leading dimension for BLAS/LAPACK calls.
*/
template <MemoryMatrix A>
template <MemoryArray A>
requires(get_rank<A> == 1 or get_rank<A> == 2)
int get_ld(A const &a) {
return a.indexmap().strides()[has_F_layout<A> ? 1 : 0];
if constexpr (get_rank<A> == 1) {
return a.size();
} else {
return a.indexmap().strides()[has_F_layout<A> ? 1 : 0];
}
}

/**
* @brief Get the number of columns in LAPACK jargon of an nda::MemoryMatrix.
* @brief Get the number of columns of an nda::MemoryArray for BLAS/LAPACK calls.
*
* @details The number of columns corresponds to the extent of the second (first) dimension of a matrix in Fortran
* (C) layout. For 1-dimensional arrays, we return 1.
*
* @tparam A nda::MemoryMatrix type.
* @param a nda::MemoryMatrix object.
* @return Number of columns.
* @note This is not necessarily the same as the number of columns in the mathematical sense.
*
* @tparam A nda::MemoryArray type.
* @param a nda::MemoryArray object.
* @return Number of columns for BLAS/LAPACK calls.
*/
template <MemoryMatrix A>
template <MemoryArray A>
requires(get_rank<A> == 1 or get_rank<A> == 2)
int get_ncols(A const &a) {
return a.shape()[has_F_layout<A> ? 1 : 0];
if constexpr (get_rank<A> == 1) {
return 1;
} else {
return a.shape()[has_F_layout<A> ? 1 : 0];
}
}

/** @} */
Expand Down
36 changes: 26 additions & 10 deletions c++/nda/lapack/gesvd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#pragma once

#include "./interface/cxx_interface.hpp"
#include "../basic_functions.hpp"
#include "../concepts.hpp"
#include "../declarations.hpp"
#include "../exceptions.hpp"
Expand All @@ -38,6 +39,7 @@
#include <algorithm>
#include <cmath>
#include <complex>
#include <concepts>
#include <utility>

namespace nda::lapack {
Expand Down Expand Up @@ -71,14 +73,19 @@ namespace nda::lapack {
* @return Integer return code from the LAPACK call.
*/
template <MemoryMatrix A, MemoryVector S, MemoryMatrix U, MemoryMatrix VT>
requires(have_same_value_type_v<A, U, VT> and mem::have_compatible_addr_space<A, S, U, VT> and is_blas_lapack_v<get_value_t<A>>)
requires(have_same_value_type_v<A, U, VT> and mem::have_compatible_addr_space<A, S, U, VT> and is_blas_lapack_v<get_value_t<A>>
and std::same_as<double, get_value_t<S>>)
int gesvd(A &&a, S &&s, U &&u, VT &&vt) { // NOLINT (temporary views are allowed here)
static_assert(has_F_layout<A> and has_F_layout<U> and has_F_layout<VT>, "Error in nda::lapack::gesvd: C order not supported");
static_assert(has_C_layout<A> == has_C_layout<U> and has_C_layout<A> == has_C_layout<VT>,
"Error in nda::lapack::gesvd: Matrix layouts have to be the same");

// check the dimensions of the output arrays/views and resize if necessary
auto dm = std::min(a.extent(0), a.extent(1));
if (s.size() < dm) s.resize(dm);
resize_or_check_if_view(s, {dm});
resize_or_check_if_view(u, {a.extent(0), a.extent(0)});
resize_or_check_if_view(vt, {a.extent(1), a.extent(1)});

// must be lapack compatible
// arrays/views must be LAPACK compatible
EXPECTS(a.indexmap().min_stride() == 1);
EXPECTS(s.indexmap().min_stride() == 1);
EXPECTS(u.indexmap().min_stride() == 1);
Expand All @@ -102,16 +109,25 @@ namespace nda::lapack {
value_type bufferSize_T{};
auto rwork = array<double, 1, C_layout, heap<mem::get_addr_space<A>>>(5 * dm);
int info = 0;
gesvd_call('A', 'A', a.extent(0), a.extent(1), a.data(), get_ld(a), s.data(), u.data(), get_ld(u), vt.data(), get_ld(vt), &bufferSize_T, -1,
rwork.data(), info);
if constexpr (has_C_layout<A>) {
gesvd_call('A', 'A', a.extent(1), a.extent(0), a.data(), get_ld(a), s.data(), vt.data(), get_ld(vt), u.data(), get_ld(u), &bufferSize_T, -1,
rwork.data(), info);
} else {
gesvd_call('A', 'A', a.extent(0), a.extent(1), a.data(), get_ld(a), s.data(), u.data(), get_ld(u), vt.data(), get_ld(vt), &bufferSize_T, -1,
rwork.data(), info);
}
int bufferSize = static_cast<int>(std::ceil(std::real(bufferSize_T)));

// allocate work buffer and perform actual library call
nda::array<value_type, 1, C_layout, heap<mem::get_addr_space<A>>> work(bufferSize);
gesvd_call('A', 'A', a.extent(0), a.extent(1), a.data(), get_ld(a), s.data(), u.data(), get_ld(u), vt.data(), get_ld(vt), work.data(), bufferSize,
rwork.data(), info);
array<value_type, 1, C_layout, heap<mem::get_addr_space<A>>> work(bufferSize);
if constexpr (has_C_layout<A>) {
gesvd_call('A', 'A', a.extent(1), a.extent(0), a.data(), get_ld(a), s.data(), vt.data(), get_ld(vt), u.data(), get_ld(u), work.data(),
bufferSize, rwork.data(), info);
} else {
gesvd_call('A', 'A', a.extent(0), a.extent(1), a.data(), get_ld(a), s.data(), u.data(), get_ld(u), vt.data(), get_ld(vt), work.data(),
bufferSize, rwork.data(), info);
}

if (info) NDA_RUNTIME_ERROR << "Error in nda::lapack::gesvd: info = " << info;
return info;
}

Expand Down
13 changes: 11 additions & 2 deletions c++/nda/lapack/getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include "../mem/address_space.hpp"
#include "../traits.hpp"

#ifndef NDA_HAVE_DEVICE
#include "../device.hpp"
#endif // NDA_HAVE_DEVICE

#include <algorithm>
#include <type_traits>

Expand Down Expand Up @@ -62,10 +66,14 @@ namespace nda::lapack {
int getrf(A &&a, IPIV &&ipiv) { // NOLINT (temporary views are allowed here)
static_assert(std::is_same_v<get_value_t<IPIV>, int>, "Error in nda::lapack::getri: Pivoting array must have elements of type int");

// for C-layout arrays, call getrf with the transpose
if constexpr (has_C_layout<A>) return getrf(transpose(a), ipiv);

// check the dimensions of the input/output arrays/views and resize if necessary
auto dm = std::min(a.extent(0), a.extent(1));
if (ipiv.size() < dm) ipiv.resize(dm); // ipiv needs to be a regular array?
if (ipiv.size() < dm) ipiv.resize(dm);

// must be lapack compatible
// arrays/views must be LAPACK compatible
EXPECTS(a.indexmap().min_stride() == 1);
EXPECTS(ipiv.indexmap().min_stride() == 1);

Expand All @@ -75,6 +83,7 @@ namespace nda::lapack {
#endif
#endif

// perform actual library call
int info = 0;
if constexpr (mem::have_device_compatible_addr_space<A, IPIV>) {
#if defined(NDA_HAVE_DEVICE)
Expand Down
13 changes: 9 additions & 4 deletions c++/nda/lapack/getri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#include "../mem/address_space.hpp"
#include "../traits.hpp"

#ifndef NDA_HAVE_DEVICE
#include "../device.hpp"
#endif // NDA_HAVE_DEVICE

#include <algorithm>
#include <cmath>
#include <complex>
Expand Down Expand Up @@ -58,15 +62,16 @@ namespace nda::lapack {
requires(mem::have_compatible_addr_space<A, IPIV> and is_blas_lapack_v<get_value_t<A>>)
int getri(A &&a, IPIV const &ipiv) { // NOLINT (temporary views are allowed here)
static_assert(std::is_same_v<get_value_t<IPIV>, int>, "Error in nda::lapack::getri: Pivoting array must have elements of type int");
auto dm = std::min(a.extent(0), a.extent(1));

if (ipiv.size() < dm)
NDA_RUNTIME_ERROR << "Error in nda::lapack::getri: Pivot index array size " << ipiv.size() << " smaller than required size " << dm;
// check the dimensions of the input/output arrays/views and resize if necessary
EXPECTS(a.extent(0) == a.extent(1));
EXPECTS(ipiv.size() >= a.extent(0));

// must be lapack compatible
// arrays/views must be LAPACK compatible
EXPECTS(a.indexmap().min_stride() == 1);
EXPECTS(ipiv.indexmap().min_stride() == 1);

// perform the LAPACK calls
int info = 0;
if constexpr (mem::have_device_compatible_addr_space<A, IPIV>) {
#if defined(NDA_HAVE_DEVICE)
Expand Down
18 changes: 12 additions & 6 deletions c++/nda/lapack/getrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "./interface/cxx_interface.hpp"
#include "../concepts.hpp"
#include "../declarations.hpp"
#include "../macros.hpp"
#include "../mem/address_space.hpp"
#include "../traits.hpp"
Expand All @@ -49,7 +50,7 @@ namespace nda::lapack {
* with a general n-by-n matrix \f$ \mathbf{A} \f$ using the LU factorization computed by `getrf`.
*
* @tparam A nda::MemoryMatrix type.
* @tparam B nda::MemoryMatrix type.
* @tparam B nda::MemoryArray type.
* @tparam IPIV nda::MemoryVector type.
* @param a Input matrix. The factors \f$ \mathbf{L} \f$ and \f$ \mathbf{U} \f$ from the factorization \f$ \mathbf{A}
* = \mathbf{P L U} \f$ as computed by `getrf`.
Expand All @@ -59,20 +60,25 @@ namespace nda::lapack {
* interchanged with row `ipiv(i)`.
* @return Integer return code from the LAPACK call.
*/
template <MemoryMatrix A, MemoryMatrix B, MemoryVector IPIV>
template <MemoryMatrix A, MemoryArray B, MemoryVector IPIV>
requires(have_same_value_type_v<A, B> and mem::have_compatible_addr_space<A, B, IPIV> and is_blas_lapack_v<get_value_t<A>>)
int getrs(A const &a, B &&b, IPIV const &ipiv) { // NOLINT (temporary views are allowed here)
static_assert(std::is_same_v<get_value_t<IPIV>, int>, "Error in nda::lapack::getrs: Pivoting array must have elements of type int");
static_assert(get_rank<B> == 1 || get_rank<B> == 2, "Error in nda::lapack::getrs: Right hand side must have rank 1 or 2");
static_assert(has_F_layout<B> or get_rank<B> == 1, "Error in nda::lapack::getrs: B must have Fortran layout or rank 1");

// check the dimensions of the input/output arrays/views and resize if necessary
EXPECTS(a.extent(0) == a.extent(1));
EXPECTS(b.extent(0) == a.extent(0));
EXPECTS(ipiv.size() >= std::min(a.extent(0), a.extent(1)));

// must be lapack compatible
// arrays/views must be LAPACK compatible
EXPECTS(a.indexmap().min_stride() == 1);
EXPECTS(b.indexmap().min_stride() == 1);
EXPECTS(ipiv.indexmap().min_stride() == 1);

// check for lazy expressions
static constexpr bool conj_A = is_conj_array_expr<A>;
char op_a = get_op<conj_A, /* transpose = */ has_C_layout<A>>;
// check for lazy expressions and C-layout
char op_a = get_op<is_conj_array_expr<A>, has_C_layout<A>>;

// perform actual library call
int info = 0;
Expand Down
18 changes: 11 additions & 7 deletions c++/nda/lapack/gtsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "./interface/cxx_interface.hpp"
#include "../concepts.hpp"
#include "../declarations.hpp"
#include "../macros.hpp"
#include "../mem/address_space.hpp"
#include "../traits.hpp"
Expand All @@ -42,6 +43,10 @@ namespace nda::lapack {
* Note that the equation \f$ \mathbf{A}^T \mathbf{X} = \mathbf{B} \f$ may be solved by interchanging the order of the
* arguments containing the subdiagonal elements.
*
* @note If the array \f$ \mathbf{B} \f$ is a matrix in C-layout, it will create a temporary copy of it with Fortran
* layout before calling the LAPACK routine. After the call, the result will be copied back to the original array.
* This might be inefficient for large arrays and it is recommended to use Fortran layout for input arrays.
*
* @tparam DL nda::MemoryVector type.
* @tparam D nda::MemoryVector type.
* @tparam DU nda::MemoryVector type.
Expand All @@ -61,17 +66,16 @@ namespace nda::lapack {
requires(have_same_value_type_v<DL, D, DU, B> and mem::on_host<DL, D, DU, B> and is_blas_lapack_v<get_value_t<DL>>)
int gtsv(DL &&dl, D &&d, DU &&du, B &&b) { // NOLINT (temporary views are allowed here)
static_assert((get_rank<B> == 1 or get_rank<B> == 2), "Error in nda::lapack::gtsv: B must be an matrix/array/view of rank 1 or 2");
static_assert(has_F_layout<B> or get_rank<B> == 1, "Error in nda::lapack::getrs: B must have Fortran layout or rank 1");

// get and check dimensions of input arrays
EXPECTS(dl.extent(0) == d.extent(0) - 1); // "gtsv : dimension mismatch between sub-diagonal and diagonal vectors "
EXPECTS(du.extent(0) == d.extent(0) - 1); // "gtsv : dimension mismatch between super-diagonal and diagonal vectors "
EXPECTS(b.extent(0) == d.extent(0)); // "gtsv : dimension mismatch between diagonal vector and RHS matrix, "
// check the dimensions of the input/output arrays/views
EXPECTS(dl.extent(0) == d.extent(0) - 1);
EXPECTS(du.extent(0) == d.extent(0) - 1);
EXPECTS(b.extent(0) == d.extent(0));

// perform actual library call
int N = d.extent(0);
int NRHS = (get_rank<B> == 2 ? b.extent(1) : 1);
int info = 0;
f77::gtsv(N, NRHS, dl.data(), d.data(), du.data(), b.data(), N, info);
f77::gtsv(d.extent(0), (get_rank<B> == 2 ? b.extent(1) : 1), dl.data(), d.data(), du.data(), b.data(), get_ld(b), info);
return info;
}

Expand Down
2 changes: 2 additions & 0 deletions c++/nda/linalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@
#include "./linalg/eigenelements.hpp"
#include "./linalg/matmul.hpp"
#include "./linalg/norm.hpp"
#include "./linalg/solve.hpp"
#include "./linalg/svd.hpp"
4 changes: 2 additions & 2 deletions c++/nda/linalg/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ namespace nda {
static constexpr bool is_valid_gemm_triple = []() {
using blas::has_F_layout;
if constexpr (has_F_layout<C>) {
return !(conj_A and has_F_layout<A>)and!(conj_B and has_F_layout<B>);
return !(conj_A and has_F_layout<A>) and !(conj_B and has_F_layout<B>);
} else {
return !(conj_B and !has_F_layout<B>)and!(conj_A and !has_F_layout<A>);
return !(conj_B and !has_F_layout<B>) and !(conj_A and !has_F_layout<A>);
}
}();

Expand Down
Loading
Loading