Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(exla): batched eigh #1591

Merged
merged 3 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions exla/c_src/exla/custom_calls/eigh.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

#include "Eigen/Eigenvalues"

#include <algorithm>
#include <iostream>
#include <numeric>
#include <vector>

template <typename DataType>
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
DataType *eigenvectors_out,
DataType *in, uint64_t m, uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
RowMajorMatrix;

// Map the input matrix
Eigen::Map<RowMajorMatrix> input(in, m, n);
Expand All @@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
}

// Get the eigenvalues and eigenvectors
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues =
eigensolver.eigenvalues();
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();

// Copy the eigenvalues to the output
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));
// Create a vector of indices and sort it based on eigenvalues in decreasing
// order
std::vector<int> indices(m);
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) {
return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j));
});

// Sort eigenvalues and rearrange eigenvectors
Eigen::Matrix<DataType, Eigen::Dynamic, 1> sorted_eigenvalues(m);
RowMajorMatrix sorted_eigenvectors(m, n);
for (int i = 0; i < m; ++i) {
sorted_eigenvalues(i) = eigenvalues(indices[i]);
sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]);
}

// Copy the sorted eigenvalues to the output
std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), m * sizeof(DataType));

// Copy the eigenvectors to the output
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
// Copy the sorted eigenvectors to the output
std::memcpy(eigenvectors_out, sorted_eigenvectors.data(),
m * n * sizeof(DataType));
}

template <typename DataType>
Expand All @@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
uint64_t num_eigenvectors_dims = dim_sizes[2];

uint64_t *operand_dims_ptr = (uint64_t *)in[2];
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
std::vector<uint64_t> operand_dims(operand_dims_ptr,
operand_dims_ptr + num_operand_dims);

uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
std::vector<uint64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
std::vector<uint64_t> eigenvalues_dims(
eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);

uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
std::vector<uint64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
std::vector<uint64_t> eigenvectors_dims(
eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);

uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
auto leading_dimensions =
std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);

uint64_t batch_items = 1;
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
Expand All @@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
DataType *eigenvalues = (DataType *)out[0];
DataType *eigenvectors = (DataType *)out[1];

uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
uint64_t inner_stride = m * n * sizeof(DataType);
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
uint64_t eigenvectors_stride =
eigenvectors_dims[eigenvectors_dims.size() - 1] *
eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t inner_stride = m * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_eigh_cpu_custom_call<DataType>(
eigenvalues + i * eigenvalues_stride,
eigenvectors + i * eigenvectors_stride,
operand + i * inner_stride / sizeof(DataType),
m, n);
eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
n);
}
}
6 changes: 4 additions & 2 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,16 @@ defmodule EXLA.Defn do
data: %Expr{
args: [
%{data: %{op: :eigh, args: [tensor, _opts]}},
{eigenvecs_expr, eigenvals_expr},
{%{type: {evec_type_kind, _}} = eigenvecs_expr,
%{type: {eval_type_kind, _}} = eigenvals_expr},
_callback
]
}
},
%{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state,
cache
) do
)
when evec_type_kind != :c and eval_type_kind != :c do
# We match only on platform: :host for MLIR, as we want to support
# eigh-on-cpu as a custom call only in this case
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
Expand Down
66 changes: 66 additions & 0 deletions exla/test/exla/nx_linalg_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,70 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
@invalid_type_error_doctests ++
[:moduledoc]
doctest Nx.LinAlg, except: @excluded_doctests

describe "eigh" do
test "properties for matrices with different eigenvalues" do
# Generate real Hermitian matrices with different eigenvalues
# from random matrices based on the relation A = Q.Λ.Q^*
# where Λ is the diagonal matrix of eigenvalues and Q is unitary matrix.

key = Nx.Random.key(System.unique_integer())

for type <- [f: 32, c: 64], reduce: key do
key ->
# Unitary matrix from a random matrix
{base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type)
{q, _} = Nx.LinAlg.qr(base)

# Different eigenvalues from random values
evals_test =
[100, 10, 1]
|> Enum.map(fn magnitude ->
sign =
if :rand.uniform() - 0.5 > 0 do
1
else
-1
end

rand = :rand.uniform() * magnitude * 0.1 + magnitude
rand * sign
end)
|> Nx.tensor(type: type)

evals_test_diag =
evals_test
|> Nx.make_diagonal()
|> Nx.reshape({1, 3, 3})
|> Nx.tile([2, 1, 1])

# Hermitian matrix with different eigenvalues
# using A = A^* = Q^*.Λ.Q.
a =
q
|> Nx.LinAlg.adjoint()
|> Nx.dot([2], [0], evals_test_diag, [1], [0])
|> Nx.dot([2], [0], q, [1], [0])

# Eigenvalues and eigenvectors
assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8)

assert_all_close(evals_test, evals[0], atol: 1.0e-8)
assert_all_close(evals_test, evals[1], atol: 1.0e-8)

evals =
evals
|> Nx.vectorize(:x)
|> Nx.make_diagonal()
|> Nx.devectorize(keep_names: false)

# Eigenvalue equation
evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0])
a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0])

assert_all_close(a, a_evecs, atol: 1.0e-8)
key
end
end
end
end