Skip to content

Commit e7e2ebb

Browse files
authored
Sync index and readme (#113)
* Sync index.md and README.md * Include index.md in tests
1 parent 4afadb1 commit e7e2ebb

2 files changed

Lines changed: 44 additions & 5 deletions

File tree

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ lint:
77

88
test:
99
python -m doctest README.md
10+
python -m doctest docs/index.md
1011
pytest -x -v
1112

1213

docs/index.md

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Imports:
2626
```python
2727
>>> import jax
2828
>>> import jax.numpy as jnp
29-
>>> from matfree import hutch, montecarlo
29+
>>> from matfree import hutch, montecarlo, slq
3030

3131
>>> a = jnp.reshape(jnp.arange(12.), (6, 2))
3232
>>> key = jax.random.PRNGKey(1)
@@ -69,13 +69,12 @@ Jointly estimating traces and diagonals improves performance.
6969
Here is how to use it:
7070

7171
```python
72-
>>> keys = jax.random.split(key, num=10_000)
73-
>>> trace, diagonal = hutch.trace_and_diagonal(matvec, keys=keys, sample_fun=sample_fun)
72+
>>> trace, diagonal = hutch.trace_and_diagonal(matvec, key=key, num_levels=10_000, sample_fun=sample_fun)
7473
>>> print(jnp.round(trace))
75-
509.0
74+
507.0
7675

7776
>>> print(jnp.round(diagonal))
78-
[222. 287.]
77+
[221. 287.]
7978

8079
>>> # for comparison:
8180
>>> print(jnp.round(jnp.trace(a.T @ a)))
@@ -85,6 +84,45 @@ Here is how to use it:
8584
[220. 286.]
8685

8786

87+
```
88+
89+
Why is the argument called `num_levels`? Because under the hood,
90+
`trace_and_diagonal` implements a multilevel diagonal-estimation scheme:
91+
```python
92+
>>> _, diagonal_1 = hutch.trace_and_diagonal(matvec, key=key, num_levels=10_000, sample_fun=sample_fun)
93+
>>> init = jnp.zeros(shape=(2,), dtype=float)
94+
>>> diagonal_2 = hutch.diagonal_multilevel(matvec, init, key=key, num_levels=10_000, sample_fun=sample_fun)
95+
96+
>>> print(jnp.round(diagonal_1, 4))
97+
[220.54979 286.7476 ]
98+
99+
>>> print(jnp.round(diagonal_2, 4))
100+
[220.54979 286.7476 ]
101+
102+
>>> diagonal = hutch.diagonal_multilevel(matvec, init, key=key, num_levels=10, num_samples_per_batch=1000, num_batches_per_level=10, sample_fun=sample_fun)
103+
>>> print(jnp.round(diagonal))
104+
[219. 285.]
105+
106+
```
107+
108+
Does the multilevel scheme help? That is not always clear; but [here](https://github.com/pnkraemer/matfree/blob/main/docs/benchmarks/control_variates.py) is a benchmark.
109+
110+
### Determinants
111+
112+
113+
Estimate log-determinants as such:
114+
```python
115+
>>> a = jnp.reshape(jnp.arange(36.), (6, 6)) / 36
116+
>>> sample_fun = montecarlo.normal(shape=(6,))
117+
>>> matvec = lambda x: a.T @ (a @ x) + x
118+
>>> order = 3
119+
>>> logdet = slq.logdet(matvec, order, key=key, sample_fun=sample_fun)
120+
>>> print(jnp.round(logdet))
121+
3.0
122+
>>> # for comparison:
123+
>>> print(jnp.round(jnp.linalg.slogdet(a.T @ a + jnp.eye(6))[1]))
124+
3.0
125+
88126
```
89127

90128
## Contributing

0 commit comments

Comments
 (0)