Skip to content

Commit 5ec66cf

Browse files
committed
Merge branch 'refactor-eigendecomposition' into eigendecomposed-shampoo
2 parents 9ce5c22 + c6ced65 commit 5ec66cf

File tree

3 files changed

+39
-51
lines changed

3 files changed

+39
-51
lines changed

matrix_functions.py

+20-38
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from dataclasses import asdict
1515
from fractions import Fraction
1616
from math import isfinite
17-
from typing import NamedTuple
1817

1918
import torch
2019
from matrix_functions_types import (
@@ -193,7 +192,7 @@ def matrix_eigendecomposition(
193192
elif A.shape[0] != A.shape[1]:
194193
raise ValueError("Matrix is not square!")
195194

196-
# return the diagonal of A and the identity matrix if A is diagonal
195+
# Return the (sorted) diagonal of A and identity matrix if A is diagonal.
197196
if is_diagonal:
198197
return A.diag(), torch.eye(
199198
A.shape[0],
@@ -258,20 +257,13 @@ def _eigh_eigenvalue_decomposition(
258257
return L.to(device=current_device), Q.to(device=current_device)
259258

260259

261-
class Criterion(NamedTuple):
262-
"""Named tuple for result of convergence criterion."""
263-
264-
converged: bool
265-
approximate_eigenvalues: Tensor
266-
267-
268-
def _approximate_eigenvalues_criterion_below_or_equal_tolerance(
269-
A: Tensor, Q: Tensor, tolerance: float
270-
) -> Criterion:
271-
"""Evaluates if a criterion using approximate eigenvalues is below or equal to the tolerance.
260+
def _estimated_eigenvalues_criterion_below_or_equal_tolerance(
261+
estimated_eigenvalues: Tensor, tolerance: float
262+
) -> bool:
263+
"""Evaluates if a criterion using estimated eigenvalues is below or equal to the tolerance.
272264
273265
Let Q^T A Q =: B be the estimate of the eigenvalues of the matrix A, where Q is the matrix containing the last computed eigenvectors.
274-
The approximate eigenvalues criterion is defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
266+
The criterion based on the estimated eigenvalues is defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
275267
The tolerance hyperparameter should therefore be in the interval [0.0, 1.0].
276268
277269
This convergence criterion can be motivated by considering A' = Q diag(B) Q^T as an approximation of A.
@@ -285,17 +277,13 @@ def _approximate_eigenvalues_criterion_below_or_equal_tolerance(
285277
tolerance (float): The tolerance for the criterion.
286278
287279
Returns:
288-
Criterion: Named tuple with the fields 'converged' and 'approximate_eigenvalues':
289-
converged (bool): whether the criterion is below or equal to the tolerance.
290-
approximate_eigenvalues (Tensor): the approximate eigenvalues of A as a vector.
280+
bool: True if the criterion is below or equal to the tolerance, False otherwise.
291281
292282
"""
293-
approximate_eigenvalues = Q.T @ A @ Q
294-
norm = torch.linalg.norm(approximate_eigenvalues)
295-
diagonal_norm = torch.linalg.norm(approximate_eigenvalues.diag())
283+
norm = torch.linalg.norm(estimated_eigenvalues)
284+
diagonal_norm = torch.linalg.norm(estimated_eigenvalues.diag())
296285
off_diagonal_norm = torch.sqrt(norm**2 - diagonal_norm**2)
297-
converged = bool(off_diagonal_norm <= tolerance * norm)
298-
return Criterion(converged, approximate_eigenvalues.diag())
286+
return bool(off_diagonal_norm <= tolerance * norm)
299287

300288

301289
def _qr_algorithm(
@@ -304,13 +292,12 @@ def _qr_algorithm(
304292
max_iterations: int = 1,
305293
tolerance: float = 0.01,
306294
) -> tuple[Tensor, Tensor]:
307-
"""
308-
Approximately compute the eigendecomposition of a symmetric matrix by performing the QR algorithm.
295+
"""Approximately compute the eigendecomposition of a symmetric matrix by performing the QR algorithm.
309296
310297
Given an initial estimate of the eigenvectors Q of matrix A, a power iteration and a QR decomposition is performed each iteration, i.e. Q, _ <- QR(A @ Q).
311298
When the initial estimate is the zero matrix, the eigendecomposition is computed using _eigh_eigenvalue_decomposition.
312299
313-
Note that if the approximate eigenvalues criterion is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
300+
Note that if the criterion based on the estimated eigenvalues is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
314301
315302
Args:
316303
A (Tensor): The symmetric input matrix.
@@ -320,35 +307,30 @@ def _qr_algorithm(
320307
(Default: 0.01)
321308
322309
Returns:
323-
tuple[Tensor, Tensor]: The approximate eigenvalues and eigenvectors of the input matrix A.
310+
tuple[Tensor, Tensor]: The estimated eigenvalues and eigenvectors of the input matrix A.
324311
325312
"""
326313
if not eigenvectors_estimate.any():
327314
return _eigh_eigenvalue_decomposition(A)
328315

329316
# Perform orthogonal/simultaneous iterations (QR algorithm).
330317
Q = eigenvectors_estimate
318+
estimated_eigenvalues = Q.T @ A @ Q
331319
iteration = 0
332-
# NOTE: This will skip the QR iterations if the approximate eigenvalues criterion is already below or equal to the tolerance given the initial eigenvectors_estimate.
320+
# NOTE: This will skip the QR iterations if the criterion is already below or equal to the tolerance given the initial eigenvectors_estimate.
333321
while (
334322
iteration < max_iterations
335-
and not (
336-
criterion := _approximate_eigenvalues_criterion_below_or_equal_tolerance(
337-
A, Q, tolerance
338-
)
339-
).converged
323+
and not _estimated_eigenvalues_criterion_below_or_equal_tolerance(
324+
estimated_eigenvalues, tolerance
325+
)
340326
):
341327
power_iteration = A @ Q
342328
Q = torch.linalg.qr(power_iteration).Q
343329
iteration += 1
330+
estimated_eigenvalues = Q.T @ A @ Q
344331

345332
# Ensure consistent ordering of estimated eigenvalues and eigenvectors.
346-
estimated_eigenvalues = (
347-
criterion.approximate_eigenvalues
348-
if iteration == 0 # Re-use approximate eigenvalues if iterations were skipped.
349-
else torch.einsum("ij, ik, kj -> j", Q, A, Q)
350-
)
351-
estimated_eigenvalues, indices = estimated_eigenvalues.sort()
333+
estimated_eigenvalues, indices = estimated_eigenvalues.diag().sort(stable=True)
352334
Q = Q[:, indices]
353335

354336
return estimated_eigenvalues, Q

matrix_functions_types.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def __post_init__(self) -> None:
5353
class QREigendecompositionConfig(EigendecompositionConfig):
5454
"""Configuration for eigenvalue decomposition via QR algorithm.
5555
56-
Determines whether the QR algorithm has converged based on the approximate eigenvalues Q^T A Q =: B, where Q is the last computed eigenvectors and A is the current Kronecker factor.
57-
The approximate eigenvalues update criterion is then defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
56+
Determines whether the QR algorithm has converged based on the estimated eigenvalues Q^T A Q =: B, where Q is the last computed eigenvectors and A is the current Kronecker factor.
57+
The convergence criterion based on the estimated eigenvalues is then defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
5858
The tolerance hyperparameter should therefore be in the interval [0.0, 1.0].
5959
60-
Note that if the approximate eigenvalues criterion is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
60+
Note that if the criterion based on the estimated eigenvalues is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
6161
6262
This convergence criterion can be motivated by considering A' = Q diag(B) Q^T as an approximation of A.
6363
We have ||A - A'||_F = ||A - Q diag(B) Q^T||_F = ||Q^T A Q - diag(B)||_F = ||B - diag(B)||_F.

tests/matrix_functions_test.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -820,20 +820,20 @@ def test_matrix_eigendecomposition(self) -> None:
820820
]
821821
expected_eigenvalues_list = [
822822
torch.tensor([1.0, 4.0]),
823-
torch.tensor([2.9009e-03, 1.7424e-01, 1.9828e03]),
823+
torch.tensor([2.9008677229e-03, 1.7424316704e-01, 1.9828229980e03]),
824824
]
825825
expected_eigenvectors_list = [
826826
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
827827
torch.tensor(
828828
[
829-
[0.0460, -0.6287, 0.7763],
830-
[-0.1752, -0.7702, -0.6133],
831-
[0.9835, -0.1078, -0.1455],
829+
[0.0460073575, -0.6286827326, 0.7762997746],
830+
[-0.1751257628, -0.7701635957, -0.6133345366],
831+
[0.9834705591, -0.1077321917, -0.1455317289],
832832
]
833833
),
834834
]
835835

836-
atol = 0.05 # TODO: Ensure consistent ordering of the eigenvectors and decrease tolerance.
836+
atol = 1e-4
837837
rtol = 1e-5
838838
with self.subTest("Test with diagonal case."):
839839
torch.testing.assert_close(
@@ -876,12 +876,18 @@ def test_matrix_eigendecomposition(self) -> None:
876876
)
877877

878878
# Tests for `QREigendecompositionConfig`.
879-
initialization_strategies = {
880-
"zero": lambda A: torch.zeros_like(A),
881-
"identity": lambda A: torch.eye(A.shape[0], dtype=A.dtype, device=A.device),
882-
"exact": lambda A: matrix_eigendecomposition(A)[1],
879+
initialization_strategies_to_functions_atol = {
880+
"zero": (lambda A: torch.zeros_like(A), atol),
881+
"identity": (
882+
lambda A: torch.eye(A.shape[0], dtype=A.dtype, device=A.device),
883+
2e-3,
884+
),
885+
"exact": (lambda A: matrix_eigendecomposition(A)[1], 2e-3),
883886
}
884-
for name, initialization_fn in initialization_strategies.items():
887+
for name, (
888+
initialization_fn,
889+
atol,
890+
) in initialization_strategies_to_functions_atol.items():
885891
with self.subTest(
886892
f"Test with QREigendecompositionConfig with {name} initialization."
887893
):

0 commit comments

Comments
 (0)