Skip to content

Commit 1d5b2dd

Browse files
author
William Wilkinson
committed
fix sparse EP cavity - energy still not working
1 parent a65f15a commit 1d5b2dd

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

newt/basemodels.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def compute_global_pseudo_lik(self):
382382
return pseudo_y_full, pseudo_var_full
383383

384384
def compute_full_pseudo_lik(self):
385-
nat1lik_full, nat2lik_full = vmap(self.compute_full_pseudo_nat)(self.obs_ind)
385+
nat1lik_full, nat2lik_full = self.compute_full_pseudo_nat(self.obs_ind) # TODO: remove obs_ind
386386
pseudo_var_full = inv_vmap(nat2lik_full + 1e-12 * np.eye(nat2lik_full.shape[1]))
387387
pseudo_y_full = pseudo_var_full @ nat1lik_full
388388
return pseudo_y_full, pseudo_var_full
@@ -391,8 +391,8 @@ def compute_full_pseudo_nat(self, batch_ind):
391391
Kuf = self.kernel(self.Z.value, self.X[batch_ind].reshape(-1, 1)) # only compute log lik for observed values
392392
Kuu = self.kernel(self.Z.value, self.Z.value)
393393
Wuf = solve(Kuu, Kuf) # conditional mapping, Kuu^-1 Kuf
394-
nat1lik_full = Wuf @ self.pseudo_likelihood.nat1[batch_ind].reshape(-1, 1)
395-
nat2lik_full = Wuf @ np.diag(self.pseudo_likelihood.nat2[batch_ind].reshape(-1)) @ transpose(Wuf)
394+
nat1lik_full = Wuf.T[..., None] @ self.pseudo_likelihood.nat1[batch_ind]
395+
nat2lik_full = Wuf.T[..., None] @ self.pseudo_likelihood.nat2[batch_ind] @ Wuf.T[:, None]
396396
return nat1lik_full, nat2lik_full
397397

398398
def compute_kl(self):
@@ -478,6 +478,23 @@ def conditional_posterior_to_data(self, batch_ind=None, post_mean=None, post_cov
478478
self.Z.value)
479479
return mean_f.reshape(Nbatch, 1, 1), cov_f.reshape(Nbatch, 1, 1)
480480

481+
def cavity_distribution(self, batch_ind=None, power=1.):
482+
""" Compute the power EP cavity for the given data points """
483+
if batch_ind is None:
484+
batch_ind = np.arange(self.num_data)
485+
486+
nat1lik_full, nat2lik_full = self.compute_full_pseudo_nat(batch_ind)
487+
488+
# then compute the cavity
489+
cavity_mean, cavity_cov = vmap(compute_cavity, [None, None, 0, 0, None])(
490+
self.posterior_mean.value[..., 0],
491+
self.posterior_covariance.value,
492+
nat1lik_full,
493+
nat2lik_full,
494+
power
495+
)
496+
return cavity_mean, cavity_cov
497+
481498

482499
class MarkovGP(BaseModel):
483500
"""

newt/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def energy(self, batch_ind=None, cubature=None, power=1.):
276276
"""
277277
if batch_ind is None:
278278
batch_ind = np.arange(self.num_data)
279-
scale = 1
279+
scale = 1.
280280
else:
281281
scale = self.num_data / batch_ind.shape[0]
282282

@@ -318,7 +318,7 @@ def energy(self, batch_ind=None, cubature=None, power=1.):
318318

319319
ep_energy = -(
320320
lZ_post
321-
+ 1 / power * (scale * np.nansum(lZ) - np.nansum(lZ_pseudo))
321+
+ 1. / power * (scale * np.nansum(lZ) - np.nansum(lZ_pseudo))
322322
)
323323

324324
return ep_energy

0 commit comments

Comments
 (0)