@@ -36,7 +36,7 @@ def _calc_exponential_pdf_vals(X, p):
3636 ll = jnp .exp (log_ll ) ## likelihood
3737 return log_ll , ll
3838
39- @jit
39+ # @jit
4040def _calc_priors_and_rates (X , weights , pi ): ## M-step co-routine
4141 ## calc new rates, responsibilities, and priors given current stats
4242 N = X .shape [0 ] ## get number of samples
@@ -45,10 +45,20 @@ def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
4545 r = r / jnp .sum (r , axis = 1 , keepdims = True ) ## responsibilities
4646 _pi = jnp .sum (r , axis = 0 , keepdims = True ) / N ## calc new priors
4747 ## calc weighted rates (weighted by responsibilities)
48+
49+ Znum = jnp .sum (r , axis = 0 , keepdims = True )
50+ #print(Znum.shape)
51+ Zden = jnp .matmul (r .T , X )
52+ rates = Znum .T / Zden
53+ #print(Zden.shape)
54+ #exit()
55+ """
4856 Z = jnp.sum(r, axis=0, keepdims=True) ## calc partition function
49- M = (Z > 0. ) * 1.
50- Z = Z * M + (1. - M ) ## we mask out any zero partition function values
51- rates = jnp .matmul (r .T , X ) / Z .T
57+ Ndata = jnp.matmul(r.T, X)
58+ M = (Ndata > 0.) * 1.
59+ Ndata = Ndata * M + (1. - M) ## we mask out division-by-0 cases
60+ rates = Z.T / Ndata
61+ """
5262 return rates , _pi , r
5363
5464@partial (jit , static_argnums = [1 ])
@@ -107,7 +117,7 @@ def init(self, X):
107117 for j in range (self .K ):
108118 ptr = ptrs [j ]
109119 self .key , * skey = random .split (self .key , 3 )
110- eps = random .uniform (skey [0 ], minval = 0. , maxval = 0.5 , shape = (1 , dim )) ## jitter initial rate params
120+ eps = random .uniform (skey [0 ], minval = 0.99 , maxval = 1.01 , shape = (1 , dim )) ## jitter initial rate params
111121 self .rate .append (eps )
112122
113123 def calc_log_likelihood (self , X ):
0 commit comments