1+ import cuqi
2+ import inspect
3+ import numpy as np
4+
5+ # This import makes suggest_sampler easier to read
6+ import cuqi .experimental .mcmc as samplers
7+
8+
9+ class SamplerRecommender (object ):
10+ """
11+ This class can be used to automatically choose a sampler.
12+
13+ Parameters
14+ ----------
15+ target: Density or JointDistribution
16+ Distribution to get sampler recommendations for.
17+
18+ exceptions: list[cuqi.experimental.mcmc.Sampler], *optional*
19+ Samplers not to be recommended.
20+ """
21+
22+ def __init__ (self , target :cuqi .density .Density , exceptions = []):
23+ self ._target = target
24+ self ._exceptions = exceptions
25+ self ._create_ordering ()
26+
27+ @property
28+ def target (self ) -> cuqi .density .Density :
29+ """ Return the target Distribution. """
30+ return self ._target
31+
32+ @target .setter
33+ def target (self , value :cuqi .density .Density ):
34+ """ Set the target Distribution. Runs validation of the target. """
35+ if value is None :
36+ raise ValueError ("Target needs to be of type cuqi.density.Density." )
37+ self ._target = value
38+
39+ def _create_ordering (self ):
40+ """
41+ Every element in the ordering consists of a tuple:
42+ (
43+ Sampler: Class
44+ boolean: additional conditions on the target
45+ parameters: additional parameters to be passed to the sampler once initialized
46+ )
47+ """
48+ number_of_components = np .sum (self ._target .dim )
49+
50+ self ._ordering = [
51+ # Direct and Conjugate samplers
52+ (samplers .Direct , True , {}),
53+ (samplers .Conjugate , True , {}),
54+ (samplers .ConjugateApprox , True , {}),
55+ # Specialized samplers
56+ (samplers .LinearRTO , True , {}),
57+ (samplers .RegularizedLinearRTO , True , {}),
58+ (samplers .UGLA , True , {}),
59+ # Gradient.based samplers (Hamiltonian and Langevin)
60+ (samplers .NUTS , True , {}),
61+ (samplers .MALA , True , {}),
62+ (samplers .ULA , True , {}),
63+ # Gibbs and Componentwise samplers
64+ (samplers .HybridGibbs , True , {"sampling_strategy" : self .recommend_HybridGibbs_sampling_strategy (as_string = False )}),
65+ (samplers .CWMH , number_of_components <= 100 , {"scale" : 0.05 * np .ones (number_of_components ),
66+ "initial_point" : 0.5 * np .ones (number_of_components )}),
67+ # Proposal based samplers
68+ (samplers .PCN , True , {"scale" : 0.02 }),
69+ (samplers .MH , number_of_components <= 1000 , {}),
70+ ]
71+
72+ @property
73+ def ordering (self ):
74+ """ Returns the ordered list of recommendation rules used by the recommender. """
75+ return self ._ordering
76+
77+ def valid_samplers (self , as_string = True ):
78+ """
79+ Finds all possible samplers that can be used for sampling from the target distribution.
80+
81+ Parameters
82+ ----------
83+
84+ as_string : boolean
85+ Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
86+
87+ """
88+
89+ all_samplers = [(name , cls ) for name , cls in inspect .getmembers (cuqi .experimental .mcmc , inspect .isclass ) if issubclass (cls , cuqi .experimental .mcmc .Sampler )]
90+ valid_samplers = []
91+
92+ for name , sampler in all_samplers :
93+ try :
94+ sampler (self .target )
95+ valid_samplers += [name if as_string else sampler ]
96+ except :
97+ pass
98+
99+ # Need a separate case for HybridGibbs
100+ if self .valid_HybridGibbs_sampling_strategy () is not None :
101+ valid_samplers += [cuqi .experimental .mcmc .HybridGibbs .__name__ if as_string else cuqi .experimental .mcmc .HybridGibbs ]
102+
103+ return valid_samplers
104+
105+
106+ def valid_HybridGibbs_sampling_strategy (self , as_string = True ):
107+ """
108+ Find all possible sampling strategies to be used with the HybridGibbs sampler.
109+ Returns None if no sampler could be suggested for at least one conditional distribution.
110+
111+ Parameters
112+ ----------
113+
114+ as_string : boolean
115+ Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
116+
117+
118+ """
119+
120+ if not isinstance (self .target , cuqi .distribution .JointDistribution ):
121+ return None
122+
123+ par_names = self .target .get_parameter_names ()
124+
125+ valid_samplers = dict ()
126+ for par_name in par_names :
127+ conditional_params = {par_name_ : np .ones (self .target .dim [i ]) for i , par_name_ in enumerate (par_names ) if par_name_ != par_name }
128+ conditional = self .target (** conditional_params )
129+
130+ recommender = SamplerRecommender (conditional )
131+ samplers = recommender .valid_samplers (as_string )
132+ if len (samplers ) == 0 :
133+ return None
134+
135+ valid_samplers [par_name ] = samplers
136+
137+ return valid_samplers
138+
139+
140+ def recommend (self , as_string = False ):
141+ """
142+ Suggests a possible sampler that can be used for sampling from the target distribution.
143+ Return None if no sampler could be suggested.
144+
145+ Parameters
146+ ----------
147+
148+ as_string : boolean
149+ Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
150+
151+ """
152+
153+ valid_samplers = self .valid_samplers (as_string = False )
154+
155+ for suggestion , flag , values in self ._ordering :
156+ if flag and (suggestion in valid_samplers ) and (suggestion not in self ._exceptions ):
157+ # Sampler found
158+ if as_string :
159+ return suggestion .__name__
160+ else :
161+ return suggestion (self .target , ** values )
162+
163+ # No sampler can be suggested
164+ raise ValueError ("Cannot suggest any sampler. Either the provided distribution is incorrectly defined or there are too many exceptions provided." )
165+
166+ def recommend_HybridGibbs_sampling_strategy (self , as_string = False ):
167+ """
168+ Suggests a possible sampling strategy to be used with the HybridGibbs sampler.
169+ Returns None if no sampler could be suggested for at least one conditional distribution.
170+
171+ Parameters
172+ ----------
173+
174+ target : `cuqi.distribution.JointDistribution`
175+ The target distribution get a sampling strategy for.
176+
177+ as_string : boolean
178+ Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
179+
180+ """
181+
182+ if not isinstance (self .target , cuqi .distribution .JointDistribution ):
183+ return None
184+
185+ par_names = self .target .get_parameter_names ()
186+
187+ suggested_samplers = dict ()
188+ for par_name in par_names :
189+ conditional_params = {par_name_ : np .ones (self .target .dim [i ]) for i , par_name_ in enumerate (par_names ) if par_name_ != par_name }
190+ conditional = self .target (** conditional_params )
191+
192+ recommender = SamplerRecommender (conditional , exceptions = self ._exceptions .copy ())
193+ sampler = recommender .recommend (as_string = as_string )
194+
195+ if sampler is None :
196+ return None
197+
198+ suggested_samplers [par_name ] = sampler
199+
200+ return suggested_samplers
0 commit comments