Skip to content

Commit d8885e7

Browse files
committed
Add docs for samplers and improve API
1 parent cd215d4 commit d8885e7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sbijax/_src/abc/smc_abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _resample(self, rng_key, particles, log_weights, n_samples):
227227
def _new_log_weights(
228228
self, new_particles, old_particles, old_log_weights, cov_chol_factor
229229
):
230-
prior_log_density = self.prior_log_density_fn(new_particles)
230+
prior_log_density = self.prior.log_prob(new_particles)
231231
K = self._kernel(old_particles, cov_chol_factor)
232232

233233
def _particle_weight(partcl):
@@ -248,7 +248,7 @@ def _kernel(self, mus, cov_chol_factor):
248248
return tfd.MultivariateNormalTriL(loc=mus, scale_tril=cov_chol_factor)
249249

250250
def _perturb(self, rng_key, mus, cov_chol_factor):
251-
_, unravel_fn = ravel_pytree(self.prior_sampler_fn(seed=jr.PRNGKey(0)))
251+
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(0)))
252252
samples = self._kernel(mus, cov_chol_factor).sample(seed=rng_key)
253253
samples = jax.vmap(unravel_fn)(samples)
254254
return samples

0 commit comments

Comments
 (0)