Skip to content

Commit 06f3af6

Browse files
authored
Add (squared) Frobenius norm to hutch.py (#96)
1 parent 672760a commit 06f3af6

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

matfree/hutch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ def quadform(vec):
2121
return stochastic_estimate(quadform, **kwargs)
2222

2323

24+
def frobeniusnorm_squared(matvec_fun, /, **kwargs):
25+
"""Estimate the squared Frobenius norm of a matrix stochastically."""
26+
27+
def quadform(vec):
28+
Av = matvec_fun(vec)
29+
return np.vecdot(Av, Av)
30+
31+
return stochastic_estimate(quadform, **kwargs)
32+
33+
2434
def diagonal(matvec_fun, /, **kwargs):
2535
"""Estimate the diagonal of a matrix stochastically."""
2636

tests/test_hutch.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,32 @@ def fixture_key():
2020
return prng.PRNGKey(seed=1)
2121

2222

23+
@testing.parametrize("num_batches", [1_000])
24+
@testing.parametrize("num_samples_per_batch", [1_000])
25+
@testing.parametrize("dim", [1, 10])
26+
@testing.parametrize("sample_fun", [montecarlo.normal, montecarlo.rademacher])
27+
def test_frobeniusnorm_squared(
28+
fun, key, num_batches, num_samples_per_batch, dim, sample_fun
29+
):
30+
"""Assert that the Frobenius norm estimate is accurate."""
31+
# Linearise function
32+
x0 = prng.uniform(key, shape=(dim,)) # random lin. point
33+
_, jvp = func.linearize(fun, x0)
34+
J = func.jacfwd(fun)(x0)
35+
36+
# Estimate the trace
37+
fun = sample_fun(shape=np.shape(x0), dtype=np.dtype(x0))
38+
estimate = hutch.frobeniusnorm_squared(
39+
jvp,
40+
num_batches=num_batches,
41+
key=key,
42+
num_samples_per_batch=num_samples_per_batch,
43+
sample_fun=fun,
44+
)
45+
truth = np.trace(J.T @ J)
46+
assert np.allclose(estimate, truth, rtol=1e-2)
47+
48+
2349
@testing.parametrize("num_batches", [1_000])
2450
@testing.parametrize("num_samples_per_batch", [1_000])
2551
@testing.parametrize("dim", [1, 10])

0 commit comments

Comments
 (0)