Skip to content

Conversation

@lucianopaz
Copy link
Member

This was born from pymc-devs/pytensor#1424. I started to write down the reparametrization tricks I could find or figure out for most of pytensor's basic RandomVariable Ops

@codecov
Copy link

codecov bot commented Jan 15, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 90.80%. Comparing base (c68c56e) to head (da92d5a).
⚠️ Report is 10 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8056      +/-   ##
==========================================
- Coverage   91.42%   90.80%   -0.62%     
==========================================
  Files         117      121       +4     
  Lines       19154    19443     +289     
==========================================
+ Hits        17512    17656     +144     
- Misses       1642     1787     +145     

see 16 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@node_rewriter([GammaRV, InvGammaRV])
def gamma_reparametrization(fgraph, node):
rng, size, shape, scale = node.inputs
return scale * node.op(shape, 1.0, rng=rng, size=size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't pull the shape parameter out of the RV?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual reparameterization trick in this case requires some effort (Marsaglia-Tsang method), jax implementation here:

https://github.com/jax-ml/jax/blob/c6568036b83b39d556c68d647a99196e63062612/jax/_src/random.py#L1298

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll write the implementation for that then. Thanks @jessegrabowski

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski, I wrote the generator code. I still need to test that the samples it returns are distributed like the reference Gamma, but maybe you could have a look and comment on the scans

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed the gamma reparametrization and it now matches the generation we get from Gamma.dist. I can't take gradients through the reparametrized version because of pymc-devs/pytensor#555

@node_rewriter([BernoulliRV])
def bernoulli_reparametrization(fgraph, node):
rng, size, p = node.inputs
return switch(
Copy link
Member

@ricardoV94 ricardoV94 Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting here but in general you may be making non equivalent graphs, by using the size argument. Remember it may be None, meaning it will be implied by the parameters shape. In that case you may be rewriting something like normal([0, 1, 2], size=None) as [0, 1, 2] + normal(size=None)

There's a non default PyTensor rewrite that makes size explicit that you may want to use prior to these rewrites. Then make sure these rewrites fail if size is None isinstance(size.type, NoneTypeT). If you do have size it will be all you need

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I'll have to ask you how to add that rewrite into the database and make it go before any reparametrization rewrites

)
from pytensor.tensor.slinalg import cholesky

reparametrization_trick_db = RewriteDatabaseQuery(include=["random_reparametrization_trick"])
Copy link
Member

@ricardoV94 ricardoV94 Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works but it's a bit expensive every time you extend it. You should start with an actual database like SequenceDB, then you can register rewrite phases in it. pymc/logprob/rewriting.py may be a good example.

I'm even surprise that the RewriteDatabaseQuery has a register argument. Feels like bad design

@register_random_reparametrization
@node_rewriter([DirichletRV])
def dirichlet_reparametrization(fgraph, node):
raise NotImplementedError("DirichletRV is not reparametrizable")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doesn't need to block this PR but I know tfp has something for the dirichlet that is one time differentiable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes from the Gamma, so we need to sort that out first

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They might have changed and they just define dirichlet with the normalized batch of gamma variates, which would be differentiable if the gamma also are?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll have a look at that

Comment on lines +163 to +172
next_rng, U = UniformRV()(
zeros_like(alpha),
ones_like(alpha),
rng=rng,
).owner.outputs
next_rng, x = NormalRV()(
zeros_like(c),
ones_like(c),
rng=next_rng,
).owner.outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's wasteful, but to avoid the current grad limitations you can sample all of these once outside of the scan then pass them in as sequences

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because nothing about U or x actually depends on the state of the scan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants