Skip to content

Commit 519896e

Browse files
author
Alexander Ororbia
committed
minor edits to emm
1 parent 02e906e commit 519896e

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

ngclearn/utils/density/exponentialMixture.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _calc_exponential_pdf_vals(X, p):
3636
ll = jnp.exp(log_ll) ## likelihood
3737
return log_ll, ll
3838

39-
@jit
39+
#@jit
4040
def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
4141
## calc new rates, responsibilities, and priors given current stats
4242
N = X.shape[0] ## get number of samples
@@ -45,10 +45,20 @@ def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
4545
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
4646
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
4747
## calc weighted rates (weighted by responsibilities)
48+
49+
Znum = jnp.sum(r, axis=0, keepdims=True)
50+
#print(Znum.shape)
51+
Zden = jnp.matmul(r.T, X)
52+
rates = Znum.T/Zden
53+
#print(Zden.shape)
54+
#exit()
55+
"""
4856
Z = jnp.sum(r, axis=0, keepdims=True) ## calc partition function
49-
M = (Z > 0.) * 1.
50-
Z = Z * M + (1. - M) ## we mask out any zero partition function values
51-
rates = jnp.matmul(r.T, X) / Z.T
57+
Ndata = jnp.matmul(r.T, X)
58+
M = (Ndata > 0.) * 1.
59+
Ndata = Ndata * M + (1. - M) ## we mask out division-by-0 cases
60+
rates = Z.T / Ndata
61+
"""
5262
return rates, _pi, r
5363

5464
@partial(jit, static_argnums=[1])
@@ -107,7 +117,7 @@ def init(self, X):
107117
for j in range(self.K):
108118
ptr = ptrs[j]
109119
self.key, *skey = random.split(self.key, 3)
110-
eps = random.uniform(skey[0], minval=0., maxval=0.5, shape=(1, dim)) ## jitter initial rate params
120+
eps = random.uniform(skey[0], minval=0.99, maxval=1.01, shape=(1, dim)) ## jitter initial rate params
111121
self.rate.append(eps)
112122

113123
def calc_log_likelihood(self, X):

0 commit comments

Comments
 (0)