|
1 | 1 | #pragma once |
| 2 | +#include <math.h> |
| 3 | +#include <stddef.h> |
| 4 | +#include <stdbool.h> |
| 5 | + |
| 6 | +#ifdef __cplusplus |
| 7 | +#include <array> |
| 8 | +#endif |
| 9 | + |
| 10 | +#include "openblas/lapack.h" |
2 | 11 |
|
3 | 12 | #ifdef __cplusplus |
4 | 13 | extern "C" |
5 | 14 | { |
6 | 15 | #endif |
7 | 16 |
|
8 | | -#include <math.h> |
9 | | -#include <stddef.h> |
10 | | -#include <stdbool.h> |
11 | | - |
12 | 17 | /** |
13 | 18 | * @defgroup LinearAlgebra Linear Algebra |
14 | 19 | * @brief Linear algebra utility for matrix and vector computations |
@@ -180,7 +185,7 @@ bool fsb_linalg_is_posdef(const double_t mat[], size_t dim, size_t work_len, dou |
180 | 185 | */ |
181 | 186 | FsbLinalgErrorType fsb_linalg_matrix_sqr_solve( |
182 | 187 | const double_t mat[], const double_t y_vec[], size_t nrhs, size_t dim, size_t work_len, |
183 | | - size_t iwork_len, double_t work[], int iwork[], double_t x_vec[]); |
| 188 | + size_t iwork_len, double_t work[], lapack_int iwork[], double_t x_vec[]); |
184 | 189 |
|
185 | 190 | /** |
186 | 191 | * @brief Inverse of a matrix where number of columns are great er than or equal to number of rows |
@@ -231,4 +236,126 @@ void sample_dgels_example(void); |
231 | 236 |
|
232 | 237 | #ifdef __cplusplus |
233 | 238 | } |
| 239 | + |
| 240 | +/** |
| 241 | + * @brief C++ convenience wrappers that accept std::array. |
| 242 | + * |
| 243 | + * These are header-only and forward to the C ABI functions above. |
| 244 | + */ |
| 245 | +template <size_t Rows, size_t Cols, size_t WorkLen> |
| 246 | +inline FsbLinalgErrorType fsb_linalg_svd_array( |
| 247 | + const std::array<double_t, Rows * Cols>& mat, |
| 248 | + bool u_full, |
| 249 | + bool v_full, |
| 250 | + std::array<double_t, WorkLen>& work, |
| 251 | + std::array<double_t, Rows * Rows>& unitary_u, |
| 252 | + std::array<double_t, FSB_MIN(Rows, Cols)>& sing_val, |
| 253 | + std::array<double_t, Cols * Cols>& unitary_vt) |
| 254 | +{ |
| 255 | + return fsb_linalg_svd( |
| 256 | + mat.data(), Rows, Cols, |
| 257 | + u_full, v_full, |
| 258 | + WorkLen, work.data(), |
| 259 | + unitary_u.data(), sing_val.data(), unitary_vt.data()); |
| 260 | +} |
| 261 | + |
| 262 | +template <size_t Dim, size_t WorkLen> |
| 263 | +inline FsbLinalgErrorType fsb_linalg_matrix_eig_array( |
| 264 | + const std::array<double_t, Dim * Dim>& mat, |
| 265 | + std::array<double_t, WorkLen>& work, |
| 266 | + std::array<double_t, Dim>& val_real, |
| 267 | + std::array<double_t, Dim>& val_imag, |
| 268 | + std::array<double_t, Dim * Dim>& vec_real, |
| 269 | + std::array<double_t, Dim * Dim>& vec_imag) |
| 270 | +{ |
| 271 | + return fsb_linalg_matrix_eig( |
| 272 | + mat.data(), Dim, WorkLen, work.data(), |
| 273 | + val_real.data(), val_imag.data(), |
| 274 | + vec_real.data(), vec_imag.data()); |
| 275 | +} |
| 276 | + |
| 277 | +template <size_t Dim, size_t WorkLen> |
| 278 | +inline FsbLinalgErrorType fsb_linalg_sym_lt_eig_array( |
| 279 | + const std::array<double_t, Dim * Dim>& mat, |
| 280 | + std::array<double_t, WorkLen>& work, |
| 281 | + std::array<double_t, Dim>& val, |
| 282 | + std::array<double_t, Dim * Dim>& vec) |
| 283 | +{ |
| 284 | + return fsb_linalg_sym_lt_eig( |
| 285 | + mat.data(), Dim, WorkLen, work.data(), |
| 286 | + val.data(), vec.data()); |
| 287 | +} |
| 288 | + |
| 289 | +template <size_t Dim> |
| 290 | +inline FsbLinalgErrorType fsb_linalg_cholesky_decomposition_array( |
| 291 | + const std::array<double_t, Dim * Dim>& mat, |
| 292 | + std::array<double_t, Dim * Dim>& mat_chol) |
| 293 | +{ |
| 294 | + return fsb_linalg_cholesky_decomposition(mat.data(), Dim, mat_chol.data()); |
| 295 | +} |
| 296 | + |
| 297 | +template <size_t Dim, size_t WorkLen> |
| 298 | +inline bool fsb_linalg_is_posdef_array( |
| 299 | + const std::array<double_t, Dim * Dim>& mat, |
| 300 | + std::array<double_t, WorkLen>& work) |
| 301 | +{ |
| 302 | + return fsb_linalg_is_posdef(mat.data(), Dim, WorkLen, work.data()); |
| 303 | +} |
| 304 | + |
| 305 | +template <size_t Dim, size_t Nrhs, size_t WorkLen> |
| 306 | +inline FsbLinalgErrorType fsb_linalg_cholesky_solve_array( |
| 307 | + const std::array<double_t, Dim * Dim>& mat, |
| 308 | + const std::array<double_t, Dim * Nrhs>& b_vec, |
| 309 | + std::array<double_t, WorkLen>& work, |
| 310 | + std::array<double_t, Dim * Nrhs>& x_vec) |
| 311 | +{ |
| 312 | + return fsb_linalg_cholesky_solve( |
| 313 | + mat.data(), b_vec.data(), |
| 314 | + Nrhs, Dim, |
| 315 | + WorkLen, work.data(), |
| 316 | + x_vec.data()); |
| 317 | +} |
| 318 | + |
| 319 | +template <size_t Dim, size_t Nrhs, size_t WorkLen, size_t IWorkLen> |
| 320 | +inline FsbLinalgErrorType fsb_linalg_matrix_sqr_solve_array( |
| 321 | + const std::array<double_t, Dim * Dim>& mat, |
| 322 | + const std::array<double_t, Dim * Nrhs>& y_vec, |
| 323 | + std::array<double_t, WorkLen>& work, |
| 324 | + std::array<lapack_int, IWorkLen>& iwork, |
| 325 | + std::array<double_t, Dim * Nrhs>& x_vec) |
| 326 | +{ |
| 327 | + return fsb_linalg_matrix_sqr_solve( |
| 328 | + mat.data(), y_vec.data(), |
| 329 | + Nrhs, Dim, |
| 330 | + WorkLen, IWorkLen, |
| 331 | + work.data(), iwork.data(), |
| 332 | + x_vec.data()); |
| 333 | +} |
| 334 | + |
| 335 | +template <size_t Rows, size_t Cols, size_t WorkLen> |
| 336 | +inline FsbLinalgErrorType fsb_linalg_pseudoinverse_array( |
| 337 | + const std::array<double_t, Rows * Cols>& mat, |
| 338 | + std::array<double_t, WorkLen>& work, |
| 339 | + std::array<double_t, Cols * Rows>& inv_mat) |
| 340 | +{ |
| 341 | + return fsb_linalg_pseudoinverse( |
| 342 | + mat.data(), Rows, Cols, |
| 343 | + WorkLen, work.data(), |
| 344 | + inv_mat.data()); |
| 345 | +} |
| 346 | + |
| 347 | +template <size_t Rows, size_t Cols, size_t Nrhs, size_t WorkLen> |
| 348 | +inline FsbLinalgErrorType fsb_linalg_leastsquares_solve_array( |
| 349 | + const std::array<double_t, Rows * Cols>& mat, |
| 350 | + const std::array<double_t, Rows * Nrhs>& b_vec, |
| 351 | + std::array<double_t, WorkLen>& work, |
| 352 | + std::array<double_t, Cols * Nrhs>& x_vec) |
| 353 | +{ |
| 354 | + return fsb_linalg_leastsquares_solve( |
| 355 | + mat.data(), Rows, Cols, |
| 356 | + b_vec.data(), Nrhs, |
| 357 | + WorkLen, work.data(), |
| 358 | + x_vec.data()); |
| 359 | +} |
| 360 | + |
234 | 361 | #endif |
0 commit comments