Skip to content

Commit 157102e

Browse files
author
Alexander Ororbia
committed
wrote gmm density estimator tutorial
1 parent ebebf74 commit 157102e

File tree

6 files changed

+161
-4
lines changed

6 files changed

+161
-4
lines changed
74.8 KB
Loading
71.3 KB
Loading
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Density Modeling and Analysis
2+
3+
NGC-Learn offers some support for density modeling/estimation, which can be particularly useful in analyzing how internal properties of neuronal models organize (e.g., how the distributed representations of a model might cluster into distinct groups/categories) or to draw samples from the underlying generative model implied by a particular neuronal structure (e.g., sampling a predictive coding generative model). Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models, e.g., a Gaussian mixture model (GMM), that might be implied to carry out such tasks. In this small lesson, we will demonstrate how to set up a GMM, fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM.
4+
5+
## Setting Up a Gaussian Mixture Model
6+
7+
Let's say you have a two-dimensional dataset of neural code vectors collected from another model you have simulated -- here, we will artificially synthesize this kind of data in this lesson from an "unobserved" trio of multivariate Gaussians (as was done in the t-SNE tutorial) -- and that, furthermore, you wanted to fit a GMM to these codes to later on sample from their underlying multi-modal distribution.
8+
9+
The following Python code will employ a GMM density estimator for you (including setting up the data generator):
10+
11+
```python
12+
from jax import numpy as jnp, random
13+
from ngclearn.utils.density.gmm import GMM ## pull out the mixture model density estimator
14+
15+
def gen_data(dkey, n_samp_per_mode): ## data generator (or proxy stochastic data generating process)
16+
scale = 0.3
17+
mu1 = jnp.asarray([[2.1, 3.2]]) * scale
18+
cov1 = jnp.eye(mu1.shape[1]) * 0.78 * scale * 0.5
19+
mu2 = jnp.asarray([[-2.8, 2.0]]) * scale
20+
cov2 = jnp.eye(mu2.shape[1]) * 0.52 * scale * 0.5
21+
mu3 = jnp.asarray([[1.2, -2.7]]) * scale
22+
cov3 = jnp.eye(mu3.shape[1]) * 1.2 * scale * 0.5
23+
params = (mu1,cov1 ,mu2,cov2,mu3,cov3)
24+
25+
dkey, *subkeys = random.split(dkey, 7)
26+
samp1 = random.multivariate_normal(subkeys[0], mu1, cov1, shape=(n_samp_per_mode,))
27+
samp2 = random.multivariate_normal(subkeys[0], mu2, cov2, shape=(n_samp_per_mode,))
28+
samp3 = random.multivariate_normal(subkeys[0], mu3, cov3, shape=(n_samp_per_mode,))
29+
X = jnp.concatenate((samp1, samp2, samp3), axis=0)
30+
y1 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[1., 0., 0.]])
31+
y2 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[0., 1., 0.]])
32+
y3 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[0., 0., 1.]])
33+
lab = jnp.concatenate((y1, y2, y3), axis=0) ## one-hot codes
34+
return X, lab, params
35+
36+
## set up the GMM density estimator
37+
key = random.PRNGKey(69)
38+
dkey, _ = random.split(key, 2)
39+
X, y, params = gen_data(key, n_samp_per_mode=200) #400)
40+
41+
n_iter = 30
42+
n_components = 3
43+
model = GMM(K=n_components, max_iter=n_iter, key=dkey)
44+
model.init(X) ## initailize the GMM to dataset X
45+
```
46+
47+
The above will construct a GMM with three components (or latent variables of its own) and be configured to use a maximum of `30` iterations to fit itself to data. Note that the call to `init()` will "shape" the GMM according to the dimensionality of the data and pre-initialize its parameters (i.e., choosing random data vectors to initialize its means).
48+
49+
To fit the GMM itself to your dataset `X`, you will then write the following:
50+
51+
```python
52+
## estimate GMM parameters over dataset via E-M
53+
model.fit(X, tol=1e-3, verbose=True) ## set verbose to `False` to silence the fitting process
54+
```
55+
56+
which should print to I/O something akin to:
57+
58+
```console
59+
0: Mean-diff = 0.8029823303222656
60+
1: Mean-diff = 0.1899024397134781
61+
2: Mean-diff = 0.18127720057964325
62+
3: Mean-diff = 0.15023663640022278
63+
4: Mean-diff = 0.13917091488838196
64+
5: Mean-diff = 0.10519692301750183
65+
6: Mean-diff = 0.05732756853103638
66+
7: Mean-diff = 0.03420640528202057
67+
8: Mean-diff = 0.01907791942358017
68+
9: Mean-diff = 0.009763183072209358
69+
10: Mean-diff = 0.004887263756245375
70+
11: Mean-diff = 0.0024237236939370632
71+
12: Mean-diff = 0.0011952449567615986
72+
13: Mean-diff = 0.0005875130300410092
73+
Converged after 14 iterations.
74+
```
75+
76+
In the above instance, notice that our GMM converged early, reaching a good log likelihood in `14` iterations. We can further calculate our model's log likelihood over the dataset `X` with the following in-built function:
77+
78+
```python
79+
# Calculate the GMM log likelihood
80+
_, logPX = model.calc_log_likelihood(X) ## 1st output is log-lieklihood per data pattern
81+
print(f"log[p(X)] = {logPX} nats")
82+
```
83+
84+
which will print out the following:
85+
86+
```console
87+
log[p(X)] = -423.30889892578125 nats
88+
```
89+
90+
Now, to visualize if our GMM actually capture the underlying multi-modal distribution of our dataset, we may visualize the final GMM with the following plotting code:
91+
92+
```python
93+
import matplotlib.pyplot as plt
94+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
95+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
96+
xx, yy = jnp.meshgrid(jnp.linspace(x_min, x_max, 100), jnp.linspace(y_min, y_max, 100))
97+
Xspace = jnp.c_[xx.ravel(), yy.ravel()]
98+
Z, _ = model.calc_log_likelihood(Xspace) # Get log likelihood (LL)
99+
Z = -Z ## flip sign of LL (to get negative LL)
100+
Z = Z.reshape(xx.shape)
101+
102+
plt.figure(figsize=(8, 6))
103+
plt.scatter(X[:, 0], X[:, 1], c="blue", s=10, alpha=0.7, label='Latent Codes')
104+
plt.contour(xx, yy, Z, levels=jnp.logspace(0, 2, 12), cmap='viridis', alpha=0.8)
105+
plt.colorbar(label='Negative Log Likelihood')
106+
107+
plt.title('GMM Distribution Plot')
108+
plt.xlabel('Latent Dimension 1')
109+
plt.ylabel('Latent Dimension 2')
110+
plt.legend()
111+
plt.grid(True)
112+
plt.savefig("gmm_fit.jpg") #plt.show()
113+
114+
plt.close()
115+
```
116+
117+
which should produce a plot similar to the one below:
118+
119+
<img src="../../images/tutorials/neurocog/gmm_fit.jpg" width="400" />
120+
121+
122+
To draw samples from our fitted/learnt GMM, we may next call its in-built synthesizing routine as follows:
123+
124+
```python
125+
## Examine GMM samples
126+
Xs = model.sample(n_samples=200 * 3) ## draw 600 samples from fitted GMM
127+
```
128+
129+
and then visualize the collected batch of samples with the following plotting code:
130+
131+
```python
132+
133+
plt.figure(figsize=(8, 6))
134+
plt.scatter(Xs[:, 0], Xs[:, 1], c="green", s=10, alpha=0.7, label='Sample Points')
135+
plt.contour(xx, yy, Z, levels=jnp.logspace(0, 2, 12), cmap='viridis', alpha=0.8)
136+
plt.colorbar(label='Negative Log-Likelihood')
137+
plt.title('GMM Samples')
138+
plt.xlabel('Latent Dimension 1')
139+
plt.ylabel('Latent Dimension 2')
140+
plt.grid(True) #plt.show()
141+
plt.savefig("gmm_samples.jpg")
142+
143+
plt.close()
144+
```
145+
146+
which will produce a plot similar to the one below:
147+
148+
<img src="../../images/tutorials/neurocog/gmm_samples.jpg" width="400" />
149+
150+
Notice that the green-colored data points roughly adhere to the contours of the GMM distribution and look much like the original (blue-colored) dataset we collected. In this example scenario, we see that we can successfully learn the density of our latent code dataset, facilitating some level of downstream distributional analysis and generative model sampling.

docs/tutorials/neurocog/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ work towards more advanced concepts.
7373
plotting
7474
metrics
7575
integration
76+
density_modeling

ngclearn/utils/density/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
## point to supported density estimator models
2+
from .gmm import GMM

ngclearn/utils/density/gmm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _sample_component(dkey, n_samples, mu, Sigma, assume_diag_cov=False): ## sam
101101

102102
class GMM: ## Gaussian mixture model (mixture-of-Gaussians)
103103
"""
104-
Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians, MoG.
104+
Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians (MoG).
105105
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
106106
learning algorithm and leverages full covariance matrices in the component
107107
multivariate Gaussians.
@@ -197,7 +197,7 @@ def _M_step(self, X, weights): ## Maximization (M) step, co-routine
197197
self.Sigma[j] = sigma_j ## store new covariance(j) parameter
198198
return means, r
199199

200-
def fit(self, X, tol=1e-3):
200+
def fit(self, X, tol=1e-3, verbose=False):
201201
"""
202202
Run full fitting process of this GMM.
203203
@@ -206,13 +206,18 @@ def fit(self, X, tol=1e-3):
206206
207207
tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
208208
if tol >= 0. (Default: 1e-3)
209+
210+
verbose: if True, this function will print out per-iteration measurements to I/O
209211
"""
210212
means_prev = jnp.concat(self.mu, axis=0)
211213
for i in range(self.max_iter):
212214
self.update(X) ## carry out one E-step followed by an M-step
213215
means = jnp.concat(self.mu, axis=0)
216+
dom = jnp.linalg.norm(means - means_prev) ## norm of difference-of-means
217+
if verbose:
218+
print(f"{i}: Mean-diff = {dom}")
214219
#print(jnp.linalg.norm(means - means_prev))
215-
if tol >= 0. and jnp.linalg.norm(means - means_prev) < tol:
220+
if tol >= 0. and dom < tol:
216221
print(f"Converged after {i + 1} iterations.")
217222
break
218223
means_prev = means
@@ -255,7 +260,6 @@ def sample(self, n_samples, mode_j=-1):
255260
Xs = []
256261
for j in range(self.K):
257262
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
258-
print(freq_j)
259263
## draw unit Gaussian noise
260264
self.key, *skey = random.split(self.key, 3)
261265
x_s = _sample_component(

0 commit comments

Comments
 (0)