77## internal routines for mixture model
88########################################################################################################################
99
10- @partial ( jit , static_argnums = [ 3 ])
10+ @jit
1111def _log_bernoulli_pdf (X , p ):
1212 """
1313 Calculates the multivariate Bernoulli log likelihood of a design matrix/dataset `X`, under a given parameter
@@ -21,7 +21,7 @@ def _log_bernoulli_pdf(X, p):
2121 Returns:
2222 the log likelihood (scalar) of this design matrix X
2323 """
24- D = mu .shape [1 ] * 1. ## get dimensionality
24+ # D = X .shape[1] * 1. ## get dimensionality
2525 ## x log(mu_k) + (1-x) log(1 - mu_k)
2626 vec_ll = X * jnp .log (p ) + (1. - X ) * jnp .log (1. - p ) ## binary cross-entropy (log Bernoulli)
2727 log_ll = jnp .sum (vec_ll , axis = 1 , keepdims = True ) ## get per-datapoint LL
@@ -99,8 +99,10 @@ def init(self, X):
9999 ptrs = random .permutation (skey [0 ], X .shape [0 ])
100100 for j in range (self .K ):
101101 ptr = ptrs [j ]
102- #self.key, *skey = random.split(self.key, 3)
103- self .mu .append (X [ptr :ptr + 1 ,:] * 0 + (1. / (dim * 1. )))
102+ self .key , * skey = random .split (self .key , 3 )
103+ #self.mu.append(X[ptr:ptr+1,:] * 0 + (1./(dim * 1.)))
104+ eps = random .uniform (skey [0 ], minval = 0. , maxval = 0.9 , shape = (1 , dim )) ## jitter initial prob params
105+ self .mu .append (eps )
104106
105107 def calc_log_likelihood (self , X ):
106108 """
0 commit comments