Skip to content

Commit 52c8f8f

Browse files
authored
Merge pull request #660 from jeverink/Fix-HybridGibbs-ordering
Fix hybrid gibbs ordering
2 parents 0ad94a5 + 1fc97e7 commit 52c8f8f

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

cuqi/experimental/mcmc/_gibbs.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class HybridGibbs:
4242
fully stateful at this point. This means samplers like NUTS will lose
4343
their internal state between Gibbs steps.
4444
45+
The order in which the conditionals are sampled is the order of the
46+
variables in the sampling strategy, unless a different sampling order
47+
is specified by the parameter `scan_order`
48+
4549
Parameters
4650
----------
4751
target : cuqi.distribution.JointDistribution
@@ -58,6 +62,11 @@ class HybridGibbs:
5862
will call its step method in each Gibbs step.
5963
Default is 1 for all variables.
6064
65+
scan_order : list or str, *optional*
66+
Order in which the conditional distributions are sampled.
67+
If set to "random", use a random ordering at each step.
68+
If not specified, it will be the order in the sampling_strategy.
69+
6170
callback : callable, optional
6271
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
6372
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
@@ -107,7 +116,7 @@ class HybridGibbs:
107116
108117
"""
109118

110-
def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, callback=None):
119+
def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, scan_order = None, callback=None):
111120

112121
# Store target and allow conditioning to reduce to a single density
113122
self.target = target() # Create a copy of target distribution (to avoid modifying the original)
@@ -121,6 +130,13 @@ def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampl
121130
# Store parameter names
122131
self.par_names = self.target.get_parameter_names()
123132

133+
# Store the scan order
134+
self._scan_order = scan_order
135+
136+
# Check that the parameters of the target align with the sampling_strategy and scan_order
137+
if set(self.par_names) != set(self.scan_order):
138+
raise ValueError("Parameter names in JointDistribution do not equal the names in the scan order.")
139+
124140
# Initialize sampler (after target is set)
125141
self._initialize()
126142

@@ -148,6 +164,16 @@ def _initialize(self):
148164
# Validate all targets for samplers.
149165
self.validate_targets()
150166

167+
@property
168+
def scan_order(self):
169+
if self._scan_order is None:
170+
return list(self.samplers.keys())
171+
if self._scan_order == "random":
172+
arr = list(self.samplers.keys())
173+
np.random.shuffle(arr) # Shuffle works in-place
174+
return arr
175+
return self._scan_order
176+
151177
# ------------ Public methods ------------
152178
def validate_targets(self):
153179
""" Validate each of the conditional targets used in the Gibbs steps """
@@ -217,7 +243,7 @@ def step(self):
217243
""" Sequentially go through all parameters and sample them conditionally on each other """
218244

219245
# Sample from each conditional distribution
220-
for par_name in self.par_names:
246+
for par_name in self.scan_order:
221247

222248
# Set target for current parameter
223249
self._set_target(par_name)

tests/test_implicit_priors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def test_regression_increasing():
210210
posterior = joint(y=y_obs)
211211

212212
sampling_strategy = {
213+
'd': cuqi.experimental.mcmc.Conjugate(),
213214
'x': cuqi.experimental.mcmc.RegularizedLinearRTO(maxit=50, penalty_parameter=20, adaptive = False),
214-
'd': cuqi.experimental.mcmc.Conjugate()
215215
}
216216
sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
217217

@@ -241,8 +241,8 @@ def test_regression_convex():
241241
posterior = joint(y=y_obs)
242242

243243
sampling_strategy = {
244-
'x': cuqi.experimental.mcmc.RegularizedLinearRTO(maxit=50, penalty_parameter=20, adaptive = False),
245-
'd': cuqi.experimental.mcmc.Conjugate()
244+
'd': cuqi.experimental.mcmc.Conjugate(),
245+
'x': cuqi.experimental.mcmc.RegularizedLinearRTO(maxit=50, penalty_parameter=20, adaptive = False)
246246
}
247247
sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
248248

tests/zexperimental/test_mcmc.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1574,4 +1574,26 @@ def test_all_samplers_that_should_be_tested_for_callback_are_in_the_tested_list(
15741574
for cls in callback_testing_sampler_classes:
15751575
assert cls in tested_classes, f"Sampler {cls} is not tested for callback."
15761576

1577-
# ============= End testing sampler callback =============
1577+
# ============= End testing sampler callback =============
1578+
def test_gibbs_random_scan_order():
1579+
target = HybridGibbs_target_1()
1580+
sampling_strategy={
1581+
"x": cuqi.experimental.mcmc.LinearRTO(),
1582+
"s": cuqi.experimental.mcmc.Conjugate(),
1583+
}
1584+
1585+
sampler = cuqi.experimental.mcmc.HybridGibbs(target, sampling_strategy, scan_order='random')
1586+
np.random.seed(0)
1587+
scan_order1 = sampler.scan_order
1588+
scan_order2 = sampler.scan_order
1589+
assert scan_order1 != scan_order2
1590+
1591+
def test_gibbs_scan_order():
1592+
target = HybridGibbs_target_1()
1593+
sampling_strategy={
1594+
"x": cuqi.experimental.mcmc.LinearRTO(),
1595+
"s": cuqi.experimental.mcmc.Conjugate(),
1596+
}
1597+
1598+
sampler = cuqi.experimental.mcmc.HybridGibbs(target, sampling_strategy, scan_order=['x', 's'])
1599+
assert sampler.scan_order == ['x', 's']

0 commit comments

Comments
 (0)