Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ae1e74d
fix(stochtrace): Support pytrees in LOO estimators
sethaxen Jun 8, 2026
b910a7f
feat(backend): Add np.array_max
sethaxen Jun 8, 2026
146632e
feat(backend): Add np.median
sethaxen Jun 8, 2026
e4f58b4
feat(stochtrace): Add XDiag
sethaxen Jun 9, 2026
5528550
feat(stochtrace): Add XNysDiag
sethaxen Jun 9, 2026
53f4f9c
feat(backend)!: Generate Hermitian matrix with random eigenvectors
sethaxen Jun 12, 2026
490df1d
test: Use hermitian_matrix_from_eigenvalues in old tests
sethaxen Jun 12, 2026
e51123a
feat(backend): Remove old symmetric generator
sethaxen Jun 12, 2026
c5fda52
test(stochtrace): Use Hermitian generator in LOO diag tests
sethaxen Jun 12, 2026
aa7ac92
test(stochtrace): Use Hermitian generator
sethaxen Jun 12, 2026
161ae92
test(stochtrace): Add helper functions for experiments
sethaxen Jun 12, 2026
3822f36
test(stochtrace): Use helpers in existing x(nys)trace tests
sethaxen Jun 12, 2026
973b918
test(stochtrace): Rename helpers file to conftest
sethaxen Jun 12, 2026
e7da6c8
test(stochtrace): Move shared nystrom fixture to conftest
sethaxen Jun 12, 2026
be3c045
test(stochtrace): Refactor xdiag tests to use cases
sethaxen Jun 12, 2026
b774331
test(stochtrace): Refactor xnysdiag tests to use cases
sethaxen Jun 12, 2026
5698dc3
test(stochtrace): Remove xnysdiag pytest tests
sethaxen Jun 12, 2026
dadf7c3
refactor(test_util): Move eigenvalue helpers from conftest to test_util
sethaxen Jun 12, 2026
80b7dd8
refactor(test_util): Add hermitian_matrix_eigvals_decaying/step helpers
sethaxen Jun 12, 2026
3e93650
test(stochtrace): Replace nystrom fixture with direct parametrize
sethaxen Jun 12, 2026
27e7445
test(stochtrace): Remove unnecessary __init__.py
sethaxen Jun 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def nanmean(x, /, axis=None):
return jnp.nanmean(x, axis)


def median(x, /, axis=None):
return jnp.median(x, axis=axis)


def array_max(x, /, axis=None):
return jnp.max(x, axis=axis)


def elementwise_max(a, b, /):
return jnp.maximum(a, b)

Expand Down
161 changes: 158 additions & 3 deletions matfree/stochtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def estimator_leave_one_out(integrand: Callable, /, sampler: Callable) -> Callab

def estimate(matvec, key, *parameters):
samples = sampler(key)
return np.mean(integrand(matvec, samples, *parameters), axis=0)
Qs = integrand(matvec, samples, *parameters)
return tree.tree_map(lambda s: np.mean(s, axis=0), Qs)

return estimate

