Skip to content

Commit 9526198

Browse files
authored
fix(exla): batched eigh (#1591)
1 parent 45066ad commit 9526198

File tree

3 files changed

+117
-19
lines changed

3 files changed

+117
-19
lines changed

exla/c_src/exla/custom_calls/eigh.h

+47-17
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22

33
#include "Eigen/Eigenvalues"
44

5+
#include <algorithm>
56
#include <iostream>
7+
#include <numeric>
8+
#include <vector>
69

710
template <typename DataType>
8-
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
9-
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
11+
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
12+
DataType *eigenvectors_out,
13+
DataType *in, uint64_t m, uint64_t n) {
14+
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
15+
Eigen::RowMajor>
16+
RowMajorMatrix;
1017

1118
// Map the input matrix
1219
Eigen::Map<RowMajorMatrix> input(in, m, n);
@@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
2027
}
2128

2229
// Get the eigenvalues and eigenvectors
23-
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
30+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues =
31+
eigensolver.eigenvalues();
2432
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();
2533

26-
// Copy the eigenvalues to the output
27-
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));
34+
// Create a vector of indices and sort it based on eigenvalues in decreasing
35+
// order
36+
std::vector<int> indices(m);
37+
std::iota(indices.begin(), indices.end(), 0);
38+
std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) {
39+
return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j));
40+
});
41+
42+
// Sort eigenvalues and rearrange eigenvectors
43+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> sorted_eigenvalues(m);
44+
RowMajorMatrix sorted_eigenvectors(m, n);
45+
for (int i = 0; i < m; ++i) {
46+
sorted_eigenvalues(i) = eigenvalues(indices[i]);
47+
sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]);
48+
}
49+
50+
// Copy the sorted eigenvalues to the output
51+
std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), m * sizeof(DataType));
2852

29-
// Copy the eigenvectors to the output
30-
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
53+
// Copy the sorted eigenvectors to the output
54+
std::memcpy(eigenvectors_out, sorted_eigenvectors.data(),
55+
m * n * sizeof(DataType));
3156
}
3257

3358
template <typename DataType>
@@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
4065
uint64_t num_eigenvectors_dims = dim_sizes[2];
4166

4267
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
43-
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
68+
std::vector<uint64_t> operand_dims(operand_dims_ptr,
69+
operand_dims_ptr + num_operand_dims);
4470

4571
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
46-
std::vector<uint64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
72+
std::vector<uint64_t> eigenvalues_dims(
73+
eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
4774

4875
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
49-
std::vector<uint64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
76+
std::vector<uint64_t> eigenvectors_dims(
77+
eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
5078

5179
uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
5280
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
5381

54-
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
82+
auto leading_dimensions =
83+
std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
5584

5685
uint64_t batch_items = 1;
5786
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
@@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
6190
DataType *eigenvalues = (DataType *)out[0];
6291
DataType *eigenvectors = (DataType *)out[1];
6392

64-
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
65-
uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
66-
uint64_t inner_stride = m * n * sizeof(DataType);
93+
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
94+
uint64_t eigenvectors_stride =
95+
eigenvectors_dims[eigenvectors_dims.size() - 1] *
96+
eigenvectors_dims[eigenvectors_dims.size() - 2];
97+
uint64_t inner_stride = m * n;
6798

6899
for (uint64_t i = 0; i < batch_items; i++) {
69100
single_matrix_eigh_cpu_custom_call<DataType>(
70101
eigenvalues + i * eigenvalues_stride,
71-
eigenvectors + i * eigenvectors_stride,
72-
operand + i * inner_stride / sizeof(DataType),
73-
m, n);
102+
eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
103+
n);
74104
}
75105
}

exla/lib/exla/defn.ex

+4-2
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,16 @@ defmodule EXLA.Defn do
404404
data: %Expr{
405405
args: [
406406
%{data: %{op: :eigh, args: [tensor, _opts]}},
407-
{eigenvecs_expr, eigenvals_expr},
407+
{%{type: {evec_type_kind, _}} = eigenvecs_expr,
408+
%{type: {eval_type_kind, _}} = eigenvals_expr},
408409
_callback
409410
]
410411
}
411412
},
412413
%{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state,
413414
cache
414-
) do
415+
)
416+
when evec_type_kind != :c and eval_type_kind != :c do
415417
# We match only on platform: :host for MLIR, as we want to support
416418
# eigh-on-cpu as a custom call only in this case
417419
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

exla/test/exla/nx_linalg_doctest_test.exs

+66
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,70 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
2525
@invalid_type_error_doctests ++
2626
[:moduledoc]
2727
doctest Nx.LinAlg, except: @excluded_doctests
28+
29+
describe "eigh" do
30+
test "properties for matrices with different eigenvalues" do
31+
# Generate real Hermitian matrices with different eigenvalues
32+
# from random matrices based on the relation A = Q.Λ.Q^*
33+
# where Λ is the diagonal matrix of eigenvalues and Q is unitary matrix.
34+
35+
key = Nx.Random.key(System.unique_integer())
36+
37+
for type <- [f: 32, c: 64], reduce: key do
38+
key ->
39+
# Unitary matrix from a random matrix
40+
{base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type)
41+
{q, _} = Nx.LinAlg.qr(base)
42+
43+
# Different eigenvalues from random values
44+
evals_test =
45+
[100, 10, 1]
46+
|> Enum.map(fn magnitude ->
47+
sign =
48+
if :rand.uniform() - 0.5 > 0 do
49+
1
50+
else
51+
-1
52+
end
53+
54+
rand = :rand.uniform() * magnitude * 0.1 + magnitude
55+
rand * sign
56+
end)
57+
|> Nx.tensor(type: type)
58+
59+
evals_test_diag =
60+
evals_test
61+
|> Nx.make_diagonal()
62+
|> Nx.reshape({1, 3, 3})
63+
|> Nx.tile([2, 1, 1])
64+
65+
# Hermitian matrix with different eigenvalues
66+
# using A = A^* = Q^*.Λ.Q.
67+
a =
68+
q
69+
|> Nx.LinAlg.adjoint()
70+
|> Nx.dot([2], [0], evals_test_diag, [1], [0])
71+
|> Nx.dot([2], [0], q, [1], [0])
72+
73+
# Eigenvalues and eigenvectors
74+
assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8)
75+
76+
assert_all_close(evals_test, evals[0], atol: 1.0e-8)
77+
assert_all_close(evals_test, evals[1], atol: 1.0e-8)
78+
79+
evals =
80+
evals
81+
|> Nx.vectorize(:x)
82+
|> Nx.make_diagonal()
83+
|> Nx.devectorize(keep_names: false)
84+
85+
# Eigenvalue equation
86+
evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0])
87+
a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0])
88+
89+
assert_all_close(a, a_evecs, atol: 1.0e-8)
90+
key
91+
end
92+
end
93+
end
2894
end

0 commit comments

Comments
 (0)