From 30e087e2eb1d9613b632e1e383289ddc1b87c96c Mon Sep 17 00:00:00 2001 From: Jack Kohm Date: Sat, 16 May 2026 17:18:38 -0400 Subject: [PATCH] Fix EnsembleSampler second split to use complementary inactive chains The second sub-iteration in EnsembleSampler.sample incorrectly updated the second half of chains against itself instead of the first half. Add a deterministic regression test that fails under the previous split logic and verifies complementary-half updates. --- numpyro/infer/ensemble.py | 2 +- test/infer/test_ensemble_mcmc.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 0f6ceb329..88fd7042a 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -203,7 +203,7 @@ def body_fn(i, z_flat_inner_state): active, inactive = jax.lax.cond( i == 0, lambda x: (x[:split_ind], x[split_ind:]), - lambda x: (x[split_ind:], x[split_ind:]), + lambda x: (x[split_ind:], x[:split_ind]), z_flat, ) diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index 033095c59..cfe7b299a 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -10,6 +10,8 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import AIES, ESS, MCMC +from numpyro.infer.ensemble import EnsembleSampler, EnsembleSamplerState +from numpyro.infer.initialization import init_to_uniform numpyro.set_host_device_count(2) # --- @@ -119,3 +121,35 @@ def test_warmup(kernel_cls): labels = labels_maker() mcmc.warmup(random.key(2), labels) mcmc.run(random.key(3), labels) + + +def test_ensemble_sampler_uses_complementary_halves(): + class ToyEnsembleSampler(EnsembleSampler): + def __init__(self): + super().__init__( + potential_fn=lambda z: jnp.array(0.0), + randomize_split=False, + init_strategy=init_to_uniform, + ) + self._num_chains = 4 + + def init_inner_state(self, rng_key): + return jnp.array(0) + + def update_active_chains(self, active, inactive, inner_state): + # Encode which half was used as inactive in each sub-iteration. + return inactive + 1.0, inner_state + + sampler = ToyEnsembleSampler() + state = EnsembleSamplerState( + # First sub-iteration uses second-half inactive chains [10, 11]. + z=jnp.array([[0.0], [1.0], [10.0], [11.0]]), + inner_state=jnp.array(0), + rng_key=random.PRNGKey(0), + ) + + new_state = sampler.sample(state, model_args=(), model_kwargs={}) + # Expected: first two chains get [11, 12] from second iteration using first half [0, 1] as inactive. + # Then last two chains get [12, 13] from first iteration using second half [10, 11] as inactive. + expected = jnp.array([[11.0], [12.0], [12.0], [13.0]]) + assert jnp.allclose(new_state.z, expected)