Expand Down Expand Up @@ -146,8 +147,9 @@ def estimator_leave_one_out_mean_and_sem(
def estimate(matvec, key, *parameters):
samples = sampler(key)
Qs = integrand(matvec, samples, *parameters)
mean = np.mean(Qs, axis=0)
sem = np.std(Qs, axis=0) / np.sqrt(Qs.shape[0])
n_samples = tree.tree_leaves(Qs)[0].shape[0]
mean = tree.tree_map(lambda s: np.mean(s, axis=0), Qs)
sem = tree.tree_map(lambda s: np.std(s, axis=0) / np.sqrt(n_samples), Qs)
return mean, sem

return estimate
Expand Down Expand Up @@ -266,6 +268,88 @@ def _trace_estimate():
return integrand


def leave_one_out_xdiag() -> Callable:
"""Construct an integrand for estimating the diagonal using the XDiag algorithm (Epperly et al. 2024).

Returns
-------
integrand
An integrand function compatible with
[estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out]
whose input has the signature ``(matvec, samples, *params)`` and whose output is a
pytree with each leaf having shape ``(num_samples, n_k)``, giving one diagonal
estimate per leave-one-out sample.

Notes
-----
The number of samples must be less than or equal to the dimension of the operator.

The sum of the diagonal estimate over all entries is an unbiased estimate of the trace but
generally has higher variance than the estimate produced by
[leave_one_out_xnystrace][matfree.stochtrace.leave_one_out_xnystrace].

References
----------
- Epperly EN, Tropp JA, Webber RJ (2024). XTrace: Making the most of every sample in stochastic trace estimation.
SIAM J Matrix Anal A. 45.1: 1-23.
doi: [10.1137/23M1548323](https://doi.org/10.1137/23M1548323)
arXiv: [2301.07825](https://arxiv.org/abs/2301.07825)
- Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis.
arXiv: [2512.15929](https://arxiv.org/abs/2512.15929)
"""

def integrand(matvec, samples, *params):
sample0 = tree.tree_map(lambda s: s[0], samples)
_, unflatten = tree.ravel_pytree(sample0)

Omega = func.vmap(lambda s: tree.ravel_pytree(s)[0])(samples).T
n, num_samples = Omega.shape

if num_samples > n:
raise ValueError(_error_num_samples(num_samples, maxval=n, minval=1))

def matvec_flat(v):
return tree.ravel_pytree(matvec(unflatten(v), *params))[0]

if 2 * num_samples >= n:
B_mat = _materialize_operator(matvec_flat, Omega[:, 0])
diag_B = linalg.diagonal(B_mat)
return func.vmap(unflatten)(
np.ones((num_samples, 1), dtype=diag_B.dtype) * diag_B
)

matvec_flat_transpose = func.linear_transpose(matvec_flat, Omega[:, 0])

def matvec_flat_adjoint(v):
(result,) = matvec_flat_transpose(v.conj())
return result.conj()

Y = func.vmap(matvec_flat, in_axes=-1, out_axes=-1)(Omega)
Q, R = linalg.qr_reduced(Y)
Z = func.vmap(matvec_flat_adjoint, in_axes=-1, out_axes=-1)(Q)

def _diag_exact():
diag_B = func.vmap(linalg.vdot, in_axes=0)(Z, Q)
return np.ones(num_samples, dtype=diag_B.dtype) * diag_B[:, None]

def _diag_estimate():
S = _qr_leave_one_out_factor(R)
QS = Q @ S
S_vd_R = func.vmap(linalg.vdot, in_axes=1)(S, R)

diag_B_hat = func.vmap(linalg.vdot, in_axes=0)(Z, Q)
diag_B_hat_loo = diag_B_hat[:, None] - QS * (Z @ S).conj()
diag_residual_loo = QS * S_vd_R * Omega.conj()
return diag_B_hat_loo + diag_residual_loo

Y_rank = np.sum(np.abs(linalg.diagonal(R)) > np.finfo_eps(R.dtype))

diag_loo = control_flow.cond(Y_rank < num_samples, _diag_exact, _diag_estimate)
return func.vmap(unflatten)(diag_loo.T)

return integrand


def leave_one_out_xnystrace(
*,
nystrom: Callable[[Callable, Array], tuple[Array, Array, Array]] | None = None,
Expand Down Expand Up @@ -378,6 +462,77 @@ def matvec_flat(v):
return integrand


def leave_one_out_xnysdiag(
*, nystrom: Callable[[Callable, Array], tuple[Array, Array, Array]] | None = None
) -> Callable:
"""Construct an integrand for estimating the diagonal of a positive semi-definite operator using the XNysDiag algorithm (Epperly et al. 2025).

Parameters
----------
nystrom
A callable with signature ``(matvec_flat, Omega) -> (nystrom_left, downdate, shift)``.
Usually the return value of
[`nystrom_shifted_cholesky`][matfree.stochtrace.nystrom_shifted_cholesky]
or [`nystrom_eigh`][matfree.stochtrace.nystrom_eigh] (default: `nystrom_eigh`).

Returns
-------
integrand
An integrand function compatible with
[estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out]
whose input has the signature ``(matvec, samples, *params)`` and whose output is a
pytree with each leaf having shape ``(num_samples, n_k)``, giving one diagonal
estimate per leave-one-out sample.
The `matvec` must be a positive semi-definite operator.

Notes
-----
The number of samples must be less than or equal to the dimension of the operator.
The output diagonal is real-valued (PSD operators have real diagonal).

The sum of the diagonal estimate over all entries equals the corresponding
[leave_one_out_xnystrace][matfree.stochtrace.leave_one_out_xnystrace]
trace estimate exactly for the same operator and samples (when ``apply_resphering=False``).

References
----------
- Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis.
arXiv: [2512.15929](https://arxiv.org/abs/2512.15929)
"""
if nystrom is None:
nystrom = nystrom_eigh()

def integrand(matvec, samples, *params):
sample0 = tree.tree_map(lambda s: s[0], samples)
_, unflatten = tree.ravel_pytree(sample0)

Omega = func.vmap(lambda s: tree.ravel_pytree(s)[0])(samples).T
n, num_samples = Omega.shape

if num_samples > n:
raise ValueError(_error_num_samples(num_samples, maxval=n, minval=1))

def matvec_flat(v):
return tree.ravel_pytree(matvec(unflatten(v), *params))[0]

if num_samples == n:
B_mat = _materialize_operator(matvec_flat, Omega[:, 0])
diag_B = linalg.diagonal(B_mat)
diag_all = np.ones((num_samples, 1), dtype=diag_B.dtype) * diag_B
return tree.tree_map(lambda x: x.real, func.vmap(unflatten)(diag_all))

F, Z, shift = nystrom(matvec_flat, Omega)
Z_vd_Omega = func.vmap(linalg.vdot, in_axes=1)(Z, Omega)

diag_B_hat = np.sum(linalg.abs2(F), axis=1) - shift
diag_B_hat_loo = diag_B_hat[:, None] - linalg.abs2(Z)
diag_res_loo = Z * Z_vd_Omega * Omega.conj()
diag_loo = (diag_B_hat_loo + diag_res_loo).T
return tree.tree_map(lambda x: x.real, func.vmap(unflatten)(diag_loo))

return integrand


def _qr_leave_one_out_factor(R):
r"""Compute the downdate factor for a QR decomposition leaving out a single column.

Expand Down
41 changes: 25 additions & 16 deletions matfree/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,32 @@
from matfree.backend import linalg, np, prng, tree


def symmetric_matrix_from_eigenvalues(eigvals, /):
"""Generate a symmetric matrix with prescribed eigenvalues."""
(n,) = eigvals.shape

# Need _some_ matrix to start with
A = np.reshape(np.arange(1.0, n**2 + 1.0), (n, n))
A = A / linalg.matrix_norm(A, which="fro")
X = A.T @ A + np.eye(n)

# QR decompose. We need the orthogonal matrix.
# Treat Q as a stack of eigenvectors.
Q, _R = linalg.qr_reduced(X)
def hermitian_matrix_from_eigenvalues(eigvals, /, key, *, dtype=None):
"""Generate a Hermitian matrix with prescribed real eigenvalues.

# Treat Q as eigenvectors, and 'D' as eigenvalues.
# return Q D Q.T.
# This matrix will be dense, symmetric, and have a given spectrum.
return Q @ (eigvals[:, None] * Q.T)
For real dtype the result is symmetric; for complex dtype it is Hermitian.
"""
(n,) = eigvals.shape
if dtype is None:
dtype = eigvals.dtype
eigvals = eigvals.real
Q, _ = linalg.qr_reduced(prng.normal(key, shape=(n, n), dtype=dtype))
return (Q * eigvals) @ Q.T.conj()


def hermitian_matrix_eigvals_decaying(n, /, key, *, dtype=None):
"""Hermitian matrix whose eigenvalues decay geometrically (0.7^k)."""
eigvals = 0.7 ** np.arange(n)
Comment on lines +19 to +21

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about including 0.7 in the arguments? We can also do this in the future as soon as we need to change it, but while we're at it...

Suggested change
def hermitian_matrix_eigvals_decaying(n, /, key, *, dtype=None):
"""Hermitian matrix whose eigenvalues decay geometrically (0.7^k)."""
eigvals = 0.7 ** np.arange(n)
def hermitian_matrix_eigvals_decaying(n, /, key, *, base=0.7, dtype=None):
"""Hermitian matrix whose eigenvalues decay geometrically (x^k)."""
eigvals = x ** np.arange(n)

rdtype = np.zeros((), dtype=dtype).real.dtype
return hermitian_matrix_from_eigenvalues(eigvals, key, dtype=rdtype)


def hermitian_matrix_eigvals_step(
n, /, key, *, num_flat=50, drop_value=1e-3, dtype=None
):
"""Hermitian matrix whose eigenvalues are flat then drop sharply."""
eigvals = np.concatenate([np.ones(num_flat), np.ones(n - num_flat) * drop_value])
return hermitian_matrix_from_eigenvalues(eigvals, key, dtype=dtype)


def asymmetric_matrix_from_singular_values(vals, /, nrows, ncols):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bounds/test_bai_golub.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Tests for Bai and Golub's log-determinant bounds."""

from matfree import bounds, test_util
from matfree.backend import linalg, np
from matfree.backend import linalg, np, prng


def test_logdet():
"""Test that Bai and Golub's log-determinant bound is correct."""
# Set up a test-problem.
eigvals = np.asarray([1.0, 2.0, 3.0, 4.0])
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)
matrix = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))

