Skip to content

Commit 1312577

Browse files
authored
Merge pull request #668 from CUQI-DTU/add_recommender_example
Add example for using recommender class
2 parents 7536b0d + 445df1c commit 1312577

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

cuqi/experimental/_recommender.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ class SamplerRecommender(object):
1717
1818
exceptions: list[cuqi.experimental.mcmc.Sampler], *optional*
1919
Samplers not to be recommended.
20+
21+
Example
22+
-------
23+
.. code-block:: python
24+
import numpy as np
25+
from cuqi.distribution import Gamma, Gaussian, JointDistribution
26+
from cuqi.experimental import SamplerRecommender
27+
28+
x = Gamma(1, 1)
29+
y = Gaussian(np.zeros(2), cov=lambda x: 1 / x)
30+
target = JointDistribution(y, x)(y=1)
31+
32+
recommender = SamplerRecommender(target)
33+
valid_samplers = recommender.valid_samplers()
34+
recommended_sampler = recommender.recommend()
35+
print("Valid samplers:", valid_samplers)
36+
print("Recommended sampler:\n", recommended_sampler)
37+
2038
"""
2139

2240
def __init__(self, target:cuqi.density.Density, exceptions = []):
@@ -28,7 +46,7 @@ def __init__(self, target:cuqi.density.Density, exceptions = []):
2846
def target(self) -> cuqi.density.Density:
2947
""" Return the target Distribution. """
3048
return self._target
31-
49+
3250
@target.setter
3351
def target(self, value:cuqi.density.Density):
3452
""" Set the target Distribution. Runs validation of the target. """
@@ -73,7 +91,7 @@ def _create_ordering(self):
7391
def ordering(self):
7492
""" Returns the ordered list of recommendation rules used by the recommender. """
7593
return self._ordering
76-
94+
7795
def valid_samplers(self, as_string = True):
7896
"""
7997
Finds all possible samplers that can be used for sampling from the target distribution.
@@ -101,7 +119,6 @@ def valid_samplers(self, as_string = True):
101119
valid_samplers += [cuqi.experimental.mcmc.HybridGibbs.__name__ if as_string else cuqi.experimental.mcmc.HybridGibbs]
102120

103121
return valid_samplers
104-
105122

106123
def valid_HybridGibbs_sampling_strategy(self, as_string = True):
107124
"""
@@ -131,11 +148,10 @@ def valid_HybridGibbs_sampling_strategy(self, as_string = True):
131148
samplers = recommender.valid_samplers(as_string)
132149
if len(samplers) == 0:
133150
return None
134-
151+
135152
valid_samplers[par_name] = samplers
136153

137154
return valid_samplers
138-
139155

140156
def recommend(self, as_string = False):
141157
"""
@@ -194,7 +210,7 @@ def recommend_HybridGibbs_sampling_strategy(self, as_string = False):
194210

195211
if sampler is None:
196212
return None
197-
213+
198214
suggested_samplers[par_name] = sampler
199215

200-
return suggested_samplers
216+
return suggested_samplers

0 commit comments

Comments
 (0)