Skip to content

Commit 4606a1c

Browse files
author
Alexander Ororbia
committed
patched some tests/syn/neuron components, added sketch of bmm density
1 parent 157102e commit 4606a1c

File tree

9 files changed

+221
-14
lines changed

9 files changed

+221
-14
lines changed

ngclearn/utils/density/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
## point to supported density estimator models
2-
from .gmm import GMM
2+
from .gmm import GMM ## Gaussian mixture
3+
from .bmm import BMM ## Bernoulli mixture
4+

ngclearn/utils/density/bmm.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from jax import numpy as jnp, random, jit, scipy
2+
from functools import partial
3+
import time, sys
4+
import numpy as np
5+
6+
########################################################################################################################
7+
## internal routines for mixture model
8+
########################################################################################################################
9+
10+
@partial(jit, static_argnums=[3])
11+
def _log_bernoulli_pdf(X, p):
12+
"""
13+
Calculates the multivariate Bernoulli log likelihood of a design matrix/dataset `X`, under a given parameter
14+
probability `p`.
15+
16+
Args:
17+
X: a design matrix (dataset) to compute the log likelihood of
18+
19+
mu: a parameter mean vector
20+
21+
Returns:
22+
the log likelihood (scalar) of this design matrix X
23+
"""
24+
D = mu.shape[1] * 1. ## get dimensionality
25+
## x log(mu_k) + (1-x) log(1 - mu_k)
26+
vec_ll = X * jnp.log(p) + (1. - X) * jnp.log(1. - p) ## binary cross-entropy (log Bernoulli)
27+
log_ll = jnp.sum(vec_ll, axis=1, keepdims=True) ## get per-datapoint LL
28+
return log_ll
29+
30+
@jit
31+
def _calc_bernoulli_pdf_vals(X, p):
32+
log_ll = _log_bernoulli_pdf(X, p) ## get log-likelihood
33+
ll = jnp.exp(log_ll) ## likelihood
34+
return log_ll, ll
35+
36+
@jit
37+
def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
38+
## calc new means, responsibilities, and priors given current stats
39+
N = X.shape[0] ## get number of samples
40+
## calc responsibilities
41+
r = (pi * weights)
42+
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
43+
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
44+
## calc weighted means (weighted by responsibilities)
45+
means = jnp.matmul(r.T, X) / jnp.sum(r, axis=0, keepdims=True).T
46+
return means, _pi, r
47+
48+
@partial(jit, static_argnums=[1])
49+
def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
50+
log_pi = jnp.log(pi) ## calc log(prior)
51+
lats = random.categorical(dkey, logits=log_pi, shape=(n_samples, 1)) ## sample components/latents
52+
return lats
53+
54+
@partial(jit, static_argnums=[1])
55+
def _sample_component(dkey, n_samples, mu): ## samples a component (of mixture)
56+
eps = random.bernoulli(dkey, p=mu, shape=(n_samples, mu.shape[1])) ## draw Bernoulli samples
57+
return x_s
58+
59+
########################################################################################################################
60+
61+
class BMM: ## Bernoulli mixture model (mixture-of-Bernoullis)
62+
"""
63+
Implements a Bernoulli mixture model (BMM) -- or mixture of Bernoullis (MoB).
64+
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
65+
learning algorithm and leverages full covariance matrices in the component
66+
multivariate Bernoulli distributions.
67+
68+
Note this is a (JAX) wrapper model that houses the sklearn implementation for learning.
69+
The sampling process has been rewritten to utilize GPU matrix computation.
70+
71+
Args:
72+
K: the number of components/latent variables within this BMM
73+
74+
max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
75+
76+
init_kmeans: <Unsupported>
77+
"""
78+
79+
def __init__(self, K, max_iter=50, init_kmeans=False, key=None):
80+
self.K = K
81+
self.max_iter = int(max_iter)
82+
self.init_kmeans = init_kmeans ## Unsupported currently
83+
self.mu = [] ## component mean parameters
84+
self.pi = None ## prior weight parameters
85+
#self.z_weights = None # variables for parameterizing weights for SGD
86+
self.key = random.PRNGKey(time.time_ns()) if key is None else key
87+
88+
def init(self, X):
89+
"""
90+
Initializes this BMM in accordance to a supplied design matrix.
91+
92+
Args:
93+
X: the design matrix to initialize this BMM to
94+
95+
"""
96+
dim = X.shape[1]
97+
self.key, *skey = random.split(self.key, 3)
98+
self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
99+
ptrs = random.permutation(skey[0], X.shape[0])
100+
for j in range(self.K):
101+
ptr = ptrs[j]
102+
#self.key, *skey = random.split(self.key, 3)
103+
self.mu.append(X[ptr:ptr+1,:] * 0 + (1./(dim * 1.)))
104+
105+
def calc_log_likelihood(self, X):
106+
"""
107+
Calculates the multivariate Bernoulli log likelihood of a design matrix/dataset `X`, under the current
108+
parameters of this Bernoulli mixture.
109+
110+
Args:
111+
X: the design matrix to estimate log likelihood values over under this BMM
112+
113+
Returns:
114+
(column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
115+
"""
116+
ll = 0.
117+
for j in range(self.K):
118+
log_ll_j, ll_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
119+
ll = ll_j + ll
120+
log_ll = jnp.log(ll) ## vector of individual log p(x_n) values
121+
complete_ll = jnp.sum(log_ll) ## complete log-likelihood for design matrix X, i.e., log p(X)
122+
return log_ll, complete_ll
123+
124+
def _E_step(self, X): ## Expectation (E) step, co-routine
125+
weights = []
126+
for j in range(self.K):
127+
log_ll_j, ll_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
128+
weights.append( ll_j )
129+
weights = jnp.concat(weights, axis=1)
130+
return weights ## data-dependent weights (intermediate responsibilities)
131+
132+
def _M_step(self, X, weights): ## Maximization (M) step, co-routine
133+
means, pi, r = _calc_priors_and_means(X, weights, self.pi)
134+
self.pi = pi ## store new prior parameters
135+
# calc weighted covariances
136+
for j in range(self.K):
137+
#r_j = r[:, j:j + 1]
138+
mu_j = means[j:j + 1, :]
139+
self.mu[j] = mu_j ## store new mean(j) parameter
140+
return means, r
141+
142+
def fit(self, X, tol=1e-3, verbose=False):
143+
"""
144+
Run full fitting process of this BMM.
145+
146+
Args:
147+
X: the dataset to fit this BMM to
148+
149+
tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
150+
if tol >= 0. (Default: 1e-3)
151+
152+
verbose: if True, this function will print out per-iteration measurements to I/O
153+
"""
154+
means_prev = jnp.concat(self.mu, axis=0)
155+
for i in range(self.max_iter):
156+
self.update(X) ## carry out one E-step followed by an M-step
157+
means = jnp.concat(self.mu, axis=0)
158+
dom = jnp.linalg.norm(means - means_prev) ## norm of difference-of-means
159+
if verbose:
160+
print(f"{i}: Mean-diff = {dom}")
161+
#print(jnp.linalg.norm(means - means_prev))
162+
if tol >= 0. and dom < tol:
163+
print(f"Converged after {i + 1} iterations.")
164+
break
165+
means_prev = means
166+
167+
def update(self, X):
168+
"""
169+
Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
170+
171+
Args:
172+
X: the dataset / design matrix to fit this BMM to
173+
"""
174+
r_w = self._E_step(X) ## carry out E-step
175+
means, respon = self._M_step(X, r_w) ## carry out M-step
176+
177+
def sample(self, n_samples, mode_j=-1):
178+
"""
179+
Draw samples from the current underlying BMM model
180+
181+
Args:
182+
n_samples: the number of samples to draw from this BMM
183+
184+
mode_j: if >= 0, will only draw samples from a specific component of this BMM
185+
(Default = -1), ignoring the Categorical prior over latent variables/components
186+
187+
Returns:
188+
Design matrix of samples drawn under the distribution defined by this BMM
189+
"""
190+
## sample prior
191+
self.key, *skey = random.split(self.key, 3)
192+
if mode_j >= 0: ## sample from a particular mode / component
193+
mu_j = self.mu[mode_j]
194+
Xs = _sample_component(skey[0], n_samples=n_samples, mu=mu_j)
195+
else: ## sample from full mixture distribution
196+
## sample components/latents
197+
lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
198+
## then sample chosen component Bernoulli
199+
Xs = []
200+
for j in range(self.K):
201+
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
202+
self.key, *skey = random.split(self.key, 3)
203+
x_s = _sample_component(skey[0], n_samples=freq_j, mu=self.mu[j])
204+
Xs.append(x_s)
205+
Xs = jnp.concat(Xs, axis=0)
206+
return Xs

