Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ae1e74d
fix(stochtrace): Support pytrees in LOO estimators
sethaxen Jun 8, 2026
b910a7f
feat(backend): Add np.array_max
sethaxen Jun 8, 2026
146632e
feat(backend): Add np.median
sethaxen Jun 8, 2026
e4f58b4
feat(stochtrace): Add XDiag
sethaxen Jun 9, 2026
5528550
feat(stochtrace): Add XNysDiag
sethaxen Jun 9, 2026
53f4f9c
feat(backend)!: Generate Hermitian matrix with random eigenvectors
sethaxen Jun 12, 2026
490df1d
test: Use hermitian_matrix_from_eigenvalues in old tests
sethaxen Jun 12, 2026
e51123a
feat(backend): Remove old symmetric generator
sethaxen Jun 12, 2026
c5fda52
test(stochtrace): Use Hermitian generator in LOO diag tests
sethaxen Jun 12, 2026
aa7ac92
test(stochtrace): Use Hermitian generator
sethaxen Jun 12, 2026
161ae92
test(stochtrace): Add helper functions for experiments
sethaxen Jun 12, 2026
3822f36
test(stochtrace): Use helpers in existing x(nys)trace tests
sethaxen Jun 12, 2026
973b918
test(stochtrace): Rename helpers file to conftest
sethaxen Jun 12, 2026
e7da6c8
test(stochtrace): Move shared nystrom fixture to conftest
sethaxen Jun 12, 2026
be3c045
test(stochtrace): Refactor xdiag tests to use cases
sethaxen Jun 12, 2026
b774331
test(stochtrace): Refactor xnysdiag tests to use cases
sethaxen Jun 12, 2026
5698dc3
test(stochtrace): Remove xnysdiag pytest tests
sethaxen Jun 12, 2026
dadf7c3
refactor(test_util): Move eigenvalue helpers from conftest to test_util
sethaxen Jun 12, 2026
80b7dd8
refactor(test_util): Add hermitian_matrix_eigvals_decaying/step helpers
sethaxen Jun 12, 2026
3e93650
test(stochtrace): Replace nystrom fixture with direct parametrize
sethaxen Jun 12, 2026
27e7445
test(stochtrace): Remove unnecessary __init__.py
sethaxen Jun 12, 2026
51106e1
Merge remote-tracking branch 'upstream/main' into xdiag_xnysdiag
sethaxen Jun 22, 2026
58bbee3
test(stochtrace): loosen atol in nystrom non-essential-vectors test
sethaxen Jun 22, 2026
df8d08a
refactor(test_util): make decay base configurable in hermitian_matrix…
sethaxen Jun 22, 2026
de2ad3e
fix(test_util): preserve complex dtype in hermitian_matrix_eigvals_de…
sethaxen Jun 22, 2026
8cdeef2
test(stochtrace): restructure xtrace tests to match xdiag pattern
sethaxen Jun 22, 2026
f383d42
test(stochtrace): restructure xnystrace tests to match xnysdiag pattern
sethaxen Jun 22, 2026
a8d27ec
test(stochtrace): make xnystrace docstring standalone
sethaxen Jun 22, 2026
ca3302e
test(stochtrace): cover heterogeneous pytrees in xtrace/xnystrace exa…
sethaxen Jun 22, 2026
fe8bd2e
test(stochtrace): cover heterogeneous pytrees in xdiag/xnysdiag exact…
sethaxen Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def nanmean(x, /, axis=None):
return jnp.nanmean(x, axis)


def median(x, /, axis=None):
return jnp.median(x, axis=axis)


def array_max(x, /, axis=None):
return jnp.max(x, axis=axis)


def elementwise_max(a, b, /):
return jnp.maximum(a, b)

Expand Down
161 changes: 158 additions & 3 deletions matfree/stochtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def estimator_leave_one_out(integrand: Callable, /, sampler: Callable) -> Callab

def estimate(matvec, key, *parameters):
samples = sampler(key)
return np.mean(integrand(matvec, samples, *parameters), axis=0)
Qs = integrand(matvec, samples, *parameters)
return tree.tree_map(lambda s: np.mean(s, axis=0), Qs)

