Skip to content

Commit 1bd1659

Browse files
authored
Merge pull request #679 from CUQI-DTU/rto_initial_guess
Allow other options for initial guess in (Regularilzed)LinearRTO steps
2 parents f6a73b0 + a27c6da commit 1bd1659

File tree

3 files changed

+89
-9
lines changed

3 files changed

+89
-9
lines changed

cuqi/experimental/mcmc/_rto.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,48 @@ class LinearRTO(Sampler):
3636
tol : float
3737
Tolerance of the inner CGLS solver. *Optional*.
3838
39+
inner_initial_point : string or np.ndarray or cuqi.array.CUQIArray
40+
Initial point for the inner optimization problem. Can be "previous_sample" (default), "MAP", or a specific numpy or cuqi array. *Optional*.
41+
3942
callback : callable, optional
4043
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
4144
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
4245
4346
"""
44-
def __init__(self, target=None, initial_point=None, maxit=10, tol=1e-6, **kwargs):
47+
def __init__(self, target=None, initial_point=None, maxit=10, tol=1e-6, inner_initial_point="previous_sample", **kwargs):
4548

4649
super().__init__(target=target, initial_point=initial_point, **kwargs)
4750

4851
# Other parameters
4952
self.maxit = maxit
5053
self.tol = tol
54+
self.inner_initial_point = inner_initial_point
5155

5256
def _initialize(self):
5357
self._precompute()
58+
self._compute_map()
59+
60+
@property
61+
def inner_initial_point(self):
62+
if isinstance(self._inner_initial_point, str):
63+
if self._inner_initial_point == "previous_sample":
64+
return self.current_point
65+
elif self._inner_initial_point == "map":
66+
return self._map
67+
else:
68+
return self._inner_initial_point
69+
70+
@inner_initial_point.setter
71+
def inner_initial_point(self, value):
72+
is_correct_string = (isinstance(value, str) and
73+
(value.lower() == "previous_sample" or
74+
value.lower() == "map"))
75+
if is_correct_string:
76+
self._inner_initial_point = value.lower()
77+
elif isinstance(value, (np.ndarray, cuqi.array.CUQIarray)):
78+
self._inner_initial_point = value
79+
else:
80+
raise ValueError("Invalid value for inner_initial_point. Choose either 'previous_sample', 'MAP', or provide a numpy array/cuqi array.")
5481

5582
@property
5683
def prior(self):
@@ -78,6 +105,10 @@ def models(self):
78105
elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
79106
return self.target.models
80107

108+
def _compute_map(self):
109+
sim = CGLS(self.M, self.b_tild, self.current_point, self.maxit, self.tol)
110+
self._map, _ = sim.solve()
111+
81112
def _precompute(self):
82113
L1 = [likelihood.distribution.sqrtprec for likelihood in self.likelihoods]
83114
L2 = self.prior.sqrtprec
@@ -114,7 +145,7 @@ def M(x, flag):
114145

115146
def step(self):
116147
y = self.b_tild + np.random.randn(len(self.b_tild))
117-
sim = CGLS(self.M, y, self.current_point, self.maxit, self.tol)
148+
sim = CGLS(self.M, y, self.inner_initial_point, self.maxit, self.tol)
118149
self.current_point, _ = sim.solve()
119150
acc = 1
120151
return acc
@@ -203,12 +234,15 @@ class RegularizedLinearRTO(LinearRTO):
203234
solver : string
204235
Options are "FISTA" (default for a single constraint or regularization), "ADMM" (default and the only option for multiple constraints or regularizations), "ScipyLinearLSQ" and "ScipyMinimizer". Note "ScipyLinearLSQ" and "ScipyMinimizer" can only be used with `RegularizedGaussian` of a single `box` or `nonnegativity` constraint. *Optional*.
205236
237+
inner_initial_point : string or np.ndarray or cuqi.array.CUQIArray
238+
Initial point for the inner optimization problem. Can be "previous_sample" (default), "MAP", or a specific numpy or cuqi array. *Optional*.
239+
206240
callback : callable, optional
207241
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
208242
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
209243
210244
"""
211-
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, **kwargs):
245+
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, inner_initial_point="previous_sample", **kwargs):
212246

213247
super().__init__(target=target, initial_point=initial_point, **kwargs)
214248

@@ -221,13 +255,15 @@ def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10,
221255
self.inner_max_it = inner_max_it
222256
self.penalty_parameter = penalty_parameter
223257
self.solver = solver
258+
self.inner_initial_point = inner_initial_point
224259

225260
def _initialize(self):
226261
super()._initialize()
227262
if self.solver is None:
228263
self.solver = "FISTA" if callable(self.proximal) else "ADMM"
229264
if self.solver == "FISTA":
230265
self._stepsize = self._choose_stepsize()
266+
self._compute_map_regularized()
231267

232268
@property
233269
def solver(self):
@@ -272,15 +308,16 @@ def _choose_stepsize(self):
272308
def prior(self):
273309
return self.target.prior.gaussian
274310

275-
def step(self):
276-
y = self.b_tild + np.random.randn(len(self.b_tild))
311+
def _compute_map_regularized(self):
312+
self._map = self._customized_step(self.b_tild, self.initial_point)
277313

314+
def _customized_step(self, y, x0):
278315
if self.solver == "FISTA":
279316
sim = FISTA(self.M, y, self.proximal,
280-
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
317+
x0, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
281318
elif self.solver == "ADMM":
282319
sim = ADMM(self.M, y, self.proximal,
283-
self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
320+
x0, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
284321
elif self.solver == "ScipyLinearLSQ":
285322
A_op = sp.sparse.linalg.LinearOperator((sum([llh.distribution.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
286323
matvec=lambda x: self.M(x, 1),
@@ -297,10 +334,17 @@ def step(self):
297334
bounds = [(self.target.prior._box_bounds[0][i], self.target.prior._box_bounds[1][i]) for i in range(self.target.prior.dim)]
298335
# Note that the objective function is defined as 0.5*||Mx-y||^2,
299336
# and the corresponding gradient (gradfunc) is given by M^T(Mx-y).
300-
sim = ScipyMinimizer(lambda x: 0.5*np.sum((self.M(x, 1)-y)**2), self.current_point, gradfunc=lambda x: self.M(self.M(x, 1) - y, 2), bounds=bounds, tol=self.abstol, options={"maxiter": self.maxit})
337+
sim = ScipyMinimizer(lambda x: 0.5*np.sum((self.M(x, 1)-y)**2), x0, gradfunc=lambda x: self.M(self.M(x, 1) - y, 2), bounds=bounds, tol=self.abstol, options={"maxiter": self.maxit})
301338
else:
302339
raise ValueError("Choice of solver not supported.")
340+
341+
sol, _ = sim.solve()
342+
return sol
343+
344+
def step(self):
345+
y = self.b_tild + np.random.randn(len(self.b_tild))
346+
347+
self.current_point = self._customized_step(y, self.inner_initial_point)
303348

304-
self.current_point, _ = sim.solve()
305349
acc = 1
306350
return acc

cuqi/experimental/mcmc/_sampler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ def target(self, value):
148148
if self._target is not None:
149149
self.validate_target()
150150

151+
@property
152+
def current_point(self):
153+
""" The current point of the sampler. """
154+
return self._current_point
155+
156+
@current_point.setter
157+
def current_point(self, value):
158+
""" Set the current point of the sampler. """
159+
self._current_point = value
160+
151161
# ------------ Public methods ------------
152162
def get_samples(self) -> Samples:
153163
""" Return the samples. The internal data-structure for the samples is a dynamic list so this creates a copy. """

tests/zexperimental/test_mcmc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,32 @@ def test_RegularizedLinearRTO_ScipyLinearLSQ_option_invalid():
14531453
with pytest.raises(ValueError, match="ScipyLinearLSQ"):
14541454
sampler = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, solver = "ScipyLinearLSQ")
14551455

1456+
def test_RegularizedLinearRTO_inner_initial_point_setting():
1457+
# Define LinearModel and data
1458+
A, y_obs, _ = cuqi.testproblem.Deconvolution1D().get_components()
1459+
1460+
# Define Bayesian Problem
1461+
x = cuqi.implicitprior.NonnegativeGMRF(np.zeros(A.domain_dim), 100)
1462+
y = cuqi.distribution.Gaussian(A@x, 0.01**2)
1463+
posterior = cuqi.distribution.JointDistribution(x, y)(y=y_obs)
1464+
1465+
# Set up RegularizedLinearRTO with three solvers
1466+
sampler1 = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, maxit=10, inner_initial_point="previous_sample", tol=1e-8)
1467+
sampler2 = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, maxit=10, inner_initial_point="MAP", tol=1e-8)
1468+
sampler3 = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, maxit=10, inner_initial_point=np.ones(A.domain_dim), tol=1e-8)
1469+
1470+
# Sample with fixed seed
1471+
np.random.seed(0)
1472+
sampler1.sample(5)
1473+
np.random.seed(0)
1474+
sampler2.sample(5)
1475+
np.random.seed(0)
1476+
sampler3.sample(5)
1477+
1478+
assert np.allclose(sampler1.inner_initial_point, sampler1.current_point, rtol=1e-5)
1479+
assert np.allclose(sampler2.inner_initial_point, sampler2._map, rtol=1e-5)
1480+
assert np.allclose(sampler3.inner_initial_point, np.ones(A.domain_dim), rtol=1e-5)
1481+
14561482
def test_RegularizedLinearRTO_ScipyLinearLSQ_against_ScipyMinimizer_and_against_FISTA():
14571483
# Define LinearModel and data
14581484
A, y_obs, _ = cuqi.testproblem.Deconvolution1D().get_components()

0 commit comments

Comments
 (0)