|
| 1 | +# Density Modeling and Analysis |
| 2 | + |
| 3 | +NGC-Learn offers some support for density modeling/estimation, which can be particularly useful in analyzing how internal properties of neuronal models organize (e.g., how the distributed representations of a model might cluster into distinct groups/categories) or to draw samples from the underlying generative model implied by a particular neuronal structure (e.g., sampling a predictive coding generative model). Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models, e.g., a Gaussian mixture model (GMM), that might be implied to carry out such tasks. In this small lesson, we will demonstrate how to set up a GMM, fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM. |
| 4 | + |
| 5 | +## Setting Up a Gaussian Mixture Model |
| 6 | + |
| 7 | +Let's say you have a two-dimensional dataset of neural code vectors collected from another model you have simulated -- here, we will artificially synthesize this kind of data in this lesson from an "unobserved" trio of multivariate Gaussians (as was done in the t-SNE tutorial) -- and that, furthermore, you wanted to fit a GMM to these codes to later on sample from their underlying multi-modal distribution. |
| 8 | + |
| 9 | +The following Python code will employ a GMM density estimator for you (including setting up the data generator): |
| 10 | + |
| 11 | +```python |
| 12 | +from jax import numpy as jnp, random |
| 13 | +from ngclearn.utils.density.gmm import GMM ## pull out the mixture model density estimator |
| 14 | + |
| 15 | +def gen_data(dkey, n_samp_per_mode): ## data generator (or proxy stochastic data generating process) |
| 16 | + scale = 0.3 |
| 17 | + mu1 = jnp.asarray([[2.1, 3.2]]) * scale |
| 18 | + cov1 = jnp.eye(mu1.shape[1]) * 0.78 * scale * 0.5 |
| 19 | + mu2 = jnp.asarray([[-2.8, 2.0]]) * scale |
| 20 | + cov2 = jnp.eye(mu2.shape[1]) * 0.52 * scale * 0.5 |
| 21 | + mu3 = jnp.asarray([[1.2, -2.7]]) * scale |
| 22 | + cov3 = jnp.eye(mu3.shape[1]) * 1.2 * scale * 0.5 |
| 23 | + params = (mu1,cov1 ,mu2,cov2,mu3,cov3) |
| 24 | + |
| 25 | + dkey, *subkeys = random.split(dkey, 7) |
| 26 | + samp1 = random.multivariate_normal(subkeys[0], mu1, cov1, shape=(n_samp_per_mode,)) |
| 27 | + samp2 = random.multivariate_normal(subkeys[0], mu2, cov2, shape=(n_samp_per_mode,)) |
| 28 | + samp3 = random.multivariate_normal(subkeys[0], mu3, cov3, shape=(n_samp_per_mode,)) |
| 29 | + X = jnp.concatenate((samp1, samp2, samp3), axis=0) |
| 30 | + y1 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[1., 0., 0.]]) |
| 31 | + y2 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[0., 1., 0.]]) |
| 32 | + y3 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[0., 0., 1.]]) |
| 33 | + lab = jnp.concatenate((y1, y2, y3), axis=0) ## one-hot codes |
| 34 | + return X, lab, params |
| 35 | + |
| 36 | +## set up the GMM density estimator |
| 37 | +key = random.PRNGKey(69) |
| 38 | +dkey, _ = random.split(key, 2) |
| 39 | +X, y, params = gen_data(key, n_samp_per_mode=200) #400) |
| 40 | + |
| 41 | +n_iter = 30 |
| 42 | +n_components = 3 |
| 43 | +model = GMM(K=n_components, max_iter=n_iter, key=dkey) |
| 44 | +model.init(X) ## initailize the GMM to dataset X |
| 45 | +``` |
| 46 | + |
| 47 | +The above will construct a GMM with three components (or latent variables of its own) and be configured to use a maximum of `30` iterations to fit itself to data. Note that the call to `init()` will "shape" the GMM according to the dimensionality of the data and pre-initialize its parameters (i.e., choosing random data vectors to initialize its means). |
| 48 | + |
| 49 | +To fit the GMM itself to your dataset `X`, you will then write the following: |
| 50 | + |
| 51 | +```python |
| 52 | +## estimate GMM parameters over dataset via E-M |
| 53 | +model.fit(X, tol=1e-3, verbose=True) ## set verbose to `False` to silence the fitting process |
| 54 | +``` |
| 55 | + |
| 56 | +which should print to I/O something akin to: |
| 57 | + |
| 58 | +```console |
| 59 | +0: Mean-diff = 0.8029823303222656 |
| 60 | +1: Mean-diff = 0.1899024397134781 |
| 61 | +2: Mean-diff = 0.18127720057964325 |
| 62 | +3: Mean-diff = 0.15023663640022278 |
| 63 | +4: Mean-diff = 0.13917091488838196 |
| 64 | +5: Mean-diff = 0.10519692301750183 |
| 65 | +6: Mean-diff = 0.05732756853103638 |
| 66 | +7: Mean-diff = 0.03420640528202057 |
| 67 | +8: Mean-diff = 0.01907791942358017 |
| 68 | +9: Mean-diff = 0.009763183072209358 |
| 69 | +10: Mean-diff = 0.004887263756245375 |
| 70 | +11: Mean-diff = 0.0024237236939370632 |
| 71 | +12: Mean-diff = 0.0011952449567615986 |
| 72 | +13: Mean-diff = 0.0005875130300410092 |
| 73 | +Converged after 14 iterations. |
| 74 | +``` |
| 75 | + |
| 76 | +In the above instance, notice that our GMM converged early, reaching a good log likelihood in `14` iterations. We can further calculate our model's log likelihood over the dataset `X` with the following in-built function: |
| 77 | + |
| 78 | +```python |
| 79 | +# Calculate the GMM log likelihood |
| 80 | +_, logPX = model.calc_log_likelihood(X) ## 1st output is log-lieklihood per data pattern |
| 81 | +print(f"log[p(X)] = {logPX} nats") |
| 82 | +``` |
| 83 | + |
| 84 | +which will print out the following: |
| 85 | + |
| 86 | +```console |
| 87 | +log[p(X)] = -423.30889892578125 nats |
| 88 | +``` |
| 89 | + |
| 90 | +Now, to visualize if our GMM actually capture the underlying multi-modal distribution of our dataset, we may visualize the final GMM with the following plotting code: |
| 91 | + |
| 92 | +```python |
| 93 | +import matplotlib.pyplot as plt |
| 94 | +x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 |
| 95 | +y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 |
| 96 | +xx, yy = jnp.meshgrid(jnp.linspace(x_min, x_max, 100), jnp.linspace(y_min, y_max, 100)) |
| 97 | +Xspace = jnp.c_[xx.ravel(), yy.ravel()] |
| 98 | +Z, _ = model.calc_log_likelihood(Xspace) # Get log likelihood (LL) |
| 99 | +Z = -Z ## flip sign of LL (to get negative LL) |
| 100 | +Z = Z.reshape(xx.shape) |
| 101 | + |
| 102 | +plt.figure(figsize=(8, 6)) |
| 103 | +plt.scatter(X[:, 0], X[:, 1], c="blue", s=10, alpha=0.7, label='Latent Codes') |
| 104 | +plt.contour(xx, yy, Z, levels=jnp.logspace(0, 2, 12), cmap='viridis', alpha=0.8) |
| 105 | +plt.colorbar(label='Negative Log Likelihood') |
| 106 | + |
| 107 | +plt.title('GMM Distribution Plot') |
| 108 | +plt.xlabel('Latent Dimension 1') |
| 109 | +plt.ylabel('Latent Dimension 2') |
| 110 | +plt.legend() |
| 111 | +plt.grid(True) |
| 112 | +plt.savefig("gmm_fit.jpg") #plt.show() |
| 113 | + |
| 114 | +plt.close() |
| 115 | +``` |
| 116 | + |
| 117 | +which should produce a plot similar to the one below: |
| 118 | + |
| 119 | +<img src="../../images/tutorials/neurocog/gmm_fit.jpg" width="400" /> |
| 120 | + |
| 121 | + |
| 122 | +To draw samples from our fitted/learnt GMM, we may next call its in-built synthesizing routine as follows: |
| 123 | + |
| 124 | +```python |
| 125 | +## Examine GMM samples |
| 126 | +Xs = model.sample(n_samples=200 * 3) ## draw 600 samples from fitted GMM |
| 127 | +``` |
| 128 | + |
| 129 | +and then visualize the collected batch of samples with the following plotting code: |
| 130 | + |
| 131 | +```python |
| 132 | + |
| 133 | +plt.figure(figsize=(8, 6)) |
| 134 | +plt.scatter(Xs[:, 0], Xs[:, 1], c="green", s=10, alpha=0.7, label='Sample Points') |
| 135 | +plt.contour(xx, yy, Z, levels=jnp.logspace(0, 2, 12), cmap='viridis', alpha=0.8) |
| 136 | +plt.colorbar(label='Negative Log-Likelihood') |
| 137 | +plt.title('GMM Samples') |
| 138 | +plt.xlabel('Latent Dimension 1') |
| 139 | +plt.ylabel('Latent Dimension 2') |
| 140 | +plt.grid(True) #plt.show() |
| 141 | +plt.savefig("gmm_samples.jpg") |
| 142 | + |
| 143 | +plt.close() |
| 144 | +``` |
| 145 | + |
| 146 | +which will produce a plot similar to the one below: |
| 147 | + |
| 148 | +<img src="../../images/tutorials/neurocog/gmm_samples.jpg" width="400" /> |
| 149 | + |
| 150 | +Notice that the green-colored data points roughly adhere to the contours of the GMM distribution and look much like the original (blue-colored) dataset we collected. In this example scenario, we see that we can successfully learn the density of our latent code dataset, facilitating some level of downstream distributional analysis and generative model sampling. |
0 commit comments