1010#include < lapack.h>
1111#endif
1212
13+ // Wrapper to account for differences in
14+ // LAPACK implementations (basically how to pass the 'uplo' string to fortran).
15+ int strtri_wrapper (char uplo, char diag, float * matrix, int N) {
16+ int info;
17+
18+ #ifdef LAPACK_FORTRAN_STRLEN_END
19+ strtri_ (
20+ /* uplo = */ &uplo,
21+ /* diag = */ &diag,
22+ /* N = */ &N,
23+ /* a = */ matrix,
24+ /* lda = */ &N,
25+ /* info = */ &info,
26+ /* uplo_len = */ static_cast <size_t >(1 ),
27+ /* diag_len = */ static_cast <size_t >(1 ));
28+ #else
29+ strtri_ (
30+ /* uplo = */ &uplo,
31+ /* diag = */ &diag,
32+ /* N = */ &N,
33+ /* a = */ matrix,
34+ /* lda = */ &N,
35+ /* info = */ &info);
36+ #endif
37+
38+ return info;
39+ }
40+
1341namespace mlx ::core {
1442
15- void inverse_impl (const array& a, array& inv) {
43+ void general_inv (array& inv, int N, int i) {
44+ int info;
45+ auto ipiv = array::Data{allocator::malloc_or_wait (sizeof (int ) * N)};
46+ // Compute LU factorization.
47+ sgetrf_ (
48+ /* m = */ &N,
49+ /* n = */ &N,
50+ /* a = */ inv.data <float >() + N * N * i,
51+ /* lda = */ &N,
52+ /* ipiv = */ static_cast <int *>(ipiv.buffer .raw_ptr ()),
53+ /* info = */ &info);
54+
55+ if (info != 0 ) {
56+ std::stringstream ss;
57+ ss << " inverse_impl: LU factorization failed with error code " << info;
58+ throw std::runtime_error (ss.str ());
59+ }
60+
61+ static const int lwork_query = -1 ;
62+ float workspace_size = 0 ;
63+
64+ // Compute workspace size.
65+ sgetri_ (
66+ /* m = */ &N,
67+ /* a = */ nullptr ,
68+ /* lda = */ &N,
69+ /* ipiv = */ nullptr ,
70+ /* work = */ &workspace_size,
71+ /* lwork = */ &lwork_query,
72+ /* info = */ &info);
73+
74+ if (info != 0 ) {
75+ std::stringstream ss;
76+ ss << " inverse_impl: LU workspace calculation failed with error code "
77+ << info;
78+ throw std::runtime_error (ss.str ());
79+ }
80+
81+ const int lwork = workspace_size;
82+ auto scratch = array::Data{allocator::malloc_or_wait (sizeof (float ) * lwork)};
83+
84+ // Compute inverse.
85+ sgetri_ (
86+ /* m = */ &N,
87+ /* a = */ inv.data <float >() + N * N * i,
88+ /* lda = */ &N,
89+ /* ipiv = */ static_cast <int *>(ipiv.buffer .raw_ptr ()),
90+ /* work = */ static_cast <float *>(scratch.buffer .raw_ptr ()),
91+ /* lwork = */ &lwork,
92+ /* info = */ &info);
93+
94+ if (info != 0 ) {
95+ std::stringstream ss;
96+ ss << " inverse_impl: inversion failed with error code " << info;
97+ throw std::runtime_error (ss.str ());
98+ }
99+ }
100+
101+ void tri_inv (array& inv, int N, int i, bool upper) {
102+ const char uplo = upper ? ' L' : ' U' ;
103+ const char diag = ' N' ;
104+ int info = strtri_wrapper (uplo, diag, inv.data <float >() + N * N * i, N);
105+ if (info != 0 ) {
106+ std::stringstream ss;
107+ ss << " inverse_impl: triangular inversion failed with error code " << info;
108+ throw std::runtime_error (ss.str ());
109+ }
110+ }
111+
112+ void inverse_impl (const array& a, array& inv, bool tri, bool upper) {
16113 // Lapack uses the column-major convention. We take advantage of the following
17114 // identity to avoid transposing (see
18115 // https://math.stackexchange.com/a/340234):
@@ -24,63 +121,11 @@ void inverse_impl(const array& a, array& inv) {
24121 const int N = a.shape (-1 );
25122 const size_t num_matrices = a.size () / (N * N);
26123
27- int info;
28- auto ipiv = array::Data{allocator::malloc_or_wait (sizeof (int ) * N)};
29-
30124 for (int i = 0 ; i < num_matrices; i++) {
31- // Compute LU factorization.
32- sgetrf_ (
33- /* m = */ &N,
34- /* n = */ &N,
35- /* a = */ inv.data <float >() + N * N * i,
36- /* lda = */ &N,
37- /* ipiv = */ static_cast <int *>(ipiv.buffer .raw_ptr ()),
38- /* info = */ &info);
39-
40- if (info != 0 ) {
41- std::stringstream ss;
42- ss << " inverse_impl: LU factorization failed with error code " << info;
43- throw std::runtime_error (ss.str ());
44- }
45-
46- static const int lwork_query = -1 ;
47- float workspace_size = 0 ;
48-
49- // Compute workspace size.
50- sgetri_ (
51- /* m = */ &N,
52- /* a = */ nullptr ,
53- /* lda = */ &N,
54- /* ipiv = */ nullptr ,
55- /* work = */ &workspace_size,
56- /* lwork = */ &lwork_query,
57- /* info = */ &info);
58-
59- if (info != 0 ) {
60- std::stringstream ss;
61- ss << " inverse_impl: LU workspace calculation failed with error code "
62- << info;
63- throw std::runtime_error (ss.str ());
64- }
65-
66- const int lwork = workspace_size;
67- auto scratch =
68- array::Data{allocator::malloc_or_wait (sizeof (float ) * lwork)};
69-
70- // Compute inverse.
71- sgetri_ (
72- /* m = */ &N,
73- /* a = */ inv.data <float >() + N * N * i,
74- /* lda = */ &N,
75- /* ipiv = */ static_cast <int *>(ipiv.buffer .raw_ptr ()),
76- /* work = */ static_cast <float *>(scratch.buffer .raw_ptr ()),
77- /* lwork = */ &lwork,
78- /* info = */ &info);
79-
80- if (info != 0 ) {
81- std::stringstream ss;
82- ss << " inverse_impl: inversion failed with error code " << info;
83- throw std::runtime_error (ss.str ());
125+ if (tri) {
126+ tri_inv (inv, N, i, upper);
127+ } else {
128+ general_inv (inv, N, i);
84129 }
85130 }
86131}
@@ -89,7 +134,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
89134 if (inputs[0 ].dtype () != float32) {
90135 throw std::runtime_error (" [Inverse::eval] only supports float32." );
91136 }
92- inverse_impl (inputs[0 ], output);
137+ inverse_impl (inputs[0 ], output, tri_, upper_ );
93138}
94139
95140} // namespace mlx::core
0 commit comments