Skip to content

Add example autoregressive flow on JAX #1176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

mattwescott
Copy link
Contributor

Motivated by #1169

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@google-cla google-cla bot added the cla: yes Declares that the user has signed CLA label Nov 23, 2020
@mattwescott
Copy link
Contributor Author

This approach is more limited than I'd realized. From @tomhennigan in google-deepmind/dm-haiku#84:

Haiku (and other OO JAX libraries) need you to use wrapped versions of JAX transforms if you are using JAX transforms that close over the usage of Modules without explicitly passing state. In Haiku this means inside your hk.transformed function you need to use hk.jit etc if you are wrapping code that makes use of a module.

For modules that only make use of parameters/rng in Haiku (e.g. if you are using hk.transform and not hk.transform_with_state) there is a "trick" you can use to handle these inner JAX transforms, which is making use of the jax.core.eval_context() context manager inside your modules). This will cause values to not be eagerly "staged out" to the inner transform (unless they have been explicitly passed in to it) and will stop parameters becoming (deeper) tracers.

Here is an updated copy of your notebook with the trick applied: https://colab.research.google.com/gist/tomhennigan/94ce417598baeb0b8db703981ea83992/autoregressive_flow_on_jax-ipynb-txt.ipynb. I've wrapped this up as a decorator (run_in_eval_context) to make the delta from your original code small (defining and using the decorator is the only meaningful difference).

In this example, the loop in MaskedAutoregressiveFlow is now unrolled to avoid any inner transforms. There may be other bijectors with inner transformations that aren't so amenable. The eval_context trick might support more bijectors, but it seems opaque and precludes using state. Manually threading the parameters seems tedious. Supporting Haiku-wrapped transforms within tfp doesn't sit right. There must be a better way. What am I missing? Maybe Oryx supports a cleaner implementation?

@MilesCranmer
Copy link

Also very interested in this! Thanks very much for the implementation @mattwescott.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Declares that the user has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants