You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/neurocog/density_modeling.md
+56-28Lines changed: 56 additions & 28 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,12 +1,13 @@
1
1
# Density Modeling and Analysis
2
2
3
-
NGC-Learn offers some support for density modeling/estimation, which can be particularly useful in analyzing how internal properties of neuronal models organize (e.g., how the distributed representations of a model might cluster into distinct groups/categories) or to draw samples from the underlying generative model implied by a particular neuronal structure (e.g., sampling a predictive coding generative model). Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models, e.g., a Gaussian mixture model (GMM), that might be implied to carry out such tasks. In this small lesson, we will demonstrate how to set up a GMM, fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM.
3
+
NGC-Learn offers some support for density modeling/estimation, which can be particularly useful in analyzing how internal properties of neuronal models' self-organized cell populations (e.g., how the distributed representations of a model might cluster into distinct groups/categories) or to draw samples from the underlying generative model implied by a particular neuronal structure (e.g., sampling a trained predictive coding generative model).
4
+
Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models -- such as a mixture-of-Bernoulli or a mixture-of-Gaussians -- which might be employed to carry out such tasks. In this small lesson, we will demonstrate how to set up a Gaussian mixture model (GMM), fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM.
4
5
5
6
## Setting Up a Gaussian Mixture Model
6
7
7
-
Let's say you have a two-dimensional dataset of neural code vectors collected from another model you have simulated -- here, we will artificially synthesize this kind of data in this lesson from an "unobserved" trio of multivariate Gaussians (as was done in the t-SNE tutorial) -- and that, furthermore, you wanted to fit a GMM to these codes to later on sample from their underlying multi-modal distribution.
8
+
Let's say you have a two-dimensional dataset of neural code vectors collected from another model you have simulated -- here, we will artificially synthesize this kind of data in this lesson from an "unobserved" trio of multivariate Gaussians (as was done in the t-SNE tutorial) and pretend that this is a set of collected vector measurements. Furthermore, you decide that, after consideration that your data might follow a multi-modal distribution (and reasonably asssuming that multivariate Gaussians might capture most of the inherent structure/shape), you want to fit a GMM to these codes to later on sample from their underlying multi-modal distribution.
8
9
9
-
The following Python code will employ a GMM density estimator for you (including setting up the data generator):
10
+
The following Python code will employ an NGC-Learn-in-built GMM density estimator for you (including setting up the data generator):
10
11
11
12
```python
12
13
from jax import numpy as jnp, random
@@ -24,27 +25,33 @@ def gen_data(dkey, n_samp_per_mode): ## data generator (or proxy stochastic data
X, y, params = gen_data(key, n_samp_per_mode=200) #400)
45
+
dkey, *skey= random.split(key, 3)
46
+
X, y, params = gen_data(key, n_samp_per_mode=200) ## X is your "vector dataset"
40
47
41
-
n_iter =30
42
-
n_components =3
43
-
model = GMM(K=n_components, max_iter=n_iter, key=dkey)
48
+
n_iter =100## maximum number of iterations to fit GMM to data
49
+
n_components =3## number of mixture components w/in GMM
50
+
model = GMM(K=n_components, max_iter=n_iter, key=skey[0])
44
51
model.init(X) ## initailize the GMM to dataset X
45
52
```
46
53
47
-
The above will construct a GMM with three components (or latent variables of its own) and be configured to use a maximum of `30` iterations to fit itself to data. Note that the call to `init()` will "shape" the GMM according to the dimensionality of the data and pre-initialize its parameters (i.e., choosing random data vectors to initialize its means).
54
+
The above will construct a GMM with three components (or latent variables of its own) and be configured to use a maximum of `100` iterations to fit itself to data. Note that the call to `init()` will "shape" the GMM according to the dimensionality of the data and pre-initialize its parameters (i.e., choosing random data vectors to initialize its means).
48
55
49
56
To fit the GMM itself to your dataset `X`, you will then write the following:
50
57
@@ -56,28 +63,48 @@ model.fit(X, tol=1e-3, verbose=True) ## set verbose to `False` to silence the fi
56
63
which should print to I/O something akin to:
57
64
58
65
```console
59
-
0: Mean-diff = 0.8029823303222656
60
-
1: Mean-diff = 0.1899024397134781
61
-
2: Mean-diff = 0.18127720057964325
62
-
3: Mean-diff = 0.15023663640022278
63
-
4: Mean-diff = 0.13917091488838196
64
-
5: Mean-diff = 0.10519692301750183
65
-
6: Mean-diff = 0.05732756853103638
66
-
7: Mean-diff = 0.03420640528202057
67
-
8: Mean-diff = 0.01907791942358017
68
-
9: Mean-diff = 0.009763183072209358
69
-
10: Mean-diff = 0.004887263756245375
70
-
11: Mean-diff = 0.0024237236939370632
71
-
12: Mean-diff = 0.0011952449567615986
72
-
13: Mean-diff = 0.0005875130300410092
73
-
Converged after 14 iterations.
66
+
0: Mean-diff = 1.4143142700195312
67
+
1: Mean-diff = 0.15272194147109985
68
+
2: Mean-diff = 0.1888418346643448
69
+
3: Mean-diff = 0.18062230944633484
70
+
4: Mean-diff = 0.15196363627910614
71
+
5: Mean-diff = 0.1135818138718605
72
+
6: Mean-diff = 0.06951556354761124
73
+
7: Mean-diff = 0.03664496913552284
74
+
8: Mean-diff = 0.026161763817071915
75
+
9: Mean-diff = 0.022674376145005226
76
+
10: Mean-diff = 0.021674498915672302
77
+
11: Mean-diff = 0.02205687016248703
78
+
12: Mean-diff = 0.023379826918244362
79
+
13: Mean-diff = 0.02553001046180725
80
+
14: Mean-diff = 0.028586825355887413
81
+
...
82
+
<shortened for brevity>
83
+
...
84
+
32: Mean-diff = 0.06849467754364014
85
+
33: Mean-diff = 0.06256962567567825
86
+
34: Mean-diff = 0.05789890140295029
87
+
35: Mean-diff = 0.05557262524962425
88
+
36: Mean-diff = 0.05545869469642639
89
+
37: Mean-diff = 0.056351397186517715
90
+
38: Mean-diff = 0.057266443967819214
91
+
39: Mean-diff = 0.05742649361491203
92
+
40: Mean-diff = 0.05546746402978897
93
+
41: Mean-diff = 0.04826011508703232
94
+
42: Mean-diff = 0.03320707008242607
95
+
43: Mean-diff = 0.016994504258036613
96
+
44: Mean-diff = 0.007737572770565748
97
+
45: Mean-diff = 0.0035514419432729483
98
+
46: Mean-diff = 0.0016557337949052453
99
+
47: Mean-diff = 0.0007792692049406469
100
+
Converged after 48 iterations.
74
101
```
75
102
76
-
In the above instance, notice that our GMM converged early, reaching a good log likelihood in `14` iterations. We can further calculate our model's log likelihood over the dataset `X` with the following in-built function:
103
+
In the above instance, notice that our GMM converged early, reaching a good log likelihood in `48` iterations. We can further calculate our final model's log likelihood over the dataset `X` with the following in-built function:
77
104
78
105
```python
79
106
# Calculate the GMM log likelihood
80
-
_, logPX = model.calc_log_likelihood(X) ## 1st output is log-lieklihood per data pattern
107
+
_, logPX = model.calc_log_likelihood(X) ## 1st output is log-likelihood per data pattern
81
108
print(f"log[p(X)] = {logPX} nats")
82
109
```
83
110
@@ -87,6 +114,7 @@ which will print out the following:
87
114
log[p(X)] = -423.30889892578125 nats
88
115
```
89
116
117
+
(If you add a log-likelihood measurement before you call `.fit()`, you will see that your original log-likelihood is around `-1046.91 nats`.)
90
118
Now, to visualize if our GMM actually capture the underlying multi-modal distribution of our dataset, we may visualize the final GMM with the following plotting code:
0 commit comments