|
1 | 1 | # Usage with PyMC models
|
| 2 | + |
| 3 | +This document shows how to use `nutpie` with PyMC models. We will use the |
| 4 | +`pymc` package to define a simple model and sample from it using `nutpie`. |
| 5 | + |
| 6 | +## Installation |
| 7 | + |
| 8 | +The recommended way to install `pymc` is through the `conda` ecosystem. A good |
| 9 | +package manager for conda packages is `pixi`. See for the [pixi |
| 10 | +documentation](https://pixi.sh) for instructions on how to install it. |
| 11 | + |
| 12 | +We create a new project for this example: |
| 13 | + |
| 14 | +```bash |
| 15 | +pixi new pymc-example |
| 16 | +``` |
| 17 | + |
| 18 | +This will create a new directory `pymc-example` with a `pixi.toml` file, that |
| 19 | +you can edit to add meta information. |
| 20 | + |
| 21 | +We then add the `pymc` and `nutpie` packages to the project: |
| 22 | + |
| 23 | +```bash |
| 24 | +cd pymc-example |
| 25 | +pixi add pymc nutpie arviz |
| 26 | +``` |
| 27 | + |
| 28 | +You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. |
| 29 | +Both are excellent tools for working with Python and data science projects. |
| 30 | + |
| 31 | +### Using VSCode |
| 32 | + |
| 33 | +1. Open VSCode. |
| 34 | +2. Open the `pymc-example` directory created earlier. |
| 35 | +3. Create a new file named `model.ipynb`. |
| 36 | +4. Select the pixi kernel to run the code. |
| 37 | + |
| 38 | +### Using JupyterLab |
| 39 | + |
| 40 | +1. Add jupyter labs to the project by running `pixi add jupyterlab`. |
| 41 | +1. Open JupyterLab by running `pixi run jupyter lab` in your terminal. |
| 42 | +3. Create a new Python notebook. |
| 43 | + |
| 44 | +## Defining and Sampling a Simple Model |
| 45 | + |
| 46 | +We will define a simple Bayesian model using `pymc` and sample from it using |
| 47 | +`nutpie`. |
| 48 | + |
| 49 | +### Model Definition |
| 50 | + |
| 51 | +In your `model.ipypy` file or Jupyter notebook, add the following code: |
| 52 | + |
| 53 | +```python |
| 54 | +import pymc as pm |
| 55 | +import nutpie |
| 56 | +import pandas as pd |
| 57 | + |
| 58 | +coords = {"observation": range(3)} |
| 59 | + |
| 60 | +with pm.Model(coords=coords) as model: |
| 61 | + # Prior distributions for the intercept and slope |
| 62 | + intercept = pm.Normal("intercept", mu=0, sigma=1) |
| 63 | + slope = pm.Normal("slope", mu=0, sigma=1) |
| 64 | + |
| 65 | + # Likelihood (sampling distribution) of observations |
| 66 | + x = [1, 2, 3] |
| 67 | + |
| 68 | + mu = intercept + slope * x |
| 69 | + sigma = pm.HalfNormal("sigma", sigma=1) |
| 70 | + y = pm.Normal("y", mu=mu, sigma=sigma, observed=[1, 2, 3], dims="observation") |
| 71 | +``` |
| 72 | + |
| 73 | +### Sampling |
| 74 | + |
| 75 | +We can now compile the model using the numba backend: |
| 76 | + |
| 77 | +```python |
| 78 | +compiled = nutpie.compile_pymc_model(model) |
| 79 | +trace = nutpie.sample(compiled) |
| 80 | +``` |
| 81 | + |
| 82 | +While sampling, nutpie shows a progress bar for each chain. It also includes |
| 83 | +information about how each chain is doing: |
| 84 | + |
| 85 | +- It shows the current number of draws |
| 86 | +- The step size of the integrator (very small stepsizes are typically a bad |
| 87 | + sign) |
| 88 | +- The number of divergences (if there are divergences, that means that nutpie is |
| 89 | + probably not sampling the posterior correctly) |
| 90 | +- The number of gradient evaluation nutpie uses for each draw. Large numbers |
| 91 | + (100 to 1000) are a sign that the parameterization of the model is not ideal, |
| 92 | + and the sampler is very inefficient. |
| 93 | + |
| 94 | +After sampling, this returns an `arviz` InferenceData object that you can use to |
| 95 | +analyze the trace. |
| 96 | + |
| 97 | +For example, we should check the effective sample size: |
| 98 | + |
| 99 | +```python |
| 100 | +import arviz as az |
| 101 | +az.ess(trace) |
| 102 | +``` |
| 103 | + |
| 104 | +and have a look at a trace plot: |
| 105 | + |
| 106 | +```python |
| 107 | +az.plot_trace(trace) |
| 108 | +``` |
| 109 | + |
| 110 | +### Choosing the backend |
| 111 | + |
| 112 | +Right now, we have been using the numba backend. This is the default backend for |
| 113 | +`nutpie`, when sampling from pymc models. It tends to have relatively long |
| 114 | +compilation times, but samples small models very efficiently. For larger models |
| 115 | +the `jax` backend sometimes outperforms `numba`. |
| 116 | + |
| 117 | +First, we need to install the `jax` package: |
| 118 | + |
| 119 | +```bash |
| 120 | +pixi add jax |
| 121 | +``` |
| 122 | + |
| 123 | +We can select the backend by passing the `backend` argument to the `compile_pymc_model`: |
| 124 | + |
| 125 | +```python |
| 126 | +compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") |
| 127 | +trace = nutpie.sample(compiled_jax) |
| 128 | +``` |
| 129 | + |
| 130 | +If you have an nvidia GPU, you can also use the `jax` backend with the `gpu`. We |
| 131 | +will have to install the `jaxlib` package with the `cuda` option |
| 132 | + |
| 133 | +```bash |
| 134 | +pixi add jaxlib --build 'cuda12' |
| 135 | +``` |
| 136 | + |
| 137 | +Restart the kernel and check that the GPU is available: |
| 138 | + |
| 139 | +```python |
| 140 | +import jax |
| 141 | + |
| 142 | +# Should list the cuda device |
| 143 | +jax.devices() |
| 144 | +``` |
| 145 | + |
| 146 | +Sampling again, should now use the GPU, which you can observe by checking the |
| 147 | +GPU usage with `nvidia-smi` or `nvtop`. |
| 148 | + |
| 149 | +### Changing the dataset without recompilation |
| 150 | + |
| 151 | +If you want to use the same model with different datasets, you can modify |
| 152 | +datasets after compilation. Since jax does not like changes in shapes, this is |
| 153 | +only recommended with the numba backend. |
| 154 | + |
| 155 | +First, we define the model, but put our dataset in a `pm.Data` structure: |
| 156 | + |
| 157 | +```python |
| 158 | +with pm.Model(): |
| 159 | + x = pm.Data("x", [1, 2, 3]) |
| 160 | + intercept = pm.Normal("intercept", mu=0, sigma=1) |
| 161 | + slope = pm.Normal("slope", mu=0, sigma=1) |
| 162 | + mu = intercept + slope * x |
| 163 | + sigma = pm.HalfNormal("sigma", sigma=1) |
| 164 | + y = pm.Normal("y", mu=mu, sigma=sigma, observed=[1, 2, 3]) |
| 165 | +``` |
| 166 | + |
| 167 | +We can now compile the model: |
| 168 | + |
| 169 | +```python |
| 170 | +compiled = nutpie.compile_pymc_model(model) |
| 171 | +trace = nutpie.sample(compiled) |
| 172 | +``` |
| 173 | + |
| 174 | +After compilation, we can change the dataset: |
| 175 | + |
| 176 | +```python |
| 177 | +compiled2 = compiled.with_data({"x": [4, 5, 6]}) |
| 178 | +trace2 = nutpie.sample(compiled2) |
| 179 | +``` |
0 commit comments