Skip to content

Commit 30e087e

Browse files
jackkohmamifalk
authored andcommitted
Fix EnsembleSampler second split to use complementary inactive chains
The second sub-iteration in EnsembleSampler.sample incorrectly updated the second half of chains against itself instead of the first half. Add a deterministic regression test that fails under the previous split logic and verifies complementary-half updates.
1 parent 151ed23 commit 30e087e

2 files changed

Lines changed: 35 additions & 1 deletion

File tree

numpyro/infer/ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def body_fn(i, z_flat_inner_state):
203203
active, inactive = jax.lax.cond(
204204
i == 0,
205205
lambda x: (x[:split_ind], x[split_ind:]),
206-
lambda x: (x[split_ind:], x[split_ind:]),
206+
lambda x: (x[split_ind:], x[:split_ind]),
207207
z_flat,
208208
)
209209

test/infer/test_ensemble_mcmc.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpyro
1111
import numpyro.distributions as dist
1212
from numpyro.infer import AIES, ESS, MCMC
13+
from numpyro.infer.ensemble import EnsembleSampler, EnsembleSamplerState
14+
from numpyro.infer.initialization import init_to_uniform
1315

1416
numpyro.set_host_device_count(2)
1517
# ---
@@ -119,3 +121,35 @@ def test_warmup(kernel_cls):
119121
labels = labels_maker()
120122
mcmc.warmup(random.key(2), labels)
121123
mcmc.run(random.key(3), labels)
124+
125+
126+
def test_ensemble_sampler_uses_complementary_halves():
127+
class ToyEnsembleSampler(EnsembleSampler):
128+
def __init__(self):
129+
super().__init__(
130+
potential_fn=lambda z: jnp.array(0.0),
131+
randomize_split=False,
132+
init_strategy=init_to_uniform,
133+
)
134+
self._num_chains = 4
135+
136+
def init_inner_state(self, rng_key):
137+
return jnp.array(0)
138+
139+
def update_active_chains(self, active, inactive, inner_state):
140+
# Encode which half was used as inactive in each sub-iteration.
141+
return inactive + 1.0, inner_state
142+
143+
sampler = ToyEnsembleSampler()
144+
state = EnsembleSamplerState(
145+
# First sub-iteration uses second-half inactive chains [10, 11].
146+
z=jnp.array([[0.0], [1.0], [10.0], [11.0]]),
147+
inner_state=jnp.array(0),
148+
rng_key=random.PRNGKey(0),
149+
)
150+
151+
new_state = sampler.sample(state, model_args=(), model_kwargs={})
152+
# Expected: first two chains get [11, 12] from second iteration using first half [0, 1] as inactive.
153+
# Then last two chains get [12, 13] from first iteration using second half [10, 11] as inactive.
154+
expected = jnp.array([[11.0], [12.0], [12.0], [13.0]])
155+
assert jnp.allclose(new_state.z, expected)

0 commit comments

Comments
 (0)