Skip to content

Commit d684612

Browse files
authored
Merge pull request #623 from CUQI-DTU/callback_update
Enable callback in Gibbs and update callback interface to pass more information to the callback method
2 parents b30b3a9 + 9390c33 commit d684612

File tree

9 files changed

+156
-67
lines changed

9 files changed

+156
-67
lines changed

cuqi/experimental/mcmc/_cwmh.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@ class CWMH(ProposalBasedSampler):
3131
initial_point : ndarray
3232
Initial parameters. *Optional*
3333
34-
callback : callable, *Optional*
35-
If set this function will be called after every sample.
36-
The signature of the callback function is
37-
`callback(sample, sample_index)`, where `sample` is the current sample
38-
and `sample_index` is the index of the sample.
39-
An example is shown in demos/demo31_callback.py.
34+
callback : callable, optional
35+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
36+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
4037
4138
kwargs : dict
4239
Additional keyword arguments to be passed to the base class

cuqi/experimental/mcmc/_gibbs.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class HybridGibbs:
5858
will call its step method in each Gibbs step.
5959
Default is 1 for all variables.
6060
61+
callback : callable, optional
62+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
63+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
64+
6165
Example
6266
-------
6367
.. code-block:: python
@@ -103,7 +107,7 @@ class HybridGibbs:
103107
104108
"""
105109

106-
def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None):
110+
def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, callback=None):
107111

108112
# Store target and allow conditioning to reduce to a single density
109113
self.target = target() # Create a copy of target distribution (to avoid modifying the original)
@@ -120,6 +124,9 @@ def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampl
120124
# Initialize sampler (after target is set)
121125
self._initialize()
122126

127+
# Set the callback function
128+
self.callback = callback
129+
123130
def _initialize(self):
124131
""" Initialize sampler """
125132

@@ -158,13 +165,15 @@ def sample(self, Ns) -> 'HybridGibbs':
158165
The number of samples to draw.
159166
160167
"""
161-
162-
for _ in tqdm(range(Ns), "Sample: "):
168+
for idx in tqdm(range(Ns), "Sample: "):
163169

164170
self.step()
165171

166172
self._store_samples()
167173

174+
# Call callback function if specified
175+
self._call_callback(idx, Ns)
176+
168177
return self
169178

170179
def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
@@ -192,6 +201,9 @@ def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
192201

193202
self._store_samples()
194203

204+
# Call callback function if specified
205+
self._call_callback(idx, Nb)
206+
195207
return self
196208

197209
def get_samples(self) -> Dict[str, Samples]:
@@ -263,6 +275,11 @@ def tune(self, skip_len, update_count):
263275
self.samplers[par_name].tune(skip_len=skip_len, update_count=update_count)
264276

265277
# ------------ Private methods ------------
278+
def _call_callback(self, sample_index, num_of_samples):
279+
""" Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
280+
if self.callback is not None:
281+
self.callback(self, sample_index, num_of_samples)
282+
266283
def _initialize_samplers(self):
267284
""" Initialize samplers """
268285
for sampler in self.samplers.values():

cuqi/experimental/mcmc/_hmc.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,9 @@ class NUTS(Sampler):
3838
opt_acc_rate should be in (0, 1), however, choosing a value that is very
3939
close to 1 or 0 might lead to poor performance of the sampler.
4040
41-
callback : callable, *Optional*
42-
If set this function will be called after every sample.
43-
The signature of the callback function is
44-
`callback(sample, sample_index)`,
45-
where `sample` is the current sample and `sample_index` is the index of
46-
the sample.
47-
An example is shown in demos/demo31_callback.py.
41+
callback : callable, optional
42+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
43+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
4844
4945
Example
5046
-------

cuqi/experimental/mcmc/_langevin_algorithm.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ class ULA(Sampler): # Refactor to Proposal-based sampler?
3232
be smaller than 1/L, where L is the Lipschitz of the gradient of the log
3333
target density, logd).
3434
35-
callback : callable, *Optional*
36-
If set this function will be called after every sample.
37-
The signature of the callback function is `callback(sample, sample_index)`,
38-
where `sample` is the current sample and `sample_index` is the index of the sample.
39-
An example is shown in demos/demo31_callback.py.
35+
callback : callable, optional
36+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
37+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
4038
4139
4240
Example
@@ -164,11 +162,9 @@ class MALA(ULA): # Refactor to Proposal-based sampler?
164162
be smaller than 1/L, where L is the Lipschitz of the gradient of the log
165163
target density, logd).
166164
167-
callback : callable, *Optional*
168-
If set this function will be called after every sample.
169-
The signature of the callback function is `callback(sample, sample_index)`,
170-
where `sample` is the current sample and `sample_index` is the index of the sample.
171-
An example is shown in demos/demo31_callback.py.
165+
callback : callable, optional
166+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
167+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
172168
173169
174170
Example
@@ -288,12 +284,9 @@ class MYULA(ULA):
288284
smoothing_strength : float
289285
This parameter controls the smoothing strength of MYULA.
290286
291-
callback : callable, *Optional*
292-
If set this function will be called after every sample.
293-
The signature of the callback function is `callback(sample, sample_index)`,
294-
where `sample` is the current sample and `sample_index` is the index of
295-
the sample.
296-
An example is shown in demos/demo31_callback.py.
287+
callback : callable, optional
288+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
289+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
297290
298291
A Deblur example can be found in demos/howtos/myula.py
299292
# TODO: update demo once sampler merged
@@ -378,12 +371,9 @@ class PnPULA(MYULA):
378371
This parameter controls the smoothing strength of PnP-ULA.
379372
380373
381-
callback : callable, *Optional*
382-
If set this function will be called after every sample.
383-
The signature of the callback function is `callback(sample, sample_index)`,
384-
where `sample` is the current sample and `sample_index` is the index of
385-
the sample.
386-
An example is shown in demos/demo31_callback.py.
374+
callback : callable, optional
375+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
376+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
387377
388378
# TODO: update demo once sampler merged
389379
"""

cuqi/experimental/mcmc/_laplace_approximation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@ class UGLA(Sampler):
4343
sampling easier but results in a worse approximation. See details in Section 3.3 of the paper.
4444
If not provided, it defaults to 1e-5.
4545
46-
callback : callable, *Optional*
47-
If set, this function will be called after every sample.
48-
The signature of the callback function is `callback(sample, sample_index)`,
49-
where `sample` is the current sample and `sample_index` is the index of the sample.
50-
An example is shown in demos/demo31_callback.py.
46+
callback : callable, optional
47+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
48+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
5149
"""
5250
def __init__(self, target=None, initial_point=None, maxit=50, tol=1e-4, beta=1e-5, **kwargs):
5351

cuqi/experimental/mcmc/_rto.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ class LinearRTO(Sampler):
3636
tol : float
3737
Tolerance of the inner CGLS solver. *Optional*.
3838
39-
callback : callable, *Optional*
40-
If set this function will be called after every sample.
41-
The signature of the callback function is `callback(sample, sample_index)`,
42-
where `sample` is the current sample and `sample_index` is the index of the sample.
43-
An example is shown in demos/demo31_callback.py.
39+
callback : callable, optional
40+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
41+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
4442
4543
"""
4644
def __init__(self, target=None, initial_point=None, maxit=10, tol=1e-6, **kwargs):
@@ -204,11 +202,9 @@ class RegularizedLinearRTO(LinearRTO):
204202
solver : string
205203
If set to "ScipyLinearLSQ", solver is set to cuqi.solver.ScipyLinearLSQ, otherwise FISTA/ISTA or ADMM is used. Note "ScipyLinearLSQ" can only be used with `RegularizedGaussian` of `box` or `nonnegativity` constraint. *Optional*.
206204
207-
callback : callable, *Optional*
208-
If set this function will be called after every sample.
209-
The signature of the callback function is `callback(sample, sample_index)`,
210-
where `sample` is the current sample and `sample_index` is the index of the sample.
211-
An example is shown in demos/demo31_callback.py.
205+
callback : callable, optional
206+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
207+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
212208
213209
"""
214210
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, **kwargs):

cuqi/experimental/mcmc/_sampler.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ def __init__(self, target:cuqi.density.Density=None, initial_point=None, callbac
5959
The initial point for the sampler. If not given, the sampler will choose an initial point.
6060
6161
callback : callable, optional
62-
A function that will be called after each sample is drawn. The function should take two arguments: the sample and the index of the sample.
63-
The sample is a 1D numpy array and the index is an integer. The callback function is useful for monitoring the sampler during sampling.
64-
62+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
63+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
6564
"""
6665

