-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add rewrites with RV reparametrization tricks #8056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #8056 +/- ##
==========================================
- Coverage 91.42% 90.80% -0.62%
==========================================
Files 117 121 +4
Lines 19154 19443 +289
==========================================
+ Hits 17512 17656 +144
- Misses 1642 1787 +145 🚀 New features to boost your workflow:
|
| @node_rewriter([GammaRV, InvGammaRV]) | ||
| def gamma_reparametrization(fgraph, node): | ||
| rng, size, shape, scale = node.inputs | ||
| return scale * node.op(shape, 1.0, rng=rng, size=size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't pull the shape parameter out of the RV?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The actual reparameterization trick in this case requires some effort (Marsaglia-Tsang method), jax implementation here:
https://github.com/jax-ml/jax/blob/c6568036b83b39d556c68d647a99196e63062612/jax/_src/random.py#L1298
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll write the implementation for that then. Thanks @jessegrabowski
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jessegrabowski, I wrote the generator code. I still need to test that the samples it returns are distributed like the reference Gamma, but maybe you could have a look and comment on the scans
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've fixed the gamma reparametrization and it now matches the generation we get from Gamma.dist. I can't take gradients through the reparametrized version because of pymc-devs/pytensor#555
| @node_rewriter([BernoulliRV]) | ||
| def bernoulli_reparametrization(fgraph, node): | ||
| rng, size, p = node.inputs | ||
| return switch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commenting here but in general you may be making non equivalent graphs, by using the size argument. Remember it may be None, meaning it will be implied by the parameters shape. In that case you may be rewriting something like normal([0, 1, 2], size=None) as [0, 1, 2] + normal(size=None)
There's a non default PyTensor rewrite that makes size explicit that you may want to use prior to these rewrites. Then make sure these rewrites fail if size is None isinstance(size.type, NoneTypeT). If you do have size it will be all you need
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I'll have to ask you how to add that rewrite into the database and make it go before any reparametrization rewrites
| ) | ||
| from pytensor.tensor.slinalg import cholesky | ||
|
|
||
| reparametrization_trick_db = RewriteDatabaseQuery(include=["random_reparametrization_trick"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works but it's a bit expensive every time you extend it. You should start with an actual database like SequenceDB, then you can register rewrite phases in it. pymc/logprob/rewriting.py may be a good example.
I'm even surprise that the RewriteDatabaseQuery has a register argument. Feels like bad design
| @register_random_reparametrization | ||
| @node_rewriter([DirichletRV]) | ||
| def dirichlet_reparametrization(fgraph, node): | ||
| raise NotImplementedError("DirichletRV is not reparametrizable") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doesn't need to block this PR but I know tfp has something for the dirichlet that is one time differentiable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It comes from the Gamma, so we need to sort that out first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They might have changed and they just define dirichlet with the normalized batch of gamma variates, which would be differentiable if the gamma also are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I'll have a look at that
| next_rng, U = UniformRV()( | ||
| zeros_like(alpha), | ||
| ones_like(alpha), | ||
| rng=rng, | ||
| ).owner.outputs | ||
| next_rng, x = NormalRV()( | ||
| zeros_like(c), | ||
| ones_like(c), | ||
| rng=next_rng, | ||
| ).owner.outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's wasteful, but to avoid the current grad limitations you can sample all of these once outside of the scan then pass them in as sequences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because nothing about U or x actually depends on the state of the scan
This was born from pymc-devs/pytensor#1424. I started to write down the reparametrization tricks I could find or figure out for most of pytensor's basic
RandomVariableOps