Skip to content

Commit cf45304

Browse files
committed
Rewrote Laplace approximation design doc as new algorithm + fixed the unconstrained space stuff (design-doc #16)
1 parent eb41985 commit cf45304

File tree

1 file changed

+102
-41
lines changed

1 file changed

+102
-41
lines changed
+102-41
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,160 @@
1-
- Feature Name: Allow Cmdstan to optionally print out draws from Laplace approximation of posterior when optimizing
1+
- Feature Name: Laplace approximation as a new algorithm
22
- Start Date: 2020-03-08
33
- RFC PR: 16
44
- Stan Issue:
55

66
# Summary
77
[summary]: #summary
88

9-
When computing a MAP estimate, the Hessian of the log density can be used to construct a normal approximation to the posterior.
9+
The proposal is to add a Laplace approximation on the unconstrained space as a
10+
form of approximate inference.
1011

1112
# Motivation
1213
[motivation]: #motivation
1314

14-
I have a Stan model that is too slow to practically sample all the time. Because optimization seem to give reasonable results, it would be nice to have the normal approximation to the posterior to give some sense of the uncertainty in the problem as well.
15+
I have a Stan model that is too slow to sample. I would like to do something
16+
better than optimization. Laplace approximations are a pretty standard way
17+
of doing this.
1518

16-
It is standard to compute a normal approximation to the posterior covariance comes from the inverse of the Hessian of the negative log density.
19+
# Guide-level explanation
20+
[guide-level-explanation]: #guide-level-explanation
1721

18-
If the MAP estimate is ```u```, and the Hessian of the unnormalized log density at u is ```H```, then a posterior draw on the constrained scale is:
22+
The `laplace` algorithm would work by forming a Laplace approximation to the
23+
unconstrained posterior density.
24+
25+
Assuming `u` are the unconstrained variables, `c` are the constrained variables,
26+
and `c = g(u)`, the log density sampled by Stan is:
1927

2028
```
21-
unconstrained_sample = multivariate_normal_rng(mean = u, cov = -inverse(H))
22-
constrained_sample = constrain(unconstrained_sample)
29+
log(p(u)) = log(p(g(u))) + log(det(jac(g)))
2330
```
2431

25-
We can output unnormalized log densities of the actual model and the approximate model to compute importance sampling diagnostics and estimates.
32+
In the Laplace approximation, we search for a mode (a maximum) of
33+
```log(p(u))```. Call this `u_mode`. This is not the same optimization that is
34+
done in the `optimizing` algorithm. That searches for a mode of `log(p(g(u)))`
35+
(or the equation above without the `log(det(jac(g)))` term. These are not the
36+
same optimizations.
2637

27-
Rstan already supports this via the 'hessian' argument to 'optimizing'.
38+
We can form a second order Taylor expansion of `log(p(u))` around `u_mode`:
2839

29-
# Guide-level explanation
30-
[guide-level-explanation]: #guide-level-explanation
40+
```
41+
log(p(u)) = log(p(u_mode))
42+
+ gradient(log(p), u_mode) * (u - umode)
43+
+ 0.5 * (u - u_mode)^T * hessian(log(p), u_mode) * (u - u_mode)
44+
+ O(||u - u_mode||^3)
45+
```
46+
47+
where `gradient(log(p), u)` is the gradient of `log(p)` at `u` and
48+
`hessian(log(p), u)` is the hessian of `log(p)` at `u`. Because the gradient
49+
is zero at the mode, the linear term drops out. Ignoring the third order
50+
terms gives us a new distribution `p_approx(u)`:
51+
52+
```
53+
log(p_approx(u)) = K + 0.5 * (u - u_mode)^T * hessian(log(p), u_mode) * (u - u_mode)
54+
```
55+
56+
where K is a constant to make this normalize. `u_approx` (`u` sampled from
57+
`p_approx(u)`) takes the distribution:
58+
```
59+
u_approx ~ N(u_mode, -(hessian(log(p), u_mode))^{-1})
60+
```
61+
62+
Taking draws from `u_approx` gives us draws from our distribution on Stan's
63+
unconstrained space. Once constrained, these draws can be used in the same
64+
way that regular draws from the `sampling` algorithm are used.
3165

32-
This adds two arguments to the cmdstan interface.
66+
The `laplace` algorithm would take in the same arguments as the `optimize`
67+
algorithm plus two additional ones:
3368

34-
```laplace_draws``` - The number of draws to take from the posterior approximation. By default, this is zero, and no laplace approximation is done
69+
```num_samples``` - The number of draws to take from the posterior
70+
approximation. This should be greater than one. (default to 1000)
3571

36-
```laplace_add_diag``` - A value to add to the diagonal of the hessian approximation to fix small non-singularities (defaulting to zero)
72+
```add_diag``` - A value to add to the diagonal of the hessian
73+
approximation to fix small non-singularities (defaulting to zero)
3774

3875
The output is printed after the optimimum.
3976

4077
A model can be called by:
4178
```
42-
./model optimize laplace_draws=100 data file=data.dat
79+
./model laplace num_samples=100 data file=data.dat
4380
```
4481

4582
or with the diagonal:
4683
```
47-
./model optimize laplace_draws=100 laplace_add_diag=1e-10 data file=data.dat
84+
./model laplace num_samples=100 add_diag=1e-10 data file=data.dat
4885
```
4986

50-
Optimizing output currently looks like:
51-
```
52-
# stan_version_major = 2
53-
...
54-
# refresh = 100 (Default)
55-
lp__,b.1,b.2
56-
3427.64,7.66366,5.33466
57-
```
87+
The output would mirror the other interfaces and print all the algorithm
88+
specific parameters with two trailing underscores appended to each followed
89+
by all the other arguments.
5890

59-
The new output would look like:
91+
The three algorithm specific parameters are:
92+
1. ```log_p__``` - The log density of the model itself
93+
2. ```log_g__``` - The log density of the Laplace approximation
94+
3. ```rejected__``` - A boolean data indicating whether it was possible to
95+
evaluate the log density of the model at this parameter value
96+
97+
For instance, the new output might look like:
6098

6199
```
62100
# stan_version_major = 2
63101
...
64-
# refresh = 100 (Default)
65-
lp__,b.1,b.2
66-
3427.64,7.66366,5.33466
67102
# Draws from Laplace approximation:
68-
lp__, log_p, log_g, b.1,b.2
69-
0, -1, -2, 7.66364,5.33463
70-
0, -2, -3, 7.66367,5.33462
103+
log_p__, log_g__, rejected__, b.1,b.2
104+
-1, -2, 0, 7.66364,5.33463
105+
-2, -3, 0, 7.66367,5.33462
106+
-3, -4, 1, 0, 0
71107
```
72108

73-
The lp__, log_p, log_g formatting is intended to mirror the advi output.
74-
75109
# Reference-level explanation
76110
[reference-level-explanation]: #reference-level-explanation
77111

78-
As far as computing the Hessian, because the higher order autodiff doesn't work with the ODE solvers 1D integrator and such, I think we should compute the Hessian with finite differences, and we use the sample finite difference implementation that the test framework does (https://github.com/stan-dev/math/blob/develop/stan/math/prim/functor/finite_diff_hessian_auto.hpp)
112+
The implementation of this would borrow heavily from the optimization code. The
113+
difference would be that the Jacobian would be turned on for the optimization.
79114

80-
Ben Goodrich points out it would be better to get this Hessian using finite differences of first order gradients. He is correct, but the fully finite differenced hessian is what is implemented in Stan math currently and so that is what I'm rolling with.
115+
We will also need to implement a way of computing Hessians with finite
116+
differences of gradients. Simple finite differences were not sufficiently
117+
accurate for an example I was working on.
81118

82119
# Drawbacks
83120
[drawbacks]: #drawbacks
84121

85-
Providing draws instead of the Laplace approximation itself is rather inefficient, but it is the easiest thing to code.
122+
It is not clear to me how to handle errors evaluating the log density.
123+
124+
There are a few options with various drawbacks:
125+
126+
1. Re-sample a new point in the unconstrained space until one is accepted
86127

87-
We also have to deal with possible singular Hessians. This is why I also added the laplace_add_diag to overcome these. They'll probably be quite common, especially with the Hessians computed with finite differences.
128+
With a poorly written model, this may never terminate
129+
130+
2. Quietly reject the sample and print nothing (so it is possible that if someone
131+
requested 200 draws that they only get 150).
132+
133+
This might lead to silent errors if the user is not vigilantly checking the
134+
lengths of their outputs (they may compute incorrect standard errors, etc).
135+
136+
3. Use `rejected__` diagnostic output to indicate a sample was rejected and
137+
print zeros where usually the parameters would be.
138+
139+
If the user does not check the `rejected__` column, then they will be using a
140+
bunch of zeros in their follow-up calculations that could mislead them.
141+
142+
Similarly to divergent transitions, we can print to the output information about
143+
how many draws were rejected in any case.
144+
145+
In the earlier part of the design document I assume #3 was implemented, but I
146+
think #2 might be better.
88147

89148
# Rationale and alternatives
90149
[rationale-and-alternatives]: #rationale-and-alternatives
91150

92-
Another design would be the print the Hessian on the unconstrained space and let users handle the sampling and the parameter transformation. The issue here is there is no good way for users to do these parameter transformations outside of certain interfaces (at least Rstan, maybe PyStan).
93-
94-
Another design would be to print a Hessian on the constrained space and let users handle the sampling. In this case users would also be expected to handle the constraints, and I don't know how that would work practically rejection sampling maybe?)
151+
An alternative that seems appealing at first is printing out the mode and
152+
hessian so that it would be possible for people to make their own posterior
153+
draws. This is not reasonable because the conversion from unconstrained to
154+
constrained space is not exposed in all of the interfaces.
95155

96156
# Prior art
97157
[prior-art]: #prior-art
98158

99-
Rstan does a version of this already.
159+
This is a pretty common Bayesian approximation. It just is not implemented in
160+
Stan.

0 commit comments

Comments
 (0)