return estimate

Expand Down Expand Up @@ -146,8 +147,9 @@ def estimator_leave_one_out_mean_and_sem(
def estimate(matvec, key, *parameters):
samples = sampler(key)
Qs = integrand(matvec, samples, *parameters)
mean = np.mean(Qs, axis=0)
sem = np.std(Qs, axis=0) / np.sqrt(Qs.shape[0])
n_samples = tree.tree_leaves(Qs)[0].shape[0]
mean = tree.tree_map(lambda s: np.mean(s, axis=0), Qs)
sem = tree.tree_map(lambda s: np.std(s, axis=0) / np.sqrt(n_samples), Qs)
return mean, sem

return estimate
Expand Down Expand Up @@ -266,6 +268,88 @@ def _trace_estimate():
return integrand


def leave_one_out_xdiag() -> Callable:
"""Construct an integrand for estimating the diagonal using the XDiag algorithm (Epperly et al. 2024).

Returns
-------
integrand
An integrand function compatible with
[estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out]
whose input has the signature ``(matvec, samples, *params)`` and whose output is a
pytree with each leaf having shape ``(num_samples, n_k)``, giving one diagonal
estimate per leave-one-out sample.

Notes
-----
The number of samples must be less than or equal to the dimension of the operator.

The sum of the diagonal estimate over all entries is an unbiased estimate of the trace but
generally has higher variance than the estimate produced by
[leave_one_out_xnystrace][matfree.stochtrace.leave_one_out_xnystrace].

References
----------
- Epperly EN, Tropp JA, Webber RJ (2024). XTrace: Making the most of every sample in stochastic trace estimation.
SIAM J Matrix Anal A. 45.1: 1-23.
doi: [10.1137/23M1548323](https://doi.org/10.1137/23M1548323)
arXiv: [2301.07825](https://arxiv.org/abs/2301.07825)
- Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis.
arXiv: [2512.15929](https://arxiv.org/abs/2512.15929)
"""

def integrand(matvec, samples, *params):
sample0 = tree.tree_map(lambda s: s[0], samples)
_, unflatten = tree.ravel_pytree(sample0)

Omega = func.vmap(lambda s: tree.ravel_pytree(s)[0])(samples).T
n, num_samples = Omega.shape

if num_samples > n:
raise ValueError(_error_num_samples(num_samples, maxval=n, minval=1))

def matvec_flat(v):
return tree.ravel_pytree(matvec(unflatten(v), *params))[0]

if 2 * num_samples >= n:
B_mat = _materialize_operator(matvec_flat, Omega[:, 0])
diag_B = linalg.diagonal(B_mat)
return func.vmap(unflatten)(
np.ones((num_samples, 1), dtype=diag_B.dtype) * diag_B
)

matvec_flat_transpose = func.linear_transpose(matvec_flat, Omega[:, 0])

def matvec_flat_adjoint(v):
(result,) = matvec_flat_transpose(v.conj())
return result.conj()

Y = func.vmap(matvec_flat, in_axes=-1, out_axes=-1)(Omega)
Q, R = linalg.qr_reduced(Y)
Z = func.vmap(matvec_flat_adjoint, in_axes=-1, out_axes=-1)(Q)

def _diag_exact():
diag_B = func.vmap(linalg.vdot, in_axes=0)(Z, Q)
return np.ones(num_samples, dtype=diag_B.dtype) * diag_B[:, None]

def _diag_estimate():
S = _qr_leave_one_out_factor(R)
QS = Q @ S
S_vd_R = func.vmap(linalg.vdot, in_axes=1)(S, R)

diag_B_hat = func.vmap(linalg.vdot, in_axes=0)(Z, Q)
diag_B_hat_loo = diag_B_hat[:, None] - QS * (Z @ S).conj()
diag_residual_loo = QS * S_vd_R * Omega.conj()
return diag_B_hat_loo + diag_residual_loo

Y_rank = np.sum(np.abs(linalg.diagonal(R)) > np.finfo_eps(R.dtype))

diag_loo = control_flow.cond(Y_rank < num_samples, _diag_exact, _diag_estimate)
return func.vmap(unflatten)(diag_loo.T)

return integrand


def leave_one_out_xnystrace(
*,
nystrom: Callable[[Callable, Array], tuple[Array, Array, Array]] | None = None,
Expand Down Expand Up @@ -378,6 +462,77 @@ def matvec_flat(v):
return integrand


def leave_one_out_xnysdiag(
*, nystrom: Callable[[Callable, Array], tuple[Array, Array, Array]] | None = None
) -> Callable:
"""Construct an integrand for estimating the diagonal of a positive semi-definite operator using the XNysDiag algorithm (Epperly et al. 2025).

Parameters
----------
nystrom
A callable with signature ``(matvec_flat, Omega) -> (nystrom_left, downdate, shift)``.
Usually the return value of
[`nystrom_shifted_cholesky`][matfree.stochtrace.nystrom_shifted_cholesky]
or [`nystrom_eigh`][matfree.stochtrace.nystrom_eigh] (default: `nystrom_eigh`).

Returns
-------
integrand
An integrand function compatible with
[estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out]
whose input has the signature ``(matvec, samples, *params)`` and whose output is a
pytree with each leaf having shape ``(num_samples, n_k)``, giving one diagonal
estimate per leave-one-out sample.
The `matvec` must be a positive semi-definite operator.

Notes
-----
The number of samples must be less than or equal to the dimension of the operator.
The output diagonal is real-valued (PSD operators have real diagonal).

The sum of the diagonal estimate over all entries equals the corresponding
[leave_one_out_xnystrace][matfree.stochtrace.leave_one_out_xnystrace]
trace estimate exactly for the same operator and samples (when ``apply_resphering=False``).

References
----------
- Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis.
arXiv: [2512.15929](https://arxiv.org/abs/2512.15929)
"""
if nystrom is None:
nystrom = nystrom_eigh()

def integrand(matvec, samples, *params):
sample0 = tree.tree_map(lambda s: s[0], samples)
_, unflatten = tree.ravel_pytree(sample0)

Omega = func.vmap(lambda s: tree.ravel_pytree(s)[0])(samples).T
n, num_samples = Omega.shape

if num_samples > n:
raise ValueError(_error_num_samples(num_samples, maxval=n, minval=1))

def matvec_flat(v):
return tree.ravel_pytree(matvec(unflatten(v), *params))[0]

if num_samples == n:
B_mat = _materialize_operator(matvec_flat, Omega[:, 0])
diag_B = linalg.diagonal(B_mat)
diag_all = np.ones((num_samples, 1), dtype=diag_B.dtype) * diag_B
return tree.tree_map(lambda x: x.real, func.vmap(unflatten)(diag_all))

F, Z, shift = nystrom(matvec_flat, Omega)
Z_vd_Omega = func.vmap(linalg.vdot, in_axes=1)(Z, Omega)

diag_B_hat = np.sum(linalg.abs2(F), axis=1) - shift
diag_B_hat_loo = diag_B_hat[:, None] - linalg.abs2(Z)
diag_res_loo = Z * Z_vd_Omega * Omega.conj()
diag_loo = (diag_B_hat_loo + diag_res_loo).T
return tree.tree_map(lambda x: x.real, func.vmap(unflatten)(diag_loo))

return integrand


def _qr_leave_one_out_factor(R):
r"""Compute the downdate factor for a QR decomposition leaving out a single column.

Expand Down
44 changes: 44 additions & 0 deletions tests/test_stochtrace/test_leave_one_out/test_estimator_loo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Test estimator_leave_one_out and estimator_leave_one_out_mean_and_sem."""

from matfree import stochtrace
from matfree.backend import np, prng


def test_estimator_leave_one_out_pytree_output():
Comment thread
pnkraemer marked this conversation as resolved.
"""Assert that estimator_leave_one_out handles pytree-returning integrands."""
n, num_samples = 7, 5
key = prng.prng_key(1)

def pytree_integrand(_matvec, samples, *_params):
# samples: array of shape (num_samples, n)
return {"a": samples[:, :3], "b": samples[:, 3:]}

sampler = stochtrace.sampler_normal(np.ones(n), num=num_samples)
estimate = stochtrace.estimator_leave_one_out(pytree_integrand, sampler)
result = estimate(lambda v: v, key)

assert isinstance(result, dict)
assert result["a"].shape == (3,)
assert result["b"].shape == (4,)


def test_estimator_leave_one_out_mean_and_sem_pytree_output():
"""Assert that estimator_leave_one_out_mean_and_sem handles pytree-returning integrands."""
n, num_samples = 7, 5
key = prng.prng_key(1)

def pytree_integrand(_matvec, samples, *_params):
return {"a": samples[:, :3], "b": samples[:, 3:]}

sampler = stochtrace.sampler_normal(np.ones(n), num=num_samples)
estimate = stochtrace.estimator_leave_one_out_mean_and_sem(
pytree_integrand, sampler
)
mean, sem = estimate(lambda v: v, key)

assert isinstance(mean, dict)
assert mean["a"].shape == (3,)
assert mean["b"].shape == (4,)
assert isinstance(sem, dict)
assert sem["a"].shape == (3,)
assert sem["b"].shape == (4,)
149 changes: 149 additions & 0 deletions tests/test_stochtrace/test_leave_one_out/test_xdiag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Tests for leave_one_out_xdiag."""

from matfree import stochtrace, test_util
from matfree.backend import func, linalg, np, prng, testing


@testing.parametrize("n", [10, 20])
def test_xdiag_error_num_samples_more_than_dimension(n):
"""Assert that num_samples greater than the dimension raises a ValueError."""
key = prng.prng_key(1)
A = np.eye(n)

def matvec(v, A):
return A @ v

integrand = stochtrace.leave_one_out_xdiag()
sampler = stochtrace.sampler_normal(np.ones(n), num=n + 1)
estimate = stochtrace.estimator_leave_one_out(integrand, sampler)
message = f"Number of samples num={n + 1} exceeds the acceptable range."
message = f"{message} Expected: 1 <= num <= {n}."
with testing.raises(ValueError, match=message):
estimate(matvec, key, A)


@testing.parametrize("n, num_samples", [(10, 5), (21, 11)])
@testing.parametrize(
"dtype_op, dtype_sample",
[(float, float), (complex, complex), (complex, float), (float, complex)],
)
def test_xdiag_exact_when_num_samples_more_than_half_dimension(
n, num_samples, dtype_op, dtype_sample
):
"""Assert exact diagonal computation when num_samples is large."""
key = prng.prng_key(1)
A = np.tril(np.ones((n, n), dtype=dtype_op))
expected = linalg.diagonal(A)

def matvec(v, A):
return A @ v

sampler = stochtrace.sampler_normal(np.ones(n, dtype=dtype_sample), num=num_samples)
integrand = stochtrace.leave_one_out_xdiag()
estimate = stochtrace.estimator_leave_one_out(integrand, sampler)
test_util.assert_allclose(estimate(matvec, key, A), expected)


@testing.parametrize("n, rank", [(50, 10), (100, 30)])
@testing.parametrize("dtype", [float, complex])
def test_xdiag_low_rank_operator(n, rank, dtype):
"""Assert that the diagonal of a low-rank operator is computed exactly."""
key = prng.prng_key(5)
key_mat1, key_mat2, key = prng.split(key, 3)
A = prng.normal(key_mat1, shape=(n, rank), dtype=dtype)
B = prng.normal(key_mat2, shape=(rank, n), dtype=dtype)
expected = linalg.diagonal(A @ B)

def matvec(v, A, B):
return A @ (B @ v)

sampler = stochtrace.sampler_normal(np.ones(n, dtype=dtype), num=rank + 1)
integrand = stochtrace.leave_one_out_xdiag()
estimate = stochtrace.estimator_leave_one_out(integrand, sampler)
test_util.assert_allclose(estimate(matvec, key, A, B), expected)


@testing.parametrize("dtype", [float, complex])
def test_xdiag_fast_spectral_decay(dtype):
"""Assert that a diagonal with fast spectral decay is estimated accurately.

Reproduces the setup of the experiment 'exp' from Fig 16.1 of Ethan Epperly's thesis.
"""
rdtype = np.abs(dtype(0)).dtype
n = 1000
num_rep = 10
key = prng.prng_key(1)
key_mat, key = prng.split(key)
U = linalg.qr_reduced(prng.normal(key_mat, shape=(n, n), dtype=dtype))[0]
d = 0.7 ** np.arange(n).astype(rdtype)
expected = linalg.abs2(U) @ d # diag(U diag(d) U^H)_j = sum_k d_k |U_jk|^2

sampler = stochtrace.sampler_signs(np.ones(n, dtype=dtype), num=35)
integrand = stochtrace.leave_one_out_xdiag()
estimate = stochtrace.estimator_leave_one_out(integrand, sampler)

def matvec(v, d, U):
return U @ (d * (U.T.conj() @ v))

key_ests = prng.split(key, num_rep)
received = func.vmap(lambda key: estimate(matvec, key, d, U))(key_ests)
max_abs_err = np.array_max(np.abs(received - expected), axis=1)
max_rel_err = max_abs_err / np.array_max(np.abs(expected))
assert float(np.median(max_rel_err)) < 1e-2


@testing.parametrize("dtype", [float, complex])
def test_xdiag_large_spectral_drop(dtype):
"""Assert that a diagonal with a large spectral drop is estimated accurately.

Reproduces the setup of the experiment 'step' from Fig 16.1 of Ethan Epperly's thesis.
"""
rdtype = np.abs(dtype(0)).dtype
n = 1000
m = 50
num_rep = 10
key = prng.prng_key(4)
key_mat, key = prng.split(key)
U = linalg.qr_reduced(prng.normal(key_mat, shape=(n, n), dtype=dtype))[0]
large_eigenvalues = np.ones(m, dtype=rdtype)
small_eigenvalues = np.ones(n - m, dtype=rdtype) * 1e-3
d = np.concatenate([large_eigenvalues, small_eigenvalues])
expected = linalg.abs2(U) @ d

sampler = stochtrace.sampler_signs(np.ones(n, dtype=dtype), num=m + 10)
integrand = stochtrace.leave_one_out_xdiag()
estimate = stochtrace.estimator_leave_one_out(integrand, sampler)

def matvec(v, d, U):
return U @ (d * (U.T.conj() @ v))

key_ests = prng.split(key, num_rep)
received = func.vmap(lambda key: estimate(matvec, key, d, U))(key_ests)
max_abs_err = np.array_max(np.abs(received - expected), axis=1)
norm_expected = np.array_max(np.abs(expected))
median_max_rel_err = np.median(max_abs_err / norm_expected)
assert float(median_max_rel_err) < 5e-2


def test_xdiag_pytrees_supported():
"""Assert that the XDiag algorithm supports pytrees."""
n1 = 100
n2 = 50
key_mat1, key_mat2, key_est = prng.split(prng.prng_key(1), 3)
A = prng.normal(key_mat1, shape=(n1, n1))
B = prng.normal(key_mat2, shape=(n2, n2))

def matvec(v, A, B):
return {"fx": A @ v["fx"], "fy": B @ v["fy"]}

integrand = stochtrace.leave_one_out_xdiag()
x_like = {"fx": np.ones(n1), "fy": np.ones(n2)}
sampler = stochtrace.sampler_sphere(x_like, num=n1 + n2 - 1)
estimate = stochtrace.estimator_leave_one_out(integrand, sampler)

received = estimate(matvec, key_est, A, B)
expected = {"fx": linalg.diagonal(A), "fy": linalg.diagonal(B)}
assert isinstance(received, dict)
assert set(received.keys()) == {"fx", "fy"}
assert np.allclose(received["fx"], expected["fx"], rtol=1e-4)
assert np.allclose(received["fy"], expected["fy"], rtol=1e-4)
Loading
Loading