Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transforming adaptation with normalizing flows #154

Closed
wants to merge 66 commits into from

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Oct 17, 2024

Experimental new algorithm that uses a normalizing flow instead of a mass matrix.

Set up using pixi:

git clone https://github.com/pymc-devs/nutpie
cd nutpie
git fetch origin pull/154/head:transform
git switch transform

pixi run develop
pixi shell

Gives a shell with an appropriate python setup.

Usage with pymc:

import pymc as pm
import nutpie
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)

with pm.Model() as model:
    log_sd = pm.Normal("log_sd")
    pm.Normal("y", sigma=np.exp(log_sd))

compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")

compiled = (
    compiled
    .with_transform_adapt(
        # Neural network width, default is half the number of model parameters
        nn_width=None,
        # Number of normalizing flow layers
        num_layers=8,
        # Depth of the neural network in each flow layer
        nn_depth=1,
        # Print status update of the optimizer.
        verbose=False,
        # Number of gradients to use in each training phase
        window_size=5000,
        # Learning rate of the optimizer
        learning_rate=1e-3,
        # Print progress bars for the optimization. Very spammy...
        show_progress=False,
        # Number of initial windows with a diagonal mass matrix
        num_diag_windows=10,
    )
)

trace_ = nutpie.sample(
    compiled,
    transform_adapt=True,
    chains=2,
    tune=1000,
    draws=1000,
    cores=1,
    seed=123,
)

Usage with stan:

import pymc as pm
import nutpie
import numpy as np
import jax
import os

os.environ["TBB_CXX_TYPE"] = "clang"
jax.config.update("jax_enable_x64", True)

code = """
parameters {
    real log_sigma;
    real x;
}
model {
    log_sigma ~ normal(0, 1);
    x ~ normal(0, exp(log_sigma));
}
"""


compiled = nutpie.compile_stan_model(code=code)

compiled = (
    compiled
    .with_transform_adapt(
        # Neural network width, default is half the number of model parameters
        nn_width=None,
        # Number of normalizing flow layers
        num_layers=8,
        # Depth of the neural network in each flow layer
        nn_depth=1,
        # Print status update of the optimizer.
        verbose=False,
        # Number of gradients to use in each training phase
        window_size=5000,
        # Learning rate of the optimizer
        learning_rate=1e-3,
        # Print progress bars for the optimization. Very spammy...
        show_progress=False,
        # Number of initial windows with a diagonal mass matrix
        num_diag_windows=10,
    )
)

trace = nutpie.sample(
    compiled,
    transform_adapt=True,
    chains=2,
    tune=1000,
    draws=1000,
    cores=1,
    seed=123,
)

The optimization can be quite expensive computationally (but luckily doen't need any extra gradient evaluations). A GPU is very helpful here. (Jax should pick up a cuda device automatically)

@aseyboldt aseyboldt added help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows labels Oct 17, 2024
@aseyboldt aseyboldt marked this pull request as ready for review February 14, 2025 11:07
@aseyboldt aseyboldt mentioned this pull request Mar 4, 2025
original parameter space to a space where the posterior is closer to a standard
normal distribution. The flow is trained during warmup.

For more information about the algorithm, see the paper todo
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this link be added?


Currently, a lot of time is spent on compiling various parts of the normalizing
flow, and for small models this can take a large amount of the total time.
Hopefully, we will be able to reduce this overhead in the future.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a recommendation be made as to the size of model that warrants consideration of using NF?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really depend on the size of the model.
If it is easy to sample, using a NF isn't worth it, but you can run it on small or large models.

compiled = nutpie.compile_pymc_model(
model, backend="jax", gradient_backend="jax"
)
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps show how to use it via pm.sample as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't pass arguments about the normalizing flow through pm.sample right now...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a small section in the pymc-usage notes about the pm.sample api.

@aseyboldt
Copy link
Member Author

Closing this in favor of #180

@aseyboldt aseyboldt closed this Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants