@@ -26,7 +26,7 @@ Imports:
2626``` python
2727>> > import jax
2828>> > import jax.numpy as jnp
29- >> > from matfree import hutch, montecarlo
29+ >> > from matfree import hutch, montecarlo, slq
3030
3131>> > a = jnp.reshape(jnp.arange(12 .), (6 , 2 ))
3232>> > key = jax.random.PRNGKey(1 )
@@ -69,13 +69,12 @@ Jointly estimating traces and diagonals improves performance.
6969Here is how to use it:
7070
7171``` python
72- >> > keys = jax.random.split(key, num = 10_000 )
73- >> > trace, diagonal = hutch.trace_and_diagonal(matvec, keys = keys, sample_fun = sample_fun)
72+ >> > trace, diagonal = hutch.trace_and_diagonal(matvec, key = key, num_levels = 10_000 , sample_fun = sample_fun)
7473>> > print (jnp.round(trace))
75- 509 .0
74+ 507 .0
7675
7776>> > print (jnp.round(diagonal))
78- [222 . 287.]
77+ [221 . 287.]
7978
8079>> > # for comparison:
8180>> > print (jnp.round(jnp.trace(a.T @ a)))
@@ -85,6 +84,45 @@ Here is how to use it:
8584[220 . 286.]
8685
8786
87+ ```
88+
89+ Why is the argument called ` num_levels ` ? Because under the hood,
90+ ` trace_and_diagonal ` implements a multilevel diagonal-estimation scheme:
91+ ``` python
92+ >> > _, diagonal_1 = hutch.trace_and_diagonal(matvec, key = key, num_levels = 10_000 , sample_fun = sample_fun)
93+ >> > init = jnp.zeros(shape = (2 ,), dtype = float )
94+ >> > diagonal_2 = hutch.diagonal_multilevel(matvec, init, key = key, num_levels = 10_000 , sample_fun = sample_fun)
95+
96+ >> > print (jnp.round(diagonal_1, 4 ))
97+ [220.54979 286.7476 ]
98+
99+ >> > print (jnp.round(diagonal_2, 4 ))
100+ [220.54979 286.7476 ]
101+
102+ >> > diagonal = hutch.diagonal_multilevel(matvec, init, key = key, num_levels = 10 , num_samples_per_batch = 1000 , num_batches_per_level = 10 , sample_fun = sample_fun)
103+ >> > print (jnp.round(diagonal))
104+ [219 . 285.]
105+
106+ ```
107+
108+ Does the multilevel scheme help? That is not always clear; but [ here] ( https://github.com/pnkraemer/matfree/blob/main/docs/benchmarks/control_variates.py ) is a benchmark.
109+
110+ ### Determinants
111+
112+
113+ Estimate log-determinants as such:
114+ ``` python
115+ >> > a = jnp.reshape(jnp.arange(36 .), (6 , 6 )) / 36
116+ >> > sample_fun = montecarlo.normal(shape = (6 ,))
117+ >> > matvec = lambda x : a.T @ (a @ x) + x
118+ >> > order = 3
119+ >> > logdet = slq.logdet(matvec, order, key = key, sample_fun = sample_fun)
120+ >> > print (jnp.round(logdet))
121+ 3.0
122+ >> > # for comparison:
123+ >> > print (jnp.round(jnp.linalg.slogdet(a.T @ a + jnp.eye(6 ))[1 ]))
124+ 3.0
125+
88126```
89127
90128## Contributing
0 commit comments