Skip to content

Commit f601a12

Browse files
authored
Logdet example in README (#98)
1 parent c63ef37 commit f601a12

4 files changed

Lines changed: 30 additions & 18 deletions

File tree

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Imports:
2626
```python
2727
>>> import jax
2828
>>> import jax.numpy as jnp
29-
>>> from matfree import hutch, montecarlo
29+
>>> from matfree import hutch, montecarlo, slq
3030

3131
>>> a = jnp.reshape(jnp.arange(12.), (6, 2))
3232
>>> key = jax.random.PRNGKey(1)
@@ -85,6 +85,24 @@ Here is how to use it:
8585
[220. 286.]
8686

8787

88+
```
89+
90+
### Determinants
91+
92+
93+
Estimate log-determinants as such:
94+
```python
95+
>>> a = jnp.reshape(jnp.arange(36.), (6, 6)) / 36
96+
>>> sample_fun = montecarlo.normal(shape=(6,))
97+
>>> matvec = lambda x: a.T @ (a @ x) + x
98+
>>> order = 3
99+
>>> logdet, _ = slq.trace_of_matfun(jnp.log, matvec, order, key=key, sample_fun=sample_fun)
100+
>>> print(jnp.round(logdet))
101+
3.0
102+
>>> # for comparison:
103+
>>> print(jnp.round(jnp.linalg.slogdet(a.T @ a + jnp.eye(6))[1]))
104+
3.0
105+
88106
```
89107

90108
## Contributing

matfree/slq.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Stochastic Lanczos quadrature."""
22

33
from matfree import decomp, montecarlo
4-
from matfree.backend import func, linalg, np, prng
4+
from matfree.backend import func, linalg, np
55

66

77
def trace_of_matfun(
@@ -11,22 +11,16 @@ def trace_of_matfun(
1111
/,
1212
*,
1313
key,
14-
num_samples_per_batch,
15-
num_batches,
16-
tangents_shape,
17-
tangents_dtype,
18-
sample_fun=prng.normal,
14+
sample_fun,
15+
num_samples_per_batch=10,
16+
num_batches=1,
1917
):
2018
"""Compute the trace of the function of a matrix.
2119
2220
For example, logdet(M) = trace(log(M)) ~ trace(U log(D) Ut) = E[v U log(D) Ut vt].
2321
"""
24-
25-
def sample(k, /):
26-
return sample_fun(k, shape=tangents_shape, dtype=tangents_dtype)
27-
2822
quadform = quadratic_form_slq(matfun, Av, order)
29-
quadform_mc = montecarlo.montecarlo(quadform, sample_fun=sample)
23+
quadform_mc = montecarlo.montecarlo(quadform, sample_fun=sample_fun)
3024

3125
quadform_batch = montecarlo.mean_vmap(quadform_mc, num_samples_per_batch)
3226
quadform_batch = montecarlo.mean_map(quadform_batch, num_batches)

tests/test_autodiff.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Tests for (selected) autodiff functionality."""
22

33

4-
from matfree import slq, test_util
4+
from matfree import montecarlo, slq, test_util
55
from matfree.backend import np, prng, testing
66

77

@@ -32,14 +32,14 @@ def fun(s):
3232

3333
def _logdet(A, order, key):
3434
n, _ = np.shape(A)
35+
fun = montecarlo.normal(shape=(n,))
3536
received, num_nans = slq.trace_of_matfun(
3637
np.log,
3738
lambda v: A @ v,
3839
order,
3940
key=key,
4041
num_samples_per_batch=10,
4142
num_batches=1,
42-
tangents_shape=(n,),
43-
tangents_dtype=float,
43+
sample_fun=fun,
4444
)
4545
return received

tests/test_slq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for Lanczos functionality."""
22

3-
from matfree import slq, test_util
3+
from matfree import montecarlo, slq, test_util
44
from matfree.backend import linalg, np, prng, testing
55

66

@@ -23,15 +23,15 @@ def test_logdet(A, order):
2323
"""Assert that the log-determinant estimation matches the true log-determinant."""
2424
n, _ = np.shape(A)
2525
key = prng.PRNGKey(1)
26+
fun = montecarlo.normal(shape=(n,))
2627
received, num_nans = slq.trace_of_matfun(
2728
np.log,
2829
lambda v: A @ v,
2930
order,
3031
key=key,
3132
num_samples_per_batch=10,
3233
num_batches=1,
33-
tangents_shape=(n,),
34-
tangents_dtype=np.dtype(A),
34+
sample_fun=fun,
3535
)
3636
expected = linalg.slogdet(A)[1]
3737
print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)

0 commit comments

Comments
 (0)