# Compute the bound
trace = linalg.trace(matrix)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_decomp/test_tridiag_sym.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Test the tri-diagonalisation."""

from matfree import decomp, test_util
from matfree.backend import linalg, np, testing
from matfree.backend import linalg, np, prng, testing


@testing.parametrize("reortho", ["full", "none"])
@testing.parametrize("ndim", [12])
def test_full_rank_reconstruction_is_exact(reortho, ndim):
# Set up a test-matrix and an initial vector
eigvals = np.arange(1.0, 2.0, step=1 / ndim)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)
matrix = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))
vector = np.flip(np.arange(1.0, 1.0 + len(eigvals)))

def matvec(s, p):
Expand Down Expand Up @@ -46,7 +46,7 @@ def matvec(s, p):
def test_mid_rank_reconstruction_satisfies_decomposition(ndim, num_matvecs, reortho):
# Set up a test-matrix and an initial vector
eigvals = np.arange(1.0, 2.0, step=1 / ndim)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)
matrix = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))
vector = np.flip(np.arange(1.0, 1.0 + len(eigvals)))

def matvec(s, p):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_decomp/test_tridiag_sym_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
def test_adjoint_vjp_matches_jax_vjp(reortho, n, krylov_num_matvecs, seed):
"""Test that the custom VJP yields the same output as autodiff."""
# Set up a test-matrix
eigvals = prng.uniform(prng.prng_key(seed), shape=(n,)) + 1.0
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)
key_eig, key_mat = prng.split(prng.prng_key(seed))
eigvals = prng.uniform(key_eig, shape=(n,)) + 1.0
matrix = test_util.hermitian_matrix_from_eigenvalues(eigvals, key_mat)
params = _sym(matrix)

def matvec(s, p):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_eig/test_eigh_partial.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Tests for eigenvalue functionality."""

