Skip to content

Commit c62aed4

Browse files
authored
Merge pull request #644 from CUQI-DTU/debug_rto_solver
Add option to use ScipyMinimizer in RegularizedLinearRTO and fix bug
2 parents fe056d1 + 0ecc36d commit c62aed4

File tree

4 files changed

+66
-21
lines changed

4 files changed

+66
-21
lines changed

cuqi/experimental/mcmc/_rto.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
44
import numpy as np
55
import cuqi
6-
from cuqi.solver import CGLS, FISTA, ADMM, ScipyLinearLSQ
6+
from cuqi.solver import CGLS, FISTA, ADMM, ScipyLinearLSQ, ScipyMinimizer
77
from cuqi.experimental.mcmc import Sampler
88

99

@@ -167,6 +167,7 @@ class RegularizedLinearRTO(LinearRTO):
167167
ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
168168
Used when prior.proximal is a list of penalty terms.
169169
ScipyLinearLSQ: Wrapper for Scipy's lsq_linear for the Trust Region Reflective algorithm. Optionally used when the constraint is either "nonnegativity" or "box".
170+
ScipyMinimizer: Wrapper for Scipy's minimize. Optionally used when the constraint is either "nonnegativity" or "box".
170171
171172
Parameters
172173
------------
@@ -177,7 +178,7 @@ class RegularizedLinearRTO(LinearRTO):
177178
Initial point for the sampler. *Optional*.
178179
179180
maxit : int
180-
Maximum number of iterations of the FISTA/ADMM/ScipyLinearLSQ solver. *Optional*.
181+
Maximum number of iterations of the FISTA/ADMM/ScipyLinearLSQ/ScipyMinimizer solver. *Optional*.
181182
182183
inner_max_it : int
183184
Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
@@ -191,7 +192,7 @@ class RegularizedLinearRTO(LinearRTO):
191192
See [2] or `cuqi.solver.ADMM`
192193
193194
abstol : float
194-
Absolute tolerance of the FISTA/ScipyLinearLSQ solver. *Optional*.
195+
Absolute tolerance of the FISTA/ScipyLinearLSQ/ScipyMinimizer solver. *Optional*.
195196
196197
inner_abstol : float
197198
Tolerance parameter for ScipyLinearLSQ's inner solve of the unbounded least-squares problem. *Optional*.
@@ -200,7 +201,7 @@ class RegularizedLinearRTO(LinearRTO):
200201
If True, FISTA is used as solver, otherwise ISTA is used. *Optional*.
201202
202203
solver : string
203-
If set to "ScipyLinearLSQ", solver is set to cuqi.solver.ScipyLinearLSQ, otherwise FISTA/ISTA or ADMM is used. Note "ScipyLinearLSQ" can only be used with `RegularizedGaussian` of `box` or `nonnegativity` constraint. *Optional*.
204+
Options are "FISTA" (default for a single constraint or regularization), "ADMM" (default and the only option for multiple constraints or regularizations), "ScipyLinearLSQ" and "ScipyMinimizer". Note "ScipyLinearLSQ" and "ScipyMinimizer" can only be used with `RegularizedGaussian` of a single `box` or `nonnegativity` constraint. *Optional*.
204205
205206
callback : callable, optional
206207
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
@@ -234,11 +235,11 @@ def solver(self):
234235

235236
@solver.setter
236237
def solver(self, value):
237-
if value == "ScipyLinearLSQ":
238+
if value == "ScipyLinearLSQ" or value == "ScipyMinimizer":
238239
if (self.target.prior.preset["constraint"] == "nonnegativity" or self.target.prior.preset["constraint"] == "box"):
239240
self._solver = value
240241
else:
241-
raise ValueError("ScipyLinearLSQ only supports RegularizedGaussian with box or nonnegativity constraint.")
242+
raise ValueError("ScipyLinearLSQ and ScipyMinimizer only support RegularizedGaussian with box or nonnegativity constraint.")
242243
else:
243244
self._solver = value
244245

@@ -281,15 +282,22 @@ def step(self):
281282
sim = ADMM(self.M, y, self.proximal,
282283
self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
283284
elif self.solver == "ScipyLinearLSQ":
284-
A_op = sp.sparse.linalg.LinearOperator((sum([llh.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
285-
matvec=lambda x: self.M(x, 1),
286-
rmatvec=lambda x: self.M(x, 2)
287-
)
288-
sim = ScipyLinearLSQ(A_op, y, self.target.prior._box_bounds,
289-
max_iter = self.maxit,
290-
lsmr_maxiter = self.inner_max_it,
291-
tol = self.abstol,
292-
lsmr_tol = self.inner_abstol)
285+
A_op = sp.sparse.linalg.LinearOperator((sum([llh.distribution.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
286+
matvec=lambda x: self.M(x, 1),
287+
rmatvec=lambda x: self.M(x, 2)
288+
)
289+
sim = ScipyLinearLSQ(A_op, y, self.target.prior._box_bounds,
290+
max_iter = self.maxit,
291+
lsmr_maxiter = self.inner_max_it,
292+
tol = self.abstol,
293+
lsmr_tol = self.inner_abstol)
294+
elif self.solver == "ScipyMinimizer":
295+
# Adapt bounds format, as scipy.minimize requires a bounds format
296+
# different than that in scipy.lsq_linear.
297+
bounds = [(self.target.prior._box_bounds[0][i], self.target.prior._box_bounds[1][i]) for i in range(self.target.prior.dim)]
298+
# Note that the objective function is defined as 0.5*||Mx-y||^2,
299+
# and the corresponding gradient (gradfunc) is given by M^T(Mx-y).
300+
sim = ScipyMinimizer(lambda x: 0.5*np.sum((self.M(x, 1)-y)**2), self.current_point, gradfunc=lambda x: self.M(self.M(x, 1) - y, 2), bounds=bounds, tol=self.abstol, options={"maxiter": self.maxit})
293301
else:
294302
raise ValueError("Choice of solver not supported.")
295303

cuqi/solver/_solver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,19 @@ class ScipyLSQ(object):
196196
'trf', Trust Region Reflective algorithm: for large sparse problems with bounds.
197197
'dogbox', dogleg algorithm with rectangular trust regions, for small problems with bounds.
198198
'lm', Levenberg-Marquardt algorithm as implemented in MINPACK. Doesn't handle bounds and sparse Jacobians.
199+
tol : The numerical tolerance for convergence checks.
200+
maxit : The maximum number of iterations.
201+
kwargs : Additional keyword arguments passed to scipy's least_squares. Empty by default. See documentation for scipy.optimize.least_squares
199202
"""
200-
def __init__(self, func, x0, jacfun='2-point', method='trf', loss='linear', tol=1e-6, maxit=1e4):
203+
def __init__(self, func, x0, jacfun='2-point', method='trf', loss='linear', tol=1e-6, maxit=1e4, **kwargs):
201204
self.func = func
202205
self.x0 = x0
203206
self.jacfun = jacfun
204207
self.method = method
205208
self.loss = loss
206209
self.tol = tol
207210
self.maxit = int(maxit)
211+
self.kwargs = kwargs
208212

209213
def solve(self):
210214
"""Runs optimization algorithm and returns solution and info.
@@ -215,7 +219,7 @@ def solve(self):
215219
Solution found (array_like) and optimization information (dictionary).
216220
"""
217221
solution = least_squares(self.func, self.x0, jac=self.jacfun, \
218-
method=self.method, loss=self.loss, xtol=self.tol, max_nfev=self.maxit)
222+
method=self.method, loss=self.loss, xtol=self.tol, max_nfev=self.maxit, **self.kwargs)
219223
info = {"success": solution['success'],
220224
"message": solution['message'],
221225
"func": solution['fun'],

tests/test_solver.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,27 @@ def test_ScipyLinearLSQ_with_LinearOperator():
100100
sol, _ = ScipyLinearLSQ(A_op, b).solve()
101101
assert np.allclose(sol, ref_sol, rtol=1e-10)
102102

103-
def test_ScipyLinearLSQ_against_FISTA():
103+
def test_ScipyLinearLSQ_against_ScipyMinimizer_and_against_FISTA():
104104
A = np.array([[73,71,52],[87,74,46],[72,2,7],[80,89,71]])
105105
b = np.array([49,67,68,20])
106+
107+
# solve with ScipyMinimizer
108+
def fun(x):
109+
return 0.5*np.linalg.norm(A@x-b)**2
110+
def jac(x):
111+
return A.T@(A@x-b)
112+
sol_min, _ = ScipyMinimizer(fun, np.zeros(3), gradfunc=jac, tol=1e-10, bounds=[(0,np.inf),(0,np.inf),(0,np.inf)]).solve()
113+
106114
# solve with ScipyLinearLSQ
107-
lb = np.zeros(3)
108-
ub = lb + np.inf
109-
sol_lsq, _ = ScipyLinearLSQ(A, b, (lb,ub)).solve()
115+
sol_lsq, _ = ScipyLinearLSQ(A, b, ([0,0,0],[np.inf,np.inf,np.inf]), tol=1e-10).solve()
116+
110117
# solve with FISTA
111118
rng = np.random.default_rng(seed = 1219)
112119
x0 = rng.standard_normal(3)
113120
sol_fista, _ = FISTA(A, b, lambda x, _: ProjectNonnegative(x), x0, stepsize=1e-7, maxit=100000, abstol=1e-16, adaptive=True).solve()
114121

115122
assert np.allclose(sol_lsq, sol_fista, rtol=1e-8)
123+
assert np.allclose(sol_min, sol_lsq, rtol=1e-8)
116124

117125
def test_LM():
118126
# compare to MATLAB's original code solution

tests/zexperimental/test_mcmc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,31 @@ def test_RegularizedLinearRTO_ScipyLinearLSQ_option_invalid():
15111511
with pytest.raises(ValueError, match="ScipyLinearLSQ"):
15121512
sampler = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, solver = "ScipyLinearLSQ")
15131513

1514+
def test_RegularizedLinearRTO_ScipyLinearLSQ_against_ScipyMinimizer_and_against_FISTA():
1515+
# Define LinearModel and data
1516+
A, y_obs, _ = cuqi.testproblem.Deconvolution1D().get_components()
1517+
1518+
# Define Bayesian Problem
1519+
x = cuqi.implicitprior.NonnegativeGMRF(np.zeros(A.domain_dim), 100)
1520+
y = cuqi.distribution.Gaussian(A@x, 0.01**2)
1521+
posterior = cuqi.distribution.JointDistribution(x, y)(y=y_obs)
1522+
1523+
# Set up RegularizedLinearRTO with three solvers
1524+
sampler1 = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, solver="ScipyMinimizer", maxit=1000, tol=1e-8)
1525+
sampler2 = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, solver="ScipyLinearLSQ", maxit=1000, tol=1e-8)
1526+
sampler3 = cuqi.experimental.mcmc.RegularizedLinearRTO(posterior, solver="FISTA", maxit=1000, tol=1e-8)
1527+
1528+
# Sample with fixed seed
1529+
np.random.seed(0)
1530+
samples1 = sampler1.sample(5).get_samples()
1531+
np.random.seed(0)
1532+
samples2 = sampler2.sample(5).get_samples()
1533+
np.random.seed(0)
1534+
samples3 = sampler3.sample(5).get_samples()
1535+
1536+
assert np.allclose(samples1.samples.mean(), samples2.samples.mean(), rtol=1e-5)
1537+
assert np.allclose(samples1.samples.mean(), samples3.samples.mean(), rtol=1e-5)
1538+
15141539
# ============ Start testing sampler callback ============
15151540
# Samplers that should be tested for callback
15161541
callback_testing_sampler_classes = [

0 commit comments

Comments
 (0)