Skip to content

Commit e21bca4

Browse files
author
Alexander Ororbia
committed
made patches to bmm
1 parent a84a5d8 commit e21bca4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ngclearn/utils/density/bmm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
## internal routines for mixture model
88
########################################################################################################################
99

10-
@partial(jit, static_argnums=[3])
10+
@jit
1111
def _log_bernoulli_pdf(X, p):
1212
"""
1313
Calculates the multivariate Bernoulli log likelihood of a design matrix/dataset `X`, under a given parameter
@@ -21,7 +21,7 @@ def _log_bernoulli_pdf(X, p):
2121
Returns:
2222
the log likelihood (scalar) of this design matrix X
2323
"""
24-
D = mu.shape[1] * 1. ## get dimensionality
24+
#D = X.shape[1] * 1. ## get dimensionality
2525
## x log(mu_k) + (1-x) log(1 - mu_k)
2626
vec_ll = X * jnp.log(p) + (1. - X) * jnp.log(1. - p) ## binary cross-entropy (log Bernoulli)
2727
log_ll = jnp.sum(vec_ll, axis=1, keepdims=True) ## get per-datapoint LL
@@ -99,8 +99,10 @@ def init(self, X):
9999
ptrs = random.permutation(skey[0], X.shape[0])
100100
for j in range(self.K):
101101
ptr = ptrs[j]
102-
#self.key, *skey = random.split(self.key, 3)
103-
self.mu.append(X[ptr:ptr+1,:] * 0 + (1./(dim * 1.)))
102+
self.key, *skey = random.split(self.key, 3)
103+
#self.mu.append(X[ptr:ptr+1,:] * 0 + (1./(dim * 1.)))
104+
eps = random.uniform(skey[0], minval=0., maxval=0.9, shape=(1, dim)) ## jitter initial prob params
105+
self.mu.append(eps)
104106

105107
def calc_log_likelihood(self, X):
106108
"""

0 commit comments

Comments
 (0)