Skip to content

Commit c6ced65

Browse files
committed
Refactor QR convergence criterion
1 parent 1031eb2 commit c6ced65

File tree

2 files changed

+22
-40
lines changed

2 files changed

+22
-40
lines changed

matrix_functions.py

+19-37
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from dataclasses import asdict
1515
from fractions import Fraction
1616
from math import isfinite
17-
from typing import NamedTuple
1817

1918
import torch
2019
from matrix_functions_types import (
@@ -258,20 +257,13 @@ def _eigh_eigenvalue_decomposition(
258257
return L.to(device=current_device), Q.to(device=current_device)
259258

260259

261-
class Criterion(NamedTuple):
262-
"""Named tuple for result of convergence criterion."""
263-
264-
converged: bool
265-
approximate_eigenvalues: Tensor
266-
267-
268-
def _approximate_eigenvalues_criterion_below_or_equal_tolerance(
269-
A: Tensor, Q: Tensor, tolerance: float
270-
) -> Criterion:
271-
"""Evaluates if a criterion using approximate eigenvalues is below or equal to the tolerance.
260+
def _estimated_eigenvalues_criterion_below_or_equal_tolerance(
261+
estimated_eigenvalues: Tensor, tolerance: float
262+
) -> bool:
263+
"""Evaluates if a criterion using estimated eigenvalues is below or equal to the tolerance.
272264
273265
Let Q^T A Q =: B be the estimate of the eigenvalues of the matrix A, where Q is the matrix containing the last computed eigenvectors.
274-
The approximate eigenvalues criterion is defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
266+
The criterion based on the estimated eigenvalues is defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
275267
The tolerance hyperparameter should therefore be in the interval [0.0, 1.0].
276268
277269
This convergence criterion can be motivated by considering A' = Q diag(B) Q^T as an approximation of A.
@@ -285,17 +277,13 @@ def _approximate_eigenvalues_criterion_below_or_equal_tolerance(
285277
tolerance (float): The tolerance for the criterion.
286278
287279
Returns:
288-
Criterion: Named tuple with the fields 'converged' and 'approximate_eigenvalues':
289-
converged (bool): whether the criterion is below or equal to the tolerance.
290-
approximate_eigenvalues (Tensor): the approximate eigenvalues of A as a vector.
280+
bool: True if the criterion is below or equal to the tolerance, False otherwise.
291281
292282
"""
293-
approximate_eigenvalues = Q.T @ A @ Q
294-
norm = torch.linalg.norm(approximate_eigenvalues)
295-
diagonal_norm = torch.linalg.norm(approximate_eigenvalues.diag())
283+
norm = torch.linalg.norm(estimated_eigenvalues)
284+
diagonal_norm = torch.linalg.norm(estimated_eigenvalues.diag())
296285
off_diagonal_norm = torch.sqrt(norm**2 - diagonal_norm**2)
297-
converged = bool(off_diagonal_norm <= tolerance * norm)
298-
return Criterion(converged, approximate_eigenvalues.diag())
286+
return bool(off_diagonal_norm <= tolerance * norm)
299287

300288

301289
def _qr_algorithm(
@@ -304,13 +292,12 @@ def _qr_algorithm(
304292
max_iterations: int = 1,
305293
tolerance: float = 0.01,
306294
) -> tuple[Tensor, Tensor]:
307-
"""
308-
Approximately compute the eigendecomposition of a symmetric matrix by performing the QR algorithm.
295+
"""Approximately compute the eigendecomposition of a symmetric matrix by performing the QR algorithm.
309296
310297
Given an initial estimate of the eigenvectors Q of matrix A, a power iteration and a QR decomposition is performed each iteration, i.e. Q, _ <- QR(A @ Q).
311298
When the initial estimate is the zero matrix, the eigendecomposition is computed using _eigh_eigenvalue_decomposition.
312299
313-
Note that if the approximate eigenvalues criterion is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
300+
Note that if the criterion based on the estimated eigenvalues is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
314301
315302
Args:
316303
A (Tensor): The symmetric input matrix.
@@ -320,35 +307,30 @@ def _qr_algorithm(
320307
(Default: 0.01)
321308
322309
Returns:
323-
tuple[Tensor, Tensor]: The approximate eigenvalues and eigenvectors of the input matrix A.
310+
tuple[Tensor, Tensor]: The estimated eigenvalues and eigenvectors of the input matrix A.
324311
325312
"""
326313
if not eigenvectors_estimate.any():
327314
return _eigh_eigenvalue_decomposition(A)
328315

329316
# Perform orthogonal/simultaneous iterations (QR algorithm).
330317
Q = eigenvectors_estimate
318+
estimated_eigenvalues = Q.T @ A @ Q
331319
iteration = 0
332-
# NOTE: This will skip the QR iterations if the approximate eigenvalues criterion is already below or equal to the tolerance given the initial eigenvectors_estimate.
320+
# NOTE: This will skip the QR iterations if the criterion is already below or equal to the tolerance given the initial eigenvectors_estimate.
333321
while (
334322
iteration < max_iterations
335-
and not (
336-
criterion := _approximate_eigenvalues_criterion_below_or_equal_tolerance(
337-
A, Q, tolerance
338-
)
339-
).converged
323+
and not _estimated_eigenvalues_criterion_below_or_equal_tolerance(
324+
estimated_eigenvalues, tolerance
325+
)
340326
):
341327
power_iteration = A @ Q
342328
Q = torch.linalg.qr(power_iteration).Q
343329
iteration += 1
330+
estimated_eigenvalues = Q.T @ A @ Q
344331

345332
# Ensure consistent ordering of estimated eigenvalues and eigenvectors.
346-
estimated_eigenvalues = (
347-
criterion.approximate_eigenvalues
348-
if iteration == 0 # Re-use approximate eigenvalues if iterations were skipped.
349-
else torch.einsum("ij, ik, kj -> j", Q, A, Q)
350-
)
351-
estimated_eigenvalues, indices = estimated_eigenvalues.sort(stable=True)
333+
estimated_eigenvalues, indices = estimated_eigenvalues.diag().sort(stable=True)
352334
Q = Q[:, indices]
353335

354336
return estimated_eigenvalues, Q

matrix_functions_types.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def __post_init__(self) -> None:
5353
class QREigendecompositionConfig(EigendecompositionConfig):
5454
"""Configuration for eigenvalue decomposition via QR algorithm.
5555
56-
Determines whether the QR algorithm has converged based on the approximate eigenvalues Q^T A Q =: B, where Q is the last computed eigenvectors and A is the current Kronecker factor.
57-
The approximate eigenvalues update criterion is then defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
56+
Determines whether the QR algorithm has converged based on the estimated eigenvalues Q^T A Q =: B, where Q is the last computed eigenvectors and A is the current Kronecker factor.
57+
The convergence criterion based on the estimated eigenvalues is then defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
5858
The tolerance hyperparameter should therefore be in the interval [0.0, 1.0].
5959
60-
Note that if the approximate eigenvalues criterion is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
60+
Note that if the criterion based on the estimated eigenvalues is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
6161
6262
This convergence criterion can be motivated by considering A' = Q diag(B) Q^T as an approximation of A.
6363
We have ||A - A'||_F = ||A - Q diag(B) Q^T||_F = ||Q^T A Q - diag(B)||_F = ||B - diag(B)||_F.

0 commit comments

Comments
 (0)