from matfree import decomp, eig, test_util
from matfree.backend import linalg, np, testing
from matfree.backend import linalg, np, prng, testing


@testing.parametrize("nrows", [10])
def test_equal_to_linalg_eigh(nrows):
eigvals = np.arange(1.0, 1.0 + nrows)
A = test_util.symmetric_matrix_from_eigenvalues(eigvals)
A = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))
v0 = np.ones((nrows,))
num_matvecs = nrows

Expand All @@ -24,7 +24,7 @@ def test_equal_to_linalg_eigh(nrows):
@testing.parametrize("num_matvecs", [8, 4, 0])
def test_shapes_as_expected_vector(nrows, num_matvecs):
eigvals = np.arange(1.0, 1.0 + nrows)
A = test_util.symmetric_matrix_from_eigenvalues(eigvals)
A = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))
v0 = np.ones((nrows,))

tridiag_sym = decomp.tridiag_sym(num_matvecs, reortho="full")
Expand All @@ -38,7 +38,7 @@ def test_shapes_as_expected_vector(nrows, num_matvecs):
@testing.parametrize("num_matvecs", [0, 2, 3])
def test_shapes_as_expected_lists_tuples(nrows, num_matvecs):
eigvals = np.arange(1.0, 1.0 + nrows)
A = test_util.symmetric_matrix_from_eigenvalues(eigvals)
A = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))
v0 = np.ones((nrows,))

# Map Pytrees to Pytrees
Expand Down
2 changes: 1 addition & 1 deletion tests/test_funm/test_funm_chebyshev.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def fun(x):
v = prng.normal(prng.prng_key(2), shape=(n,))

eigvals = np.linspace(-1 + 0.01, 1 - 0.01, num=n)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)
matrix = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))

# Compute the solution
eigvals, eigvecs = linalg.eigh(matrix)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_funm/test_funm_lanczos_sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fun(x):
v = prng.normal(prng.prng_key(2), shape=(n,))

eigvals = np.linspace(0.01, 0.99, num=n)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)
matrix = test_util.hermitian_matrix_from_eigenvalues(eigvals, prng.prng_key(1))

# Compute the solution
eigvals, eigvecs = linalg.eigh(matrix)
Expand Down
Loading
Loading