Skip to content

Commit 9674b28

Browse files
committed
Add docs for samplers and improve API
1 parent d8885e7 commit 9674b28

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
pip install hatch
7474
- name: Build package
7575
run: |
76-
pip install jaxlib==0.4.24 jax==0.4.24
76+
pip install jaxlib jax
7777
- name: Run tests
7878
run: |
7979
hatch run test:test

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ description = " Simulation-based inference in JAX"
88
authors = [{name = "Simon Dirmeier", email = "sfyrbnd@pm.me"}]
99
readme = "README.md"
1010
license = "Apache-2.0"
11-
keywords = [ "sbi", "abc", "simulation-based inference", "approximate Bayesian computation"]
11+
keywords = ["sbi", "abc", "simulation-based inference", "approximate Bayesian computation"]
1212
classifiers = [
1313
"Development Status :: 4 - Beta",
1414
"Intended Audience :: Science/Research",

sbijax/_src/npe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def sample_posterior(
352352
proposal_probs = self.prior.log_prob(proposal)
353353
proposal = jax.vmap(lambda x: ravel_pytree(x)[0])(proposal)
354354
else:
355-
proposal_probs = self.prior_log_density_fn(
355+
proposal_probs = self.prior.log_prob(
356356
jax.vmap(unravel_fn)(proposal)
357357
)
358358
if check_proposal_probs:

0 commit comments

Comments
 (0)