Skip to content

changes to rethinking.quap.py #1178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

jeffpollock9
Copy link
Contributor

Hello, as per #1161 I had a look at the quap function. I've added a few noddy changes to it, I hope that it's helpful and not annoying (if the latter please let me know and I will change back).

@ColCarroll is there anything in particular you would like me to attempt to improve in the quap function further? I've updated my version of find_map (gist here: https://gist.github.com/jeffpollock9/0cc298641722277cf30e438f04eb8609) so please let me know if there is anything you'd like to take. I wasn't sure the best way to change quap to use this stuff without changing the API, which I would prefer not to do unless someone else is happy with that.

@google-cla google-cla bot added the cla: yes Declares that the user has signed CLA label Nov 24, 2020
@googlebot googlebot added the cla: yes Declares that the user has signed CLA label Nov 24, 2020
@ColCarroll
Copy link
Collaborator

Thanks for this! It looks fine as is, but also right now is a great time to change the API there, since it is not publicly exposed or even used in the notebooks yet.

Would it be possible to:

  • Rename find_map to _bfgs_optimize, and just return opt out of the final control_dependencies
  • Add a find_MAP calling the above and using your logic to return a dictionary
  • Add a quap calling the above and using my logic to return a JDNamed

What do you think?

@jeffpollock9
Copy link
Contributor Author

Yes that sounds much better than the ideas I had for this - thanks!

I'll try a few ideas along those lines and see how it looks. I'd also like to support Laplace approximations from this sort of code in some way, as per #570, but not sure if that would be well placed in the rethinking examples.

@ColCarroll
Copy link
Collaborator

I was playing with this yesterday, and another problem is that the Laplace approximation is basis-dependent, so the (approximate) hessian you get back from lbfgs depends on the bijectors being used. It seems like there's some choice, then (see this, for example) over what gets returned.

@jeffpollock9
Copy link
Contributor Author

Thanks for the reference, I'll try to go through it soon.

I think the MAP estimate also has problems being basis-dependent, I recall a lot of discussion on the stan forums as to whether or not running stan in optimizing mode should include the bijector jacobian adjustment but I can't find it now and don't think there was a definite answer :(

My favourite document on this stuff is https://mc-stan.org/users/documentation/case-studies/mle-params.html, thought I'd link it on the off chance it is useful to anyone, this is the summary:

The moral of the story is that full Bayesian inference is insensitive to parameterization as long as the appropriate Jacobian adjustment is applied. In contrast, posterior modes (i.e., maximum a posteriori estimates) do change under reparameterization, and thus are no true Bayesian quantity.

I guess I'll think a little harder about how the Laplace stuff could work and go on as I was with the MAP stuff, although any comments about any of this would be great.

@ColCarroll
Copy link
Collaborator

Ah, thanks for the link!

Yes, I imagine the docstring would include a discussion and some of these references, and probably there should be a bijector argument that defaults to default_event_space_bijector, and makes this nice warning about the output being "no true Bayesian quantity."

Does it make sense to you to merge the changes on this PR into quap.py, and you can think about a PR into the main repo with a design closer to the one in the gist?

@jeffpollock9
Copy link
Contributor Author

Yes, that sounds good to me.

@jeffpollock9
Copy link
Contributor Author

Hi @ColCarroll, I've tried to merge all the ideas together into one function, and have uploaded it here https://gist.github.com/jeffpollock9/2a6950164711d8f813561b8253c9cfa9 as I wasn't sure where to put it as I'm hoping this could perhaps be added to the TFP API (shall I open a separate PR with doc + tests?) instead of the rethinking examples. In that case perhaps we could leave the quap function or merge in the small changes (mostly for the docs) in this PR.

This new attempt is different from both of our original functions as it returns a transformed multivariate normal. I think this is quite nice as the object naturally contains all the information that (I think/hope) will be useful for folks. I've not tested it too much but I think this can handle more cases than the current quap function and also keeps the full estimate of the MVN covariance matrix instead of just the diagonal.

An example of the usage:

print(tf.__version__)
# ==> 2.5.0-dev20201201

print(tfp.__version__)
# ==> 0.12.0-dev20201201


@tfd.JointDistributionCoroutineAutoBatched
def joint_dist():
    yield tfd.Normal(loc=1.0, scale=1.0, name="a")
    yield tfd.Normal(loc=[2.0, 2.0], scale=2.0, name="b")
    yield tfd.Gamma(concentration=[2.0, 2.0, 2.0], rate=10.0, name="c")
    yield tfd.CholeskyLKJ(dimension=3, concentration=1.0, name="d")


# without conditioning
approximation = laplace_approximation(joint_dist)

names = joint_dist._flat_resolve_names()

point_estimate = approximation.bijector(approximation.distribution.mode())

for name, estimate in zip(names, point_estimate):
    print(f"{name}:")
    print(estimate.numpy())
    print("")

# ==>
# a:
# 0.9999998

# b:
# [2.000008  1.9999926]

# c:
# [0.09999622 0.10000692 0.09999667]

# d:
# [[ 1.000000e+00  0.000000e+00  0.000000e+00]
#  [ 4.016183e-07  1.000000e+00  0.000000e+00]
#  [ 5.092372e-01 -4.885684e-01  7.085044e-01]]

# with conditioning, also shows that sampling from the approximation works
approximation = laplace_approximation(joint_dist, data={"d": approximation.sample()[3]})

point_estimate = approximation.bijector(approximation.distribution.mode())

for name, estimate in zip(names, point_estimate):
    print(f"{name}:")
    print(estimate.numpy())
    print("")

# ==>
# a:
# 0.9999998

# b:
# [1.9999962 1.999986 ]

# c:
# [0.10000057 0.10000078 0.09999969]

If you (or anyone else) has any comments that would be great but no rush of course.

Thanks again.

@brianwa84
Copy link
Contributor

I think we'd probably be open to a laplace approximation under tfp.experimental.distributions if you want to send a PR to put something like this in there. @ColCarroll wdyt?

@ColCarroll
Copy link
Collaborator

+1 that'd be great -- happy to help review/advise, too.

@jeffpollock9
Copy link
Contributor Author

@brianwa84 @ColCarroll thanks - that sounds great. I'll try to send over a PR for a Laplace approximation in tfp.experimental.distributions when I have some time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Declares that the user has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants