Skip to content

Commit b79c450

Browse files
author
Alexander Ororbia
committed
removed old weight_distribution.py, other cleanup/revisions throughout
1 parent 6e6561e commit b79c450

File tree

23 files changed

+432
-702
lines changed

23 files changed

+432
-702
lines changed
-4.01 KB
Loading
-4.25 KB
Loading

docs/source/ngclearn.utils.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,6 @@ ngclearn.utils.surrogate\_fx module
9090
:undoc-members:
9191
:show-inheritance:
9292

93-
ngclearn.utils.weight\_distribution module
94-
------------------------------------------
95-
96-
.. automodule:: ngclearn.utils.weight_distribution
97-
:members:
98-
:undoc-members:
99-
:show-inheritance:
100-
10193
Module contents
10294
---------------
10395

docs/tutorials/neurocog/density_modeling.md

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Density Modeling and Analysis
22

33
NGC-Learn offers some support for density modeling/estimation, which can be particularly useful in analyzing how internal properties of neuronal models' self-organized cell populations (e.g., how the distributed representations of a model might cluster into distinct groups/categories) or to draw samples from the underlying generative model implied by a particular neuronal structure (e.g., sampling a trained predictive coding generative model).
4-
Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models -- such as a mixture-of-Bernoulli or a mixture-of-Gaussians -- which might be employed to carry out such tasks. In this small lesson, we will demonstrate how to set up a Gaussian mixture model (GMM), fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM.
4+
Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models -- such as mixtures of Bernoullis, Gaussians, and exponentials -- which might be employed to carry out such tasks. In this small lesson, we will demonstrate how to set up a Gaussian mixture model (GMM), fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM.
55

6-
## Setting Up a Gaussian Mixture Model
6+
## Setting Up a Gaussian Mixture Model (GMM)
77

88
Let's say you have a two-dimensional dataset of neural code vectors collected from another model you have simulated -- here, we will artificially synthesize this kind of data in this lesson from an "unobserved" trio of multivariate Gaussians (as was done in the t-SNE tutorial) and pretend that this is a set of collected vector measurements. Furthermore, you decide that, after consideration that your data might follow a multi-modal distribution (and reasonably asssuming that multivariate Gaussians might capture most of the inherent structure/shape), you want to fit a GMM to these codes to later on sample from their underlying multi-modal distribution.
99

@@ -63,44 +63,30 @@ model.fit(X, tol=1e-3, verbose=True) ## set verbose to `False` to silence the fi
6363
which should print to I/O something akin to:
6464

6565
```console
66-
0: Mean-diff = 1.4143142700195312
67-
1: Mean-diff = 0.15272194147109985
68-
2: Mean-diff = 0.1888418346643448
69-
3: Mean-diff = 0.18062230944633484
70-
4: Mean-diff = 0.15196363627910614
71-
5: Mean-diff = 0.1135818138718605
72-
6: Mean-diff = 0.06951556354761124
73-
7: Mean-diff = 0.03664496913552284
74-
8: Mean-diff = 0.026161763817071915
75-
9: Mean-diff = 0.022674376145005226
76-
10: Mean-diff = 0.021674498915672302
77-
11: Mean-diff = 0.02205687016248703
78-
12: Mean-diff = 0.023379826918244362
79-
13: Mean-diff = 0.02553001046180725
80-
14: Mean-diff = 0.028586825355887413
66+
0: Mean-diff = 1.4147894382476807 log(p(X)) = -1706.0753173828125 nats
67+
1: Mean-diff = 0.14663299918174744 log(p(X)) = -1386.569091796875 nats
68+
2: Mean-diff = 0.18331432342529297 log(p(X)) = -1359.6962890625 nats
69+
3: Mean-diff = 0.17693905532360077 log(p(X)) = -1309.736083984375 nats
70+
4: Mean-diff = 0.1494818776845932 log(p(X)) = -1250.130615234375 nats
71+
5: Mean-diff = 0.11344392597675323 log(p(X)) = -1221.0008544921875 nats
72+
6: Mean-diff = 0.07362686842679977 log(p(X)) = -1204.680419921875 nats
73+
7: Mean-diff = 0.03828870505094528 log(p(X)) = -1192.706298828125 nats
74+
8: Mean-diff = 0.025705577805638313 log(p(X)) = -1188.51123046875 nats
75+
9: Mean-diff = 0.021316207945346832 log(p(X)) = -1187.055908203125 nats
76+
10: Mean-diff = 0.019372563809156418 log(p(X)) = -1186.157470703125 nats
77+
11: Mean-diff = 0.018868334591388702 log(p(X)) = -1185.443115234375 nats
8178
...
8279
<shortened for brevity>
8380
...
84-
32: Mean-diff = 0.06849467754364014
85-
33: Mean-diff = 0.06256962567567825
86-
34: Mean-diff = 0.05789890140295029
87-
35: Mean-diff = 0.05557262524962425
88-
36: Mean-diff = 0.05545869469642639
89-
37: Mean-diff = 0.056351397186517715
90-
38: Mean-diff = 0.057266443967819214
91-
39: Mean-diff = 0.05742649361491203
92-
40: Mean-diff = 0.05546746402978897
93-
41: Mean-diff = 0.04826011508703232
94-
42: Mean-diff = 0.03320707008242607
95-
43: Mean-diff = 0.016994504258036613
96-
44: Mean-diff = 0.007737572770565748
97-
45: Mean-diff = 0.0035514419432729483
98-
46: Mean-diff = 0.0016557337949052453
99-
47: Mean-diff = 0.0007792692049406469
100-
Converged after 48 iterations.
81+
46: Mean-diff = 0.017377303913235664 log(p(X)) = -1062.2596435546875 nats
82+
47: Mean-diff = 0.007906327955424786 log(p(X)) = -1060.440185546875 nats
83+
48: Mean-diff = 0.003615213558077812 log(p(X)) = -1060.09130859375 nats
84+
49: Mean-diff = 0.0016773870447650552 log(p(X)) = -1060.0233154296875 nats
85+
50: Mean-diff = 0.0007852672133594751 log(p(X)) = -1060.0093994140625 nats
86+
Converged after 51 iterations.
10187
```
10288

103-
In the above instance, notice that our GMM converged early, reaching a good log likelihood in `48` iterations. We can further calculate our final model's log likelihood over the dataset `X` with the following in-built function:
89+
In the above instance, notice that our GMM converged early, reaching a good, stable log likelihood in `51` iterations. We can further calculate our final model's log likelihood over the dataset `X` with the following in-built function:
10490

10591
```python
10692
# Calculate the GMM log likelihood
@@ -111,10 +97,10 @@ print(f"log[p(X)] = {logPX} nats")
11197
which will print out the following:
11298

11399
```console
114-
log[p(X)] = -423.30889892578125 nats
100+
log[p(X)] = -1060.006591796875 nats
115101
```
116102

117-
(If you add a log-likelihood measurement before you call `.fit()`, you will see that your original log-likelihood is around `-1046.91 nats`.)
103+
(If you add a log-likelihood measurement before you call `.fit()`, you will see that your original log-likelihood is around `-1060.01 nats`.)
118104
Now, to visualize if our GMM actually capture the underlying multi-modal distribution of our dataset, we may visualize the final GMM with the following plotting code:
119105

120106
```python

ngclearn/components/synapses/convolution/convSynapse.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ngclearn import compilable #from ngcsimlib.parser import compilable
33
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
44
from ngcsimlib.logger import info
5-
import ngclearn.utils.weight_distribution as dist
5+
from ngclearn.utils.distribution_generator import DistributionGenerator
66
from ngclearn.components.synapses.convolution.ngcconv import conv2d
77

88
from ngclearn.components.jaxComponent import JaxComponent
@@ -80,7 +80,12 @@ def __init__(
8080

8181
######################### set up compartments ##########################
8282
tmp_key, *subkeys = random.split(self.key.get(), 4)
83-
weights = dist.initialize_params(subkeys[0], filter_init, shape) ## filter tensor
83+
#weights = dist.initialize_params(subkeys[0], filter_init, shape)
84+
if self.filter_init is None:
85+
info(self.name, "is using default weight initializer!")
86+
self.filter_init = DistributionGenerator.uniform(0.025, 0.8)
87+
weights = self.filter_init(shape, subkeys[0]) ## filter tensor
88+
8489
self.batch_size = batch_size # 1
8590
## Compartment setup and shape computation
8691
_x = jnp.zeros((self.batch_size, x_size, x_size, n_in_chan))
@@ -91,10 +96,10 @@ def __init__(
9196
self.outputs = Compartment(jnp.zeros(self.out_shape))
9297
self.weights = Compartment(weights)
9398
if self.bias_init is None:
94-
info(self.name, "is using default bias value of zero (no bias "
95-
"kernel provided)!")
99+
info(self.name, "is using default bias value of zero (no bias kernel provided)!")
96100
self.biases = Compartment(
97-
dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
101+
#dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
102+
self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0
98103
)
99104

100105
@compilable

ngclearn/components/synapses/convolution/deconvSynapse.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ngclearn import compilable #from ngcsimlib.parser import compilable
33
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
44
from ngcsimlib.logger import info
5-
import ngclearn.utils.weight_distribution as dist
5+
from ngclearn.utils.distribution_generator import DistributionGenerator
66
from ngclearn.components.synapses.convolution.ngcconv import deconv2d
77

88
from ngclearn.components.jaxComponent import JaxComponent
@@ -68,8 +68,12 @@ def __init__(
6868

6969
######################### set up compartments ##########################
7070
tmp_key, *subkeys = random.split(self.key.get(), 4)
71-
weights = dist.initialize_params(subkeys[0], filter_init,
72-
shape) ## filter tensor
71+
#weights = dist.initialize_params(subkeys[0], filter_init, shape)
72+
if self.filter_init is None:
73+
info(self.name, "is using default weight initializer!")
74+
self.filter_init = DistributionGenerator.uniform(0.025, 0.8)
75+
weights = self.filter_init(shape, subkeys[0]) ## filter tensor
76+
7377
self.batch_size = batch_size # 1
7478
## Compartment setup and shape computation
7579
_x = jnp.zeros((self.batch_size, x_size, x_size, n_in_chan))
@@ -82,9 +86,10 @@ def __init__(
8286
if self.bias_init is None:
8387
info(self.name, "is using default bias value of zero (no bias "
8488
"kernel provided)!")
85-
self.biases = Compartment(dist.initialize_params(subkeys[2], bias_init,
86-
(1, shape[1]))
87-
if bias_init else 0.0)
89+
self.biases = Compartment(
90+
# dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
91+
self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0
92+
)
8893

8994
@compilable
9095
def advance_state(self):

ngclearn/components/synapses/denseSynapse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ def __init__(
7575
self.weights = Compartment(weights)
7676
## Set up (optional) bias values
7777
if self.bias_init is None:
78-
info(self.name, "is using default bias value of zero (no bias "
79-
"kernel provided)!")
78+
info(self.name, "is using default bias value of zero (no bias kernel provided)!")
8079
self.biases = Compartment(self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0)
8180
# self.biases = Compartment(initialize_params(subkeys[2], bias_init,
8281
# (1, shape[1]))

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from ngcsimlib import deprecate_args
1212

1313
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
14-
def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
15-
prior_type=None, prior_lmbda=0.,
16-
pre_wght=1., post_wght=1.):
14+
def _calc_update(
15+
pre, post, W, w_bound, is_nonnegative=True, signVal=1., prior_type=None, prior_lmbda=0., pre_wght=1.,
16+
post_wght=1.
17+
):
1718
"""
1819
Compute a tensor of adjustments to be applied to a synaptic value matrix.
1920

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,22 @@
55

66
from ngclearn.utils.model_utils import clip, d_clip
77
import jax
8-
import jax.numpy as jnp
9-
import numpy as np
8+
#import numpy as np
109

1110
from ngclearn.components.synapses import DenseSynapse
1211
from ngclearn.utils import tensorstats
1312
from ngclearn.utils.model_utils import create_function
1413

15-
def gaussian_logpdf(event, mean, stddev):
14+
def _gaussian_logpdf(event, mean, stddev):
1615
scale_sqrd = stddev ** 2
1716
log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
1817
quadratic = (event - mean)**2 / scale_sqrd
1918
return - 0.5 * (log_normalizer + quadratic)
2019

2120

22-
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
21+
def _compute_update(
22+
dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev
23+
):
2324
learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
2425
# (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
2526
W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
@@ -37,7 +38,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
3738
sample = jnp.clip(sample, mu_out_min, mu_out_max)
3839
outputs = sample # the actual action that we take
3940
# Compute log probability density of the Gaussian
40-
log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1)
41+
log_prob = _gaussian_logpdf(sample, fx_mean, std).sum(-1)
4142
# Compute objective (negative REINFORCE objective)
4243
objective = (-log_prob * rewards).mean() * 1e-2
4344

@@ -65,7 +66,6 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
6566
return dW, objective, outputs
6667

6768

68-
6969
class REINFORCESynapse(DenseSynapse):
7070
"""
7171
A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse
@@ -122,8 +122,10 @@ def __init__(
122122
) -> None:
123123
# This is because we have weights mu and weight log sigma
124124
input_dim, output_dim = shape
125-
super().__init__(name, (input_dim, output_dim * 2), weight_init, None, resist_scale,
126-
p_conn, batch_size=batch_size, **kwargs)
125+
super().__init__(
126+
name, (input_dim, output_dim * 2), weight_init, None, resist_scale, p_conn,
127+
batch_size=batch_size, **kwargs
128+
)
127129

128130
## Synaptic hyper-parameters
129131
self.shape = shape ## shape of synaptic efficacy matrix
@@ -150,12 +152,8 @@ def __init__(
150152
self.learning_mask = Compartment(jnp.zeros(()))
151153
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
152154

153-
154-
# @transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
155-
# @staticmethod
156155
@compilable
157156
def evolve(self, dt):
158-
159157
# Get compartment values
160158
weights = self.weights.get()
161159
dWeights = self.dWeights.get()
@@ -173,13 +171,13 @@ def evolve(self, dt):
173171
dt, inputs, rewards, self.act_fx, weights, sub_seed, self.mu_act_fx, self.dmu_act_fx, self.mu_out_min, self.mu_out_max, self.scalar_stddev
174172
)
175173
## do a gradient ascent update/shift
176-
weights = (weights + dWeights * self.eta) * self.learning_mask + weights * (1.0 - self.learning_mask) # update the weights only where learning_mask is 1.0
174+
weights = (weights + dWeights * self.eta) * self.learning_mask + weights * (1.0 - self.learning_mask.get()) # update the weights only where learning_mask is 1.0
177175
## enforce non-negativity
178176
eps = 0.0 # 0.01 # 0.001
179177
weights = jnp.clip(weights, eps, self.w_bound - eps) # jnp.abs(w_bound))
180178
step_count += 1
181179
accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * self.decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
182-
step_count = step_count * (1 - self.learning_mask) # reset the step count to 0 when we have learned
180+
step_count = step_count * (1 - self.learning_mask.get()) # reset the step count to 0 when we have learned
183181

184182
# Set updated compartment values
185183
self.weights.set(weights)
@@ -190,8 +188,6 @@ def evolve(self, dt):
190188
self.step_count.set(step_count)
191189
self.seed.set(main_seed)
192190

193-
# @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
194-
# @staticmethod
195191
@compilable
196192
def reset(self):
197193
preVals = jnp.zeros((self.batch_size, self.shape[0]))
@@ -214,7 +210,6 @@ def reset(self):
214210
self.step_count.set(step_count)
215211
self.seed.set(seed)
216212

217-
218213
@classmethod
219214
def help(cls): ## component help function
220215
properties = {

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from ngclearn.utils import tensorstats
1414

1515
# @partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
16-
def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1.,
17-
prior_type=None, prior_lmbda=0.,
18-
pre_wght=1., post_wght=1.):
16+
def _calc_update(
17+
pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1., prior_type=None, prior_lmbda=0., pre_wght=1.,
18+
post_wght=1.
19+
):
1920
"""
2021
Compute a tensor of adjustments to be applied to a synaptic value matrix.
2122
@@ -190,12 +191,15 @@ class HebbianPatchedSynapse(PatchedSynapse):
190191
batch_size: the size of each mini batch
191192
"""
192193

193-
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
194-
block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
195-
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
196-
resist_scale=1., batch_size=1, **kwargs):
197-
super().__init__(name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale,
198-
p_conn, batch_size=batch_size, **kwargs)
194+
def __init__(
195+
self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
196+
block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1., optim_type="sgd",
197+
pre_wght=1., post_wght=1., p_conn=1., resist_scale=1., batch_size=1, **kwargs
198+
):
199+
super().__init__(
200+
name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale, p_conn,
201+
batch_size=batch_size, **kwargs
202+
)
199203

200204
prior_type, prior_lmbda = prior
201205
self.prior_type = prior_type
@@ -338,23 +342,6 @@ def help(cls): ## component help function
338342
"hyperparameters": hyperparams}
339343
return info
340344

341-
342-
343-
def __repr__(self):
344-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
345-
maxlen = max(len(c) for c in comps) + 5
346-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
347-
for c in comps:
348-
stats = tensorstats(getattr(self, c).get())
349-
if stats is not None:
350-
line = [f"{k}: {v}" for k, v in stats.items()]
351-
line = ", ".join(line)
352-
else:
353-
line = "None"
354-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
355-
return lines
356-
357-
358345
if __name__ == '__main__':
359346
from ngcsimlib.context import Context
360347
with Context("Bar") as bar:

0 commit comments

Comments
 (0)