ngclearn/utils/density/gmm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,8 @@ def sample(self, n_samples, mode_j=-1):
260260
Xs = []
261261
for j in range(self.K):
262262
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
263-
## draw unit Gaussian noise
264263
self.key, *skey = random.split(self.key, 3)
265-
x_s = _sample_component(
264+
x_s = _sample_component( ## now physically sample component
266265
skey[0], n_samples=freq_j, mu=self.mu[j], Sigma=self.Sigma[j], assume_diag_cov=self.assume_diag_cov
267266
)
268267
Xs.append(x_s)

tests/components/neurons/graded/test_bernoulliErrorCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def clamp_target(x):
3939
target_xt = jnp.array([[target_seq[0, ts]]])
4040
clamp_target(target_xt)
4141
advance_process.run(t=ts * 1., dt=dt)
42-
outs.append(a.dp.value)
42+
outs.append(a.dp.get())
4343
outs = jnp.concatenate(outs, axis=1)
4444
# print(outs)
4545
## output should equal input

tests/components/neurons/graded/test_gaussianErrorCell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def clamp_target(x):
4444
target_t = jnp.array([[target_seq[0, ts]]])
4545
clamp_target(target_t)
4646
advance_process.run(t=ts * 1., dt=dt)
47-
dmu_outs.append(a.dmu.value)
48-
L_outs.append(a.L.value)
47+
dmu_outs.append(a.dmu.get())
48+
L_outs.append(a.L.get())
4949

5050
dmu_outs = jnp.concatenate(dmu_outs, axis=1)
5151
L_outs = jnp.array(L_outs)[None] # (1, 10)
@@ -58,4 +58,4 @@ def clamp_target(x):
5858
np.testing.assert_allclose(dmu_outs, expected_dmu, atol=1e-5)
5959
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
6060

61-
# test_gaussianErrorCell()
61+
# test_gaussianErrorCell()

tests/components/synapses/hebbian/test_traceSTDPSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
#import ngclearn.utils.weight_distribution as dist
88
from ngclearn.components.synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
99
from numpy.testing import assert_array_equal
1010

tests/components/synapses/modulated/test_MSTDPETSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
#import ngclearn.utils.weight_distribution as dist
88
from ngclearn.components.synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
99
from numpy.testing import assert_array_equal
1010

tests/components/synapses/test_STPDenseSynapse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
from ngclearn.utils.distribution_generator import DistributionGenerator
88
from ngclearn.components.synapses.STPDenseSynapse import STPDenseSynapse
99

1010
def test_STPDenseSynapse1():
@@ -16,7 +16,7 @@ def test_STPDenseSynapse1():
1616
# ---- build a simple Poisson cell system ----
1717
with Context(name) as ctx:
1818
a = STPDenseSynapse(
19-
name="a", shape=(1,1), resources_init=dist.constant(value=1.),key=subkeys[0]
19+
name="a", shape=(1,1), resources_init=DistributionGenerator.constant(value=1.),key=subkeys[0]
2020
)
2121

2222
advance_process = (MethodProcess("advance_proc")

tests/components/synapses/test_exponentialSynapse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
from ngclearn.utils.distribution_generator import DistributionGenerator
88
from ngclearn.components.synapses.exponentialSynapse import ExponentialSynapse
99

1010
def test_exponentialSynapse1():
@@ -19,8 +19,8 @@ def test_exponentialSynapse1():
1919
# ---- build a single exp-synapse system ----
2020
with Context(name) as ctx:
2121
a = ExponentialSynapse(
22-
name="a", shape=(1,1), tau_decay=tau_syn, g_syn_bar=2.4, syn_rest=E_rest, weight_init=dist.constant(value=1.),
23-
key=subkeys[0]
22+
name="a", shape=(1,1), tau_decay=tau_syn, g_syn_bar=2.4, syn_rest=E_rest,
23+
weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
2424
)
2525

2626
advance_process = (MethodProcess("advance_proc")

0 commit comments

Comments
 (0)