Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/infer/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def body_fn(i, z_flat_inner_state):
active, inactive = jax.lax.cond(
i == 0,
lambda x: (x[:split_ind], x[split_ind:]),
lambda x: (x[split_ind:], x[split_ind:]),
lambda x: (x[split_ind:], x[:split_ind]),
z_flat,
)

Expand Down
34 changes: 34 additions & 0 deletions test/infer/test_ensemble_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpyro
import numpyro.distributions as dist
from numpyro.infer import AIES, ESS, MCMC
from numpyro.infer.ensemble import EnsembleSampler, EnsembleSamplerState
from numpyro.infer.initialization import init_to_uniform

numpyro.set_host_device_count(2)
# ---
Expand Down Expand Up @@ -119,3 +121,35 @@ def test_warmup(kernel_cls):
labels = labels_maker()
mcmc.warmup(random.key(2), labels)
mcmc.run(random.key(3), labels)


def test_ensemble_sampler_uses_complementary_halves():
class ToyEnsembleSampler(EnsembleSampler):
def __init__(self):
super().__init__(
potential_fn=lambda z: jnp.array(0.0),
randomize_split=False,
init_strategy=init_to_uniform,
)
self._num_chains = 4

def init_inner_state(self, rng_key):
return jnp.array(0)

def update_active_chains(self, active, inactive, inner_state):
# Encode which half was used as inactive in each sub-iteration.
return inactive + 1.0, inner_state

sampler = ToyEnsembleSampler()
state = EnsembleSamplerState(
# First sub-iteration uses second-half inactive chains [10, 11].
z=jnp.array([[0.0], [1.0], [10.0], [11.0]]),
inner_state=jnp.array(0),
rng_key=random.PRNGKey(0),
)

new_state = sampler.sample(state, model_args=(), model_kwargs={})
# Expected: first two chains get [11, 12] from second iteration using first half [0, 1] as inactive.
# Then last two chains get [12, 13] from first iteration using second half [10, 11] as inactive.
expected = jnp.array([[11.0], [12.0], [12.0], [13.0]])
assert jnp.allclose(new_state.z, expected)
Loading