2
2
3
3
#include " Eigen/Eigenvalues"
4
4
5
+ #include < algorithm>
5
6
#include < iostream>
7
+ #include < numeric>
8
+ #include < vector>
6
9
7
10
template <typename DataType>
8
- void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
9
- typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
11
+ void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out,
12
+ DataType *eigenvectors_out,
13
+ DataType *in, uint64_t m, uint64_t n) {
14
+ typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
15
+ Eigen::RowMajor>
16
+ RowMajorMatrix;
10
17
11
18
// Map the input matrix
12
19
Eigen::Map<RowMajorMatrix> input (in, m, n);
@@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
20
27
}
21
28
22
29
// Get the eigenvalues and eigenvectors
23
- Eigen::Matrix<DataType, Eigen::Dynamic, 1 > eigenvalues = eigensolver.eigenvalues ();
30
+ Eigen::Matrix<DataType, Eigen::Dynamic, 1 > eigenvalues =
31
+ eigensolver.eigenvalues ();
24
32
RowMajorMatrix eigenvectors = eigensolver.eigenvectors ();
25
33
26
- // Copy the eigenvalues to the output
27
- std::memcpy (eigenvalues_out, eigenvalues.data (), m * sizeof (DataType));
34
+ // Create a vector of indices and sort it based on eigenvalues in decreasing
35
+ // order
36
+ std::vector<int > indices (m);
37
+ std::iota (indices.begin (), indices.end (), 0 );
38
+ std::sort (indices.begin (), indices.end (), [&eigenvalues](int i, int j) {
39
+ return std::abs (eigenvalues (i)) > std::abs (eigenvalues (j));
40
+ });
41
+
42
+ // Sort eigenvalues and rearrange eigenvectors
43
+ Eigen::Matrix<DataType, Eigen::Dynamic, 1 > sorted_eigenvalues (m);
44
+ RowMajorMatrix sorted_eigenvectors (m, n);
45
+ for (int i = 0 ; i < m; ++i) {
46
+ sorted_eigenvalues (i) = eigenvalues (indices[i]);
47
+ sorted_eigenvectors.col (i) = eigenvectors.col (indices[i]);
48
+ }
49
+
50
+ // Copy the sorted eigenvalues to the output
51
+ std::memcpy (eigenvalues_out, sorted_eigenvalues.data (), m * sizeof (DataType));
28
52
29
- // Copy the eigenvectors to the output
30
- std::memcpy (eigenvectors_out, eigenvectors.data (), m * n * sizeof (DataType));
53
+ // Copy the sorted eigenvectors to the output
54
+ std::memcpy (eigenvectors_out, sorted_eigenvectors.data (),
55
+ m * n * sizeof (DataType));
31
56
}
32
57
33
58
template <typename DataType>
@@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
40
65
uint64_t num_eigenvectors_dims = dim_sizes[2 ];
41
66
42
67
uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
43
- std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
68
+ std::vector<uint64_t > operand_dims (operand_dims_ptr,
69
+ operand_dims_ptr + num_operand_dims);
44
70
45
71
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3 ];
46
- std::vector<uint64_t > eigenvalues_dims (eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
72
+ std::vector<uint64_t > eigenvalues_dims (
73
+ eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
47
74
48
75
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4 ];
49
- std::vector<uint64_t > eigenvectors_dims (eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
76
+ std::vector<uint64_t > eigenvectors_dims (
77
+ eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
50
78
51
79
uint64_t m = eigenvectors_dims[eigenvectors_dims.size () - 2 ];
52
80
uint64_t n = eigenvectors_dims[eigenvectors_dims.size () - 1 ];
53
81
54
- auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
82
+ auto leading_dimensions =
83
+ std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
55
84
56
85
uint64_t batch_items = 1 ;
57
86
for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
@@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
61
90
DataType *eigenvalues = (DataType *)out[0 ];
62
91
DataType *eigenvectors = (DataType *)out[1 ];
63
92
64
- uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ] * sizeof (DataType);
65
- uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size () - 1 ] * eigenvectors_dims[eigenvectors_dims.size () - 2 ] * sizeof (DataType);
66
- uint64_t inner_stride = m * n * sizeof (DataType);
93
+ uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ];
94
+ uint64_t eigenvectors_stride =
95
+ eigenvectors_dims[eigenvectors_dims.size () - 1 ] *
96
+ eigenvectors_dims[eigenvectors_dims.size () - 2 ];
97
+ uint64_t inner_stride = m * n;
67
98
68
99
for (uint64_t i = 0 ; i < batch_items; i++) {
69
100
single_matrix_eigh_cpu_custom_call<DataType>(
70
101
eigenvalues + i * eigenvalues_stride,
71
- eigenvectors + i * eigenvectors_stride,
72
- operand + i * inner_stride / sizeof (DataType),
73
- m, n);
102
+ eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
103
+ n);
74
104
}
75
105
}
0 commit comments