Skip to content

Commit 0ad94a5

Browse files
authored
Merge pull request #630 from jeverink/automated_sampler_suggestion
Robust automatic sampler selection/suggestion
2 parents b9a5dd3 + d5d6e28 commit 0ad94a5

File tree

7 files changed

+334
-76
lines changed

7 files changed

+334
-76
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ dist/
3535

3636
# Ignore certain files from demos
3737
*/CUQI_samples/*
38+
*.ipynb

cuqi/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from . import mcmc
33
from . import algebra
44
from . import geometry
5+
from ._recommender import SamplerRecommender

cuqi/experimental/_recommender.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import cuqi
2+
import inspect
3+
import numpy as np
4+
5+
# This import makes suggest_sampler easier to read
6+
import cuqi.experimental.mcmc as samplers
7+
8+
9+
class SamplerRecommender(object):
10+
"""
11+
This class can be used to automatically choose a sampler.
12+
13+
Parameters
14+
----------
15+
target: Density or JointDistribution
16+
Distribution to get sampler recommendations for.
17+
18+
exceptions: list[cuqi.experimental.mcmc.Sampler], *optional*
19+
Samplers not to be recommended.
20+
"""
21+
22+
def __init__(self, target:cuqi.density.Density, exceptions = []):
23+
self._target = target
24+
self._exceptions = exceptions
25+
self._create_ordering()
26+
27+
@property
28+
def target(self) -> cuqi.density.Density:
29+
""" Return the target Distribution. """
30+
return self._target
31+
32+
@target.setter
33+
def target(self, value:cuqi.density.Density):
34+
""" Set the target Distribution. Runs validation of the target. """
35+
if value is None:
36+
raise ValueError("Target needs to be of type cuqi.density.Density.")
37+
self._target = value
38+
39+
def _create_ordering(self):
40+
"""
41+
Every element in the ordering consists of a tuple:
42+
(
43+
Sampler: Class
44+
boolean: additional conditions on the target
45+
parameters: additional parameters to be passed to the sampler once initialized
46+
)
47+
"""
48+
number_of_components = np.sum(self._target.dim)
49+
50+
self._ordering = [
51+
# Direct and Conjugate samplers
52+
(samplers.Direct, True, {}),
53+
(samplers.Conjugate, True, {}),
54+
(samplers.ConjugateApprox, True, {}),
55+
# Specialized samplers
56+
(samplers.LinearRTO, True, {}),
57+
(samplers.RegularizedLinearRTO, True, {}),
58+
(samplers.UGLA, True, {}),
59+
# Gradient.based samplers (Hamiltonian and Langevin)
60+
(samplers.NUTS, True, {}),
61+
(samplers.MALA, True, {}),
62+
(samplers.ULA, True, {}),
63+
# Gibbs and Componentwise samplers
64+
(samplers.HybridGibbs, True, {"sampling_strategy" : self.recommend_HybridGibbs_sampling_strategy(as_string = False)}),
65+
(samplers.CWMH, number_of_components <= 100, {"scale" : 0.05*np.ones(number_of_components),
66+
"initial_point" : 0.5*np.ones(number_of_components)}),
67+
# Proposal based samplers
68+
(samplers.PCN, True, {"scale" : 0.02}),
69+
(samplers.MH, number_of_components <= 1000, {}),
70+
]
71+
72+
@property
73+
def ordering(self):
74+
""" Returns the ordered list of recommendation rules used by the recommender. """
75+
return self._ordering
76+
77+
def valid_samplers(self, as_string = True):
78+
"""
79+
Finds all possible samplers that can be used for sampling from the target distribution.
80+
81+
Parameters
82+
----------
83+
84+
as_string : boolean
85+
Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
86+
87+
"""
88+
89+
all_samplers = [(name, cls) for name, cls in inspect.getmembers(cuqi.experimental.mcmc, inspect.isclass) if issubclass(cls, cuqi.experimental.mcmc.Sampler)]
90+
valid_samplers = []
91+
92+
for name, sampler in all_samplers:
93+
try:
94+
sampler(self.target)
95+
valid_samplers += [name if as_string else sampler]
96+
except:
97+
pass
98+
99+
# Need a separate case for HybridGibbs
100+
if self.valid_HybridGibbs_sampling_strategy() is not None:
101+
valid_samplers += [cuqi.experimental.mcmc.HybridGibbs.__name__ if as_string else cuqi.experimental.mcmc.HybridGibbs]
102+
103+
return valid_samplers
104+
105+
106+
def valid_HybridGibbs_sampling_strategy(self, as_string = True):
107+
"""
108+
Find all possible sampling strategies to be used with the HybridGibbs sampler.
109+
Returns None if no sampler could be suggested for at least one conditional distribution.
110+
111+
Parameters
112+
----------
113+
114+
as_string : boolean
115+
Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
116+
117+
118+
"""
119+
120+
if not isinstance(self.target, cuqi.distribution.JointDistribution):
121+
return None
122+
123+
par_names = self.target.get_parameter_names()
124+
125+
valid_samplers = dict()
126+
for par_name in par_names:
127+
conditional_params = {par_name_: np.ones(self.target.dim[i]) for i, par_name_ in enumerate(par_names) if par_name_ != par_name}
128+
conditional = self.target(**conditional_params)
129+
130+
recommender = SamplerRecommender(conditional)
131+
samplers = recommender.valid_samplers(as_string)
132+
if len(samplers) == 0:
133+
return None
134+
135+
valid_samplers[par_name] = samplers
136+
137+
return valid_samplers
138+
139+
140+
def recommend(self, as_string = False):
141+
"""
142+
Suggests a possible sampler that can be used for sampling from the target distribution.
143+
Return None if no sampler could be suggested.
144+
145+
Parameters
146+
----------
147+
148+
as_string : boolean
149+
Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
150+
151+
"""
152+
153+
valid_samplers = self.valid_samplers(as_string = False)
154+
155+
for suggestion, flag, values in self._ordering:
156+
if flag and (suggestion in valid_samplers) and (suggestion not in self._exceptions):
157+
# Sampler found
158+
if as_string:
159+
return suggestion.__name__
160+
else:
161+
return suggestion(self.target, **values)
162+
163+
# No sampler can be suggested
164+
raise ValueError("Cannot suggest any sampler. Either the provided distribution is incorrectly defined or there are too many exceptions provided.")
165+
166+
def recommend_HybridGibbs_sampling_strategy(self, as_string = False):
167+
"""
168+
Suggests a possible sampling strategy to be used with the HybridGibbs sampler.
169+
Returns None if no sampler could be suggested for at least one conditional distribution.
170+
171+
Parameters
172+
----------
173+
174+
target : `cuqi.distribution.JointDistribution`
175+
The target distribution get a sampling strategy for.
176+
177+
as_string : boolean
178+
Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
179+
180+
"""
181+
182+
if not isinstance(self.target, cuqi.distribution.JointDistribution):
183+
return None
184+
185+
par_names = self.target.get_parameter_names()
186+
187+
suggested_samplers = dict()
188+
for par_name in par_names:
189+
conditional_params = {par_name_: np.ones(self.target.dim[i]) for i, par_name_ in enumerate(par_names) if par_name_ != par_name}
190+
conditional = self.target(**conditional_params)
191+
192+
recommender = SamplerRecommender(conditional, exceptions = self._exceptions.copy())
193+
sampler = recommender.recommend(as_string = as_string)
194+
195+
if sampler is None:
196+
return None
197+
198+
suggested_samplers[par_name] = sampler
199+
200+
return suggested_samplers

cuqi/experimental/mcmc/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,3 @@
120120
from ._conjugate import Conjugate
121121
from ._conjugate_approx import ConjugateApprox
122122
from ._direct import Direct
123-
from ._utilities import find_valid_samplers

cuqi/experimental/mcmc/_utilities.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

tests/zexperimental/test_mcmc.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -830,64 +830,6 @@ def test_conjugate_wrong_equation_for_conjugate_parameter_supported_cases(target
830830
cuqi.experimental.mcmc.ConjugateApprox(target=posterior)
831831
else:
832832
cuqi.experimental.mcmc.Conjugate(target=posterior)
833-
def test_find_valid_samplers_linearGaussianGaussian():
834-
target = cuqi.testproblem.Deconvolution1D(dim=2).posterior
835-
836-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(target)
837-
838-
assert(set(valid_samplers) == set(['CWMH', 'LinearRTO', 'MALA', 'MH', 'NUTS', 'PCN', 'ULA']))
839-
840-
def test_find_valid_samplers_nonlinearGaussianGaussian():
841-
posterior = cuqi.testproblem.Poisson1D(dim=2).posterior
842-
843-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(posterior)
844-
845-
print(set(valid_samplers) == set(['CWMH', 'MH', 'PCN']))
846-
847-
def test_find_valid_samplers_conjugate_valid():
848-
""" Test that conjugate sampler is valid for Gaussian-Gamma conjugate pair when parameter is defined as the precision."""
849-
x = cuqi.distribution.Gamma(1,1)
850-
y = cuqi.distribution.Gaussian(np.zeros(2), cov=lambda x : 1/x) # Valid on precision only, e.g. cov=lambda x : 1/x
851-
target = cuqi.distribution.JointDistribution(y, x)(y = 1)
852-
853-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(target)
854-
855-
assert(set(valid_samplers) == set(['CWMH', 'Conjugate', 'MH']))
856-
857-
def test_find_valid_samplers_conjugate_invalid():
858-
""" Test that conjugate sampler is invalid for Gaussian-Gamma conjugate pair when parameter is defined as the covariance."""
859-
x = cuqi.distribution.Gamma(1,1)
860-
y = cuqi.distribution.Gaussian(np.zeros(2), cov=lambda x : x) # Invalid if defined via covariance as cov=lambda x : x
861-
target = cuqi.distribution.JointDistribution(y, x)(y = 1)
862-
863-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(target)
864-
865-
assert(set(valid_samplers) == set(['CWMH', 'MH']))
866-
867-
def test_find_valid_samplers_direct():
868-
target = cuqi.distribution.Gamma(1,1)
869-
870-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(target)
871-
872-
assert(set(valid_samplers) == set(['CWMH', 'Direct', 'MH']))
873-
874-
def test_find_valid_samplers_implicit_posterior():
875-
A, y_obs, _ = cuqi.testproblem.Deconvolution1D(dim=2).get_components()
876-
877-
x = cuqi.implicitprior.RegularizedGaussian(np.zeros(2), 1, constraint="nonnegativity")
878-
y = cuqi.distribution.Gaussian(A@x, 1)
879-
target = cuqi.distribution.JointDistribution(y, x)(y = y_obs)
880-
881-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(target)
882-
883-
assert(set(valid_samplers) == set(['RegularizedLinearRTO']))
884-
885-
def test_find_valid_samplers_implicit_prior():
886-
target = cuqi.implicitprior.RegularizedGaussian(np.zeros(2), 1, constraint="nonnegativity")
887-
888-
valid_samplers = cuqi.experimental.mcmc.find_valid_samplers(target)
889-
890-
assert(len(set(valid_samplers)) == 0)
891833

892834
# ============ Testing of HybridGibbs ============
893835

0 commit comments

Comments
 (0)