-
Notifications
You must be signed in to change notification settings - Fork 7
feat(stochtrace): Add LOO-based estimators of the diagonal #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+808
−252
Merged
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
ae1e74d
fix(stochtrace): Support pytrees in LOO estimators
sethaxen b910a7f
feat(backend): Add np.array_max
sethaxen 146632e
feat(backend): Add np.median
sethaxen e4f58b4
feat(stochtrace): Add XDiag
sethaxen 5528550
feat(stochtrace): Add XNysDiag
sethaxen 53f4f9c
feat(backend)!: Generate Hermitian matrix with random eigenvectors
sethaxen 490df1d
test: Use hermitian_matrix_from_eigenvalues in old tests
sethaxen e51123a
feat(backend): Remove old symmetric generator
sethaxen c5fda52
test(stochtrace): Use Hermitian generator in LOO diag tests
sethaxen aa7ac92
test(stochtrace): Use Hermitian generator
sethaxen 161ae92
test(stochtrace): Add helper functions for experiments
sethaxen 3822f36
test(stochtrace): Use helpers in existing x(nys)trace tests
sethaxen 973b918
test(stochtrace): Rename helpers file to conftest
sethaxen e7da6c8
test(stochtrace): Move shared nystrom fixture to conftest
sethaxen be3c045
test(stochtrace): Refactor xdiag tests to use cases
sethaxen b774331
test(stochtrace): Refactor xnysdiag tests to use cases
sethaxen 5698dc3
test(stochtrace): Remove xnysdiag pytest tests
sethaxen dadf7c3
refactor(test_util): Move eigenvalue helpers from conftest to test_util
sethaxen 80b7dd8
refactor(test_util): Add hermitian_matrix_eigvals_decaying/step helpers
sethaxen 3e93650
test(stochtrace): Replace nystrom fixture with direct parametrize
sethaxen 27e7445
test(stochtrace): Remove unnecessary __init__.py
sethaxen 51106e1
Merge remote-tracking branch 'upstream/main' into xdiag_xnysdiag
sethaxen 58bbee3
test(stochtrace): loosen atol in nystrom non-essential-vectors test
sethaxen df8d08a
refactor(test_util): make decay base configurable in hermitian_matrix…
sethaxen de2ad3e
fix(test_util): preserve complex dtype in hermitian_matrix_eigvals_de…
sethaxen 8cdeef2
test(stochtrace): restructure xtrace tests to match xdiag pattern
sethaxen f383d42
test(stochtrace): restructure xnystrace tests to match xnysdiag pattern
sethaxen a8d27ec
test(stochtrace): make xnystrace docstring standalone
sethaxen ca3302e
test(stochtrace): cover heterogeneous pytrees in xtrace/xnystrace exa…
sethaxen fe8bd2e
test(stochtrace): cover heterogeneous pytrees in xdiag/xnysdiag exact…
sethaxen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
tests/test_stochtrace/test_leave_one_out/test_estimator_loo.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| """Test estimator_leave_one_out and estimator_leave_one_out_mean_and_sem.""" | ||
|
|
||
| from matfree import stochtrace | ||
| from matfree.backend import np, prng | ||
|
|
||
|
|
||
| def test_estimator_leave_one_out_pytree_output(): | ||
| """Assert that estimator_leave_one_out handles pytree-returning integrands.""" | ||
| n, num_samples = 7, 5 | ||
| key = prng.prng_key(1) | ||
|
|
||
| def pytree_integrand(_matvec, samples, *_params): | ||
| # samples: array of shape (num_samples, n) | ||
| return {"a": samples[:, :3], "b": samples[:, 3:]} | ||
|
|
||
| sampler = stochtrace.sampler_normal(np.ones(n), num=num_samples) | ||
| estimate = stochtrace.estimator_leave_one_out(pytree_integrand, sampler) | ||
| result = estimate(lambda v: v, key) | ||
|
|
||
| assert isinstance(result, dict) | ||
| assert result["a"].shape == (3,) | ||
| assert result["b"].shape == (4,) | ||
|
|
||
|
|
||
| def test_estimator_leave_one_out_mean_and_sem_pytree_output(): | ||
| """Assert that estimator_leave_one_out_mean_and_sem handles pytree-returning integrands.""" | ||
| n, num_samples = 7, 5 | ||
| key = prng.prng_key(1) | ||
|
|
||
| def pytree_integrand(_matvec, samples, *_params): | ||
| return {"a": samples[:, :3], "b": samples[:, 3:]} | ||
|
|
||
| sampler = stochtrace.sampler_normal(np.ones(n), num=num_samples) | ||
| estimate = stochtrace.estimator_leave_one_out_mean_and_sem( | ||
| pytree_integrand, sampler | ||
| ) | ||
| mean, sem = estimate(lambda v: v, key) | ||
|
|
||
| assert isinstance(mean, dict) | ||
| assert mean["a"].shape == (3,) | ||
| assert mean["b"].shape == (4,) | ||
| assert isinstance(sem, dict) | ||
| assert sem["a"].shape == (3,) | ||
| assert sem["b"].shape == (4,) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| """Tests for leave_one_out_xdiag.""" | ||
|
|
||
| from matfree import stochtrace, test_util | ||
| from matfree.backend import func, linalg, np, prng, testing | ||
|
|
||
|
|
||
| @testing.parametrize("n", [10, 20]) | ||
| def test_xdiag_error_num_samples_more_than_dimension(n): | ||
| """Assert that num_samples greater than the dimension raises a ValueError.""" | ||
| key = prng.prng_key(1) | ||
| A = np.eye(n) | ||
|
|
||
| def matvec(v, A): | ||
| return A @ v | ||
|
|
||
| integrand = stochtrace.leave_one_out_xdiag() | ||
| sampler = stochtrace.sampler_normal(np.ones(n), num=n + 1) | ||
| estimate = stochtrace.estimator_leave_one_out(integrand, sampler) | ||
| message = f"Number of samples num={n + 1} exceeds the acceptable range." | ||
| message = f"{message} Expected: 1 <= num <= {n}." | ||
| with testing.raises(ValueError, match=message): | ||
| estimate(matvec, key, A) | ||
|
|
||
|
|
||
| @testing.parametrize("n, num_samples", [(10, 5), (21, 11)]) | ||
| @testing.parametrize( | ||
| "dtype_op, dtype_sample", | ||
| [(float, float), (complex, complex), (complex, float), (float, complex)], | ||
| ) | ||
| def test_xdiag_exact_when_num_samples_more_than_half_dimension( | ||
| n, num_samples, dtype_op, dtype_sample | ||
| ): | ||
| """Assert exact diagonal computation when num_samples is large.""" | ||
| key = prng.prng_key(1) | ||
| A = np.tril(np.ones((n, n), dtype=dtype_op)) | ||
| expected = linalg.diagonal(A) | ||
|
|
||
| def matvec(v, A): | ||
| return A @ v | ||
|
|
||
| sampler = stochtrace.sampler_normal(np.ones(n, dtype=dtype_sample), num=num_samples) | ||
| integrand = stochtrace.leave_one_out_xdiag() | ||
| estimate = stochtrace.estimator_leave_one_out(integrand, sampler) | ||
| test_util.assert_allclose(estimate(matvec, key, A), expected) | ||
|
|
||
|
|
||
| @testing.parametrize("n, rank", [(50, 10), (100, 30)]) | ||
| @testing.parametrize("dtype", [float, complex]) | ||
| def test_xdiag_low_rank_operator(n, rank, dtype): | ||
| """Assert that the diagonal of a low-rank operator is computed exactly.""" | ||
| key = prng.prng_key(5) | ||
| key_mat1, key_mat2, key = prng.split(key, 3) | ||
| A = prng.normal(key_mat1, shape=(n, rank), dtype=dtype) | ||
| B = prng.normal(key_mat2, shape=(rank, n), dtype=dtype) | ||
| expected = linalg.diagonal(A @ B) | ||
|
|
||
| def matvec(v, A, B): | ||
| return A @ (B @ v) | ||
|
|
||
| sampler = stochtrace.sampler_normal(np.ones(n, dtype=dtype), num=rank + 1) | ||
| integrand = stochtrace.leave_one_out_xdiag() | ||
| estimate = stochtrace.estimator_leave_one_out(integrand, sampler) | ||
| test_util.assert_allclose(estimate(matvec, key, A, B), expected) | ||
|
|
||
|
|
||
| @testing.parametrize("dtype", [float, complex]) | ||
| def test_xdiag_fast_spectral_decay(dtype): | ||
| """Assert that a diagonal with fast spectral decay is estimated accurately. | ||
|
|
||
| Reproduces the setup of the experiment 'exp' from Fig 16.1 of Ethan Epperly's thesis. | ||
| """ | ||
| rdtype = np.abs(dtype(0)).dtype | ||
| n = 1000 | ||
| num_rep = 10 | ||
| key = prng.prng_key(1) | ||
| key_mat, key = prng.split(key) | ||
| U = linalg.qr_reduced(prng.normal(key_mat, shape=(n, n), dtype=dtype))[0] | ||
| d = 0.7 ** np.arange(n).astype(rdtype) | ||
| expected = linalg.abs2(U) @ d # diag(U diag(d) U^H)_j = sum_k d_k |U_jk|^2 | ||
|
|
||
| sampler = stochtrace.sampler_signs(np.ones(n, dtype=dtype), num=35) | ||
| integrand = stochtrace.leave_one_out_xdiag() | ||
| estimate = stochtrace.estimator_leave_one_out(integrand, sampler) | ||
|
|
||
| def matvec(v, d, U): | ||
| return U @ (d * (U.T.conj() @ v)) | ||
|
|
||
| key_ests = prng.split(key, num_rep) | ||
| received = func.vmap(lambda key: estimate(matvec, key, d, U))(key_ests) | ||
| max_abs_err = np.array_max(np.abs(received - expected), axis=1) | ||
| max_rel_err = max_abs_err / np.array_max(np.abs(expected)) | ||
| assert float(np.median(max_rel_err)) < 1e-2 | ||
|
|
||
|
|
||
| @testing.parametrize("dtype", [float, complex]) | ||
| def test_xdiag_large_spectral_drop(dtype): | ||
| """Assert that a diagonal with a large spectral drop is estimated accurately. | ||
|
|
||
| Reproduces the setup of the experiment 'step' from Fig 16.1 of Ethan Epperly's thesis. | ||
| """ | ||
| rdtype = np.abs(dtype(0)).dtype | ||
| n = 1000 | ||
| m = 50 | ||
| num_rep = 10 | ||
| key = prng.prng_key(4) | ||
| key_mat, key = prng.split(key) | ||
| U = linalg.qr_reduced(prng.normal(key_mat, shape=(n, n), dtype=dtype))[0] | ||
| large_eigenvalues = np.ones(m, dtype=rdtype) | ||
| small_eigenvalues = np.ones(n - m, dtype=rdtype) * 1e-3 | ||
| d = np.concatenate([large_eigenvalues, small_eigenvalues]) | ||
| expected = linalg.abs2(U) @ d | ||
|
|
||
| sampler = stochtrace.sampler_signs(np.ones(n, dtype=dtype), num=m + 10) | ||
| integrand = stochtrace.leave_one_out_xdiag() | ||
| estimate = stochtrace.estimator_leave_one_out(integrand, sampler) | ||
|
|
||
| def matvec(v, d, U): | ||
| return U @ (d * (U.T.conj() @ v)) | ||
|
|
||
| key_ests = prng.split(key, num_rep) | ||
| received = func.vmap(lambda key: estimate(matvec, key, d, U))(key_ests) | ||
| max_abs_err = np.array_max(np.abs(received - expected), axis=1) | ||
| norm_expected = np.array_max(np.abs(expected)) | ||
| median_max_rel_err = np.median(max_abs_err / norm_expected) | ||
| assert float(median_max_rel_err) < 5e-2 | ||
|
|
||
|
|
||
| def test_xdiag_pytrees_supported(): | ||
| """Assert that the XDiag algorithm supports pytrees.""" | ||
| n1 = 100 | ||
| n2 = 50 | ||
| key_mat1, key_mat2, key_est = prng.split(prng.prng_key(1), 3) | ||
| A = prng.normal(key_mat1, shape=(n1, n1)) | ||
| B = prng.normal(key_mat2, shape=(n2, n2)) | ||
|
|
||
| def matvec(v, A, B): | ||
| return {"fx": A @ v["fx"], "fy": B @ v["fy"]} | ||
|
|
||
| integrand = stochtrace.leave_one_out_xdiag() | ||
| x_like = {"fx": np.ones(n1), "fy": np.ones(n2)} | ||
| sampler = stochtrace.sampler_sphere(x_like, num=n1 + n2 - 1) | ||
| estimate = stochtrace.estimator_leave_one_out(integrand, sampler) | ||
|
|
||
| received = estimate(matvec, key_est, A, B) | ||
| expected = {"fx": linalg.diagonal(A), "fy": linalg.diagonal(B)} | ||
| assert isinstance(received, dict) | ||
| assert set(received.keys()) == {"fx", "fy"} | ||
| assert np.allclose(received["fx"], expected["fx"], rtol=1e-4) | ||
| assert np.allclose(received["fy"], expected["fy"], rtol=1e-4) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.