6766
self.target = target
@@ -209,7 +208,6 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
209208
The path to save the samples. If not specified, the samples are saved to the current working directory under a folder called 'CUQI_samples'.
210209
211210
"""
212-
213211
self._ensure_initialized()
214212

215213
# Initialize batch handler
@@ -235,7 +233,7 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
235233
batch_handler.add_sample(self.current_point)
236234

237235
# Call callback function if specified
238-
self._call_callback(self.current_point, len(self._samples)-1)
236+
self._call_callback(idx, Ns)
239237

240238
return self
241239

@@ -276,7 +274,7 @@ def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
276274
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
277275

278276
# Call callback function if specified
279-
self._call_callback(self.current_point, len(self._samples)-1)
277+
self._call_callback(idx, Nb)
280278

281279
return self
282280

@@ -367,10 +365,10 @@ def set_history(self, history: dict):
367365
raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
368366

369367
# ------------ Private methods ------------
370-
def _call_callback(self, sample, sample_index):
371-
""" Calls the callback function. Assumes input is sample and sample index"""
368+
def _call_callback(self, sample_index, num_of_samples):
369+
""" Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
372370
if self.callback is not None:
373-
self.callback(sample, sample_index)
371+
self.callback(self, sample_index, num_of_samples)
374372

375373
def _validate_initialization(self):
376374
""" Validate the initialization of the sampler by checking all state and history keys are set. """

cuqi/problem/_problem.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def sample_posterior(self, Ns, Nb=None, callback=None, experimental=False) -> cu
304304
The signature of the callback function is `callback(sample, sample_index)`,
305305
where `sample` is the current sample and `sample_index` is the index of the sample.
306306
An example is shown in demos/demo31_callback.py.
307+
Note: if the parameter `experimental` is set to True, the callback function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature in the case is: `callback(sampler, sample_index, num_of_samples)`.
307308
308309
experimental : bool, *Optional*
309310
If set to True, the sampler selection will use the samplers from the :mod:`cuqi.experimental.mcmc` module.
@@ -848,16 +849,14 @@ def _sampleGibbs(self, Ns, Nb, callback=None, experimental=False):
848849
print(f"burn-in: {Nb/Ns*100:g}%")
849850
print("")
850851

851-
if callback is not None:
852-
raise NotImplementedError("Callback not implemented for Gibbs sampler")
853-
854852
# Start timing
855853
ti = time.time()
856854

857855
# Sampling strategy
858856
sampling_strategy = self._determine_sampling_strategy(experimental=True)
859857

860-
sampler = cuqi.experimental.mcmc.HybridGibbs(self._target, sampling_strategy)
858+
sampler = cuqi.experimental.mcmc.HybridGibbs(
859+
self._target, sampling_strategy, callback=callback)
861860
sampler.warmup(Nb)
862861
sampler.sample(Ns)
863862
samples = sampler.get_samples()
@@ -876,7 +875,7 @@ def _sampleGibbs(self, Ns, Nb, callback=None, experimental=False):
876875
print("")
877876

878877
if callback is not None:
879-
raise NotImplementedError("Callback not implemented for Gibbs sampler")
878+
raise NotImplementedError("Callback not implemented for Gibbs sampler. It is only implemented for experimental Gibbs sampler.")
880879

881880
# Start timing
882881
ti = time.time()

0 commit comments

Comments
 (0)