@@ -382,7 +382,7 @@ def compute_global_pseudo_lik(self):
382382 return pseudo_y_full , pseudo_var_full
383383
384384 def compute_full_pseudo_lik (self ):
385- nat1lik_full , nat2lik_full = vmap ( self .compute_full_pseudo_nat ) (self .obs_ind )
385+ nat1lik_full , nat2lik_full = self .compute_full_pseudo_nat (self .obs_ind ) # TODO: remove obs_ind
386386 pseudo_var_full = inv_vmap (nat2lik_full + 1e-12 * np .eye (nat2lik_full .shape [1 ]))
387387 pseudo_y_full = pseudo_var_full @ nat1lik_full
388388 return pseudo_y_full , pseudo_var_full
@@ -391,8 +391,8 @@ def compute_full_pseudo_nat(self, batch_ind):
391391 Kuf = self .kernel (self .Z .value , self .X [batch_ind ].reshape (- 1 , 1 )) # only compute log lik for observed values
392392 Kuu = self .kernel (self .Z .value , self .Z .value )
393393 Wuf = solve (Kuu , Kuf ) # conditional mapping, Kuu^-1 Kuf
394- nat1lik_full = Wuf @ self .pseudo_likelihood .nat1 [batch_ind ]. reshape ( - 1 , 1 )
395- nat2lik_full = Wuf @ np . diag ( self .pseudo_likelihood .nat2 [batch_ind ]. reshape ( - 1 )) @ transpose ( Wuf )
394+ nat1lik_full = Wuf . T [..., None ] @ self .pseudo_likelihood .nat1 [batch_ind ]
395+ nat2lik_full = Wuf . T [..., None ] @ self .pseudo_likelihood .nat2 [batch_ind ] @ Wuf . T [:, None ]
396396 return nat1lik_full , nat2lik_full
397397
398398 def compute_kl (self ):
@@ -478,6 +478,23 @@ def conditional_posterior_to_data(self, batch_ind=None, post_mean=None, post_cov
478478 self .Z .value )
479479 return mean_f .reshape (Nbatch , 1 , 1 ), cov_f .reshape (Nbatch , 1 , 1 )
480480
481+ def cavity_distribution (self , batch_ind = None , power = 1. ):
482+ """ Compute the power EP cavity for the given data points """
483+ if batch_ind is None :
484+ batch_ind = np .arange (self .num_data )
485+
486+ nat1lik_full , nat2lik_full = self .compute_full_pseudo_nat (batch_ind )
487+
488+ # then compute the cavity
489+ cavity_mean , cavity_cov = vmap (compute_cavity , [None , None , 0 , 0 , None ])(
490+ self .posterior_mean .value [..., 0 ],
491+ self .posterior_covariance .value ,
492+ nat1lik_full ,
493+ nat2lik_full ,
494+ power
495+ )
496+ return cavity_mean , cavity_cov
497+
481498
482499class MarkovGP (BaseModel ):
483500 """
0 commit comments