diff --git a/exla/c_src/exla/custom_calls/eigh.h b/exla/c_src/exla/custom_calls/eigh.h index 5acc8af664..55cb5adfc1 100644 --- a/exla/c_src/exla/custom_calls/eigh.h +++ b/exla/c_src/exla/custom_calls/eigh.h @@ -2,11 +2,18 @@ #include "Eigen/Eigenvalues" +#include #include +#include +#include template -void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) { - typedef Eigen::Matrix RowMajorMatrix; +void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, + DataType *eigenvectors_out, + DataType *in, uint64_t m, uint64_t n) { + typedef Eigen::Matrix + RowMajorMatrix; // Map the input matrix Eigen::Map input(in, m, n); @@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig } // Get the eigenvalues and eigenvectors - Eigen::Matrix eigenvalues = eigensolver.eigenvalues(); + Eigen::Matrix eigenvalues = + eigensolver.eigenvalues(); RowMajorMatrix eigenvectors = eigensolver.eigenvectors(); - // Copy the eigenvalues to the output - std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType)); + // Create a vector of indices and sort it based on eigenvalues in decreasing + // order + std::vector indices(m); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) { + return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j)); + }); + + // Sort eigenvalues and rearrange eigenvectors + Eigen::Matrix sorted_eigenvalues(m); + RowMajorMatrix sorted_eigenvectors(m, n); + for (int i = 0; i < m; ++i) { + sorted_eigenvalues(i) = eigenvalues(indices[i]); + sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]); + } + + // Copy the sorted eigenvalues to the output + std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), m * sizeof(DataType)); - // Copy the eigenvectors to the output - std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType)); + // Copy the sorted eigenvectors to the output + std::memcpy(eigenvectors_out, sorted_eigenvectors.data(), + m * n * sizeof(DataType)); } template @@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) { uint64_t num_eigenvectors_dims = dim_sizes[2]; uint64_t *operand_dims_ptr = (uint64_t *)in[2]; - std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); + std::vector operand_dims(operand_dims_ptr, + operand_dims_ptr + num_operand_dims); uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3]; - std::vector eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); + std::vector eigenvalues_dims( + eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4]; - std::vector eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); + std::vector eigenvectors_dims( + eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; - auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + auto leading_dimensions = + std::vector(operand_dims.begin(), operand_dims.end() - 2); uint64_t batch_items = 1; for (uint64_t i = 0; i < leading_dimensions.size(); i++) { @@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) { DataType *eigenvalues = (DataType *)out[0]; DataType *eigenvectors = (DataType *)out[1]; - uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType); - uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType); - uint64_t inner_stride = m * n * sizeof(DataType); + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = + eigenvectors_dims[eigenvectors_dims.size() - 1] * + eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t inner_stride = m * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_eigh_cpu_custom_call( eigenvalues + i * eigenvalues_stride, - eigenvectors + i * eigenvectors_stride, - operand + i * inner_stride / sizeof(DataType), - m, n); + eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m, + n); } } \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3a676b5942..3bb478258d 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -404,14 +404,16 @@ defmodule EXLA.Defn do data: %Expr{ args: [ %{data: %{op: :eigh, args: [tensor, _opts]}}, - {eigenvecs_expr, eigenvals_expr}, + {%{type: {evec_type_kind, _}} = eigenvecs_expr, + %{type: {eval_type_kind, _}} = eigenvals_expr}, _callback ] } }, %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, cache - ) do + ) + when evec_type_kind != :c and eval_type_kind != :c do # We match only on platform: :host for MLIR, as we want to support # eigh-on-cpu as a custom call only in this case {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 09d60ba8f6..3d49d2c826 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -25,4 +25,70 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do @invalid_type_error_doctests ++ [:moduledoc] doctest Nx.LinAlg, except: @excluded_doctests + + describe "eigh" do + test "properties for matrices with different eigenvalues" do + # Generate real Hermitian matrices with different eigenvalues + # from random matrices based on the relation A = Q.Λ.Q^* + # where Λ is the diagonal matrix of eigenvalues and Q is unitary matrix. + + key = Nx.Random.key(System.unique_integer()) + + for type <- [f: 32, c: 64], reduce: key do + key -> + # Unitary matrix from a random matrix + {base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type) + {q, _} = Nx.LinAlg.qr(base) + + # Different eigenvalues from random values + evals_test = + [100, 10, 1] + |> Enum.map(fn magnitude -> + sign = + if :rand.uniform() - 0.5 > 0 do + 1 + else + -1 + end + + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign + end) + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) + + # Hermitian matrix with different eigenvalues + # using A = A^* = Q^*.Λ.Q. + a = + q + |> Nx.LinAlg.adjoint() + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) + |> Nx.dot([2], [0], q, [1], [0]) + + # Eigenvalues and eigenvectors + assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8) + + assert_all_close(evals_test, evals[0], atol: 1.0e-8) + assert_all_close(evals_test, evals[1], atol: 1.0e-8) + + evals = + evals + |> Nx.vectorize(:x) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) + + # Eigenvalue equation + evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0]) + a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0]) + + assert_all_close(a, a_evecs, atol: 1.0e-8) + key + end + end + end end