@@ -17,6 +17,24 @@ class SamplerRecommender(object):
1717
1818 exceptions: list[cuqi.experimental.mcmc.Sampler], *optional*
1919 Samplers not to be recommended.
20+
21+ Example
22+ -------
23+ .. code-block:: python
24+ import numpy as np
25+ from cuqi.distribution import Gamma, Gaussian, JointDistribution
26+ from cuqi.experimental import SamplerRecommender
27+
28+ x = Gamma(1, 1)
29+ y = Gaussian(np.zeros(2), cov=lambda x: 1 / x)
30+ target = JointDistribution(y, x)(y=1)
31+
32+ recommender = SamplerRecommender(target)
33+ valid_samplers = recommender.valid_samplers()
34+ recommended_sampler = recommender.recommend()
35+ print("Valid samplers:", valid_samplers)
36+ print("Recommended sampler:\n ", recommended_sampler)
37+
2038 """
2139
2240 def __init__ (self , target :cuqi .density .Density , exceptions = []):
@@ -28,7 +46,7 @@ def __init__(self, target:cuqi.density.Density, exceptions = []):
2846 def target (self ) -> cuqi .density .Density :
2947 """ Return the target Distribution. """
3048 return self ._target
31-
49+
3250 @target .setter
3351 def target (self , value :cuqi .density .Density ):
3452 """ Set the target Distribution. Runs validation of the target. """
@@ -73,7 +91,7 @@ def _create_ordering(self):
7391 def ordering (self ):
7492 """ Returns the ordered list of recommendation rules used by the recommender. """
7593 return self ._ordering
76-
94+
7795 def valid_samplers (self , as_string = True ):
7896 """
7997 Finds all possible samplers that can be used for sampling from the target distribution.
@@ -101,7 +119,6 @@ def valid_samplers(self, as_string = True):
101119 valid_samplers += [cuqi .experimental .mcmc .HybridGibbs .__name__ if as_string else cuqi .experimental .mcmc .HybridGibbs ]
102120
103121 return valid_samplers
104-
105122
106123 def valid_HybridGibbs_sampling_strategy (self , as_string = True ):
107124 """
@@ -131,11 +148,10 @@ def valid_HybridGibbs_sampling_strategy(self, as_string = True):
131148 samplers = recommender .valid_samplers (as_string )
132149 if len (samplers ) == 0 :
133150 return None
134-
151+
135152 valid_samplers [par_name ] = samplers
136153
137154 return valid_samplers
138-
139155
140156 def recommend (self , as_string = False ):
141157 """
@@ -194,7 +210,7 @@ def recommend_HybridGibbs_sampling_strategy(self, as_string = False):
194210
195211 if sampler is None :
196212 return None
197-
213+
198214 suggested_samplers [par_name ] = sampler
199215
200- return suggested_samplers
216+ return suggested_samplers
0 commit comments