From 1946c03606f480588c7ba16fb035fe30f917d745 Mon Sep 17 00:00:00 2001 From: Jack Kohm Date: Mon, 9 Feb 2026 21:17:49 -0700 Subject: [PATCH 1/2] Fix EnsembleSampler second split to use complementary inactive chains. The second sub-iteration in EnsembleSampler.sample incorrectly updated the second half of chains against itself instead of the first half. Add a deterministic regression test that fails under the previous split logic and verifies complementary-half updates. --- numpyro/infer/ensemble.py | 2 +- test/infer/test_ensemble_mcmc.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index d12689c84..3e630e021 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -203,7 +203,7 @@ def body_fn(i, z_flat_inner_state): active, inactive = jax.lax.cond( i == 0, lambda x: (x[:split_ind], x[split_ind:]), - lambda x: (x[split_ind:], x[split_ind:]), + lambda x: (x[split_ind:], x[:split_ind]), z_flat, ) diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index 1b9074408..6a8af0655 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -10,6 +10,8 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import AIES, ESS, MCMC +from numpyro.infer.ensemble import EnsembleSampler, EnsembleSamplerState +from numpyro.infer.initialization import init_to_uniform numpyro.set_host_device_count(2) # --- @@ -119,3 +121,32 @@ def test_warmup(kernel_cls): labels = labels_maker() mcmc.warmup(random.PRNGKey(2), labels) mcmc.run(random.PRNGKey(3), labels) + + +def test_ensemble_sampler_uses_complementary_halves(): + class ToyEnsembleSampler(EnsembleSampler): + def __init__(self): + super().__init__( + potential_fn=lambda z: jnp.array(0.0), + randomize_split=False, + init_strategy=init_to_uniform, + ) + self._num_chains = 4 + + def init_inner_state(self, rng_key): + return jnp.array(0) + + def update_active_chains(self, active, inactive, inner_state): + # Encode which half was used as inactive in each sub-iteration. + return inactive + 1.0, inner_state + + sampler = ToyEnsembleSampler() + state = EnsembleSamplerState( + z=jnp.array([[0.0], [1.0], [10.0], [11.0]]), + inner_state=jnp.array(0), + rng_key=random.PRNGKey(0), + ) + + new_state = sampler.sample(state, model_args=(), model_kwargs={}) + expected = jnp.array([[11.0], [12.0], [12.0], [13.0]]) + assert jnp.allclose(new_state.z, expected) From 39346fec5edf1c3d4eb9d7df41d91a38f6681e8e Mon Sep 17 00:00:00 2001 From: Jack Kohm Date: Tue, 10 Feb 2026 17:19:26 -0700 Subject: [PATCH 2/2] Clarify complementary-split regression test fixture --- test/infer/test_ensemble_mcmc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index 6a8af0655..24d34a5d8 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -142,7 +142,8 @@ def update_active_chains(self, active, inactive, inner_state): sampler = ToyEnsembleSampler() state = EnsembleSamplerState( - z=jnp.array([[0.0], [1.0], [10.0], [11.0]]), + # First sub-iteration uses second-half inactive chains [10, 11]. + z=jnp.array([[0.0], [0.0], [10.0], [11.0]]), inner_state=jnp.array(0), rng_key=random.PRNGKey(0), )