Skip to content

Commit 0b7a726

Browse files
authored
Merge pull request #602 from CUQI-DTU/add_solvers_to_rrto
Use scipy.optimize.lsq_linear in RegularizedLinearRTO for nonnegativity
2 parents 4d0949e + 198dfb9 commit 0b7a726

File tree

5 files changed

+136
-23
lines changed

5 files changed

+136
-23
lines changed

cuqi/experimental/mcmc/_rto.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
44
import numpy as np
55
import cuqi
6-
from cuqi.solver import CGLS, FISTA, ADMM
6+
from cuqi.solver import CGLS, FISTA, ADMM, ScipyLinearLSQ
77
from cuqi.experimental.mcmc import Sampler
88

99

@@ -168,6 +168,7 @@ class RegularizedLinearRTO(LinearRTO):
168168
Used when prior.proximal is callable.
169169
ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
170170
Used when prior.proximal is a list of penalty terms.
171+
ScipyLinearLSQ: Wrapper for Scipy's lsq_linear for the Trust Region Reflective algorithm. Optionally used when the constraint is either "nonnegativity" or "box".
171172
172173
Parameters
173174
------------
@@ -178,7 +179,7 @@ class RegularizedLinearRTO(LinearRTO):
178179
Initial point for the sampler. *Optional*.
179180
180181
maxit : int
181-
Maximum number of iterations of the inner FISTA/ADMM solver. *Optional*.
182+
Maximum number of iterations of the FISTA/ADMM/ScipyLinearLSQ solver. *Optional*.
182183
183184
inner_max_it : int
184185
Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
@@ -188,14 +189,20 @@ class RegularizedLinearRTO(LinearRTO):
188189
If stepsize is a float, then this stepsize is used.
189190
190191
penalty_parameter : int
191-
Penalty parameter of the inner ADMM solver. *Optional*.
192+
Penalty parameter of the ADMM solver. *Optional*.
192193
See [2] or `cuqi.solver.ADMM`
193194
194195
abstol : float
195-
Absolute tolerance of the inner FISTA solver. *Optional*.
196+
Absolute tolerance of the FISTA/ScipyLinearLSQ solver. *Optional*.
197+
198+
inner_abstol : float
199+
Tolerance parameter for ScipyLinearLSQ's inner solve of the unbounded least-squares problem. *Optional*.
196200
197201
adaptive : bool
198-
If True, FISTA is used as inner solver, otherwise ISTA is used. *Optional*.
202+
If True, FISTA is used as solver, otherwise ISTA is used. *Optional*.
203+
204+
solver : string
205+
If set to "ScipyLinearLSQ", solver is set to cuqi.solver.ScipyLinearLSQ, otherwise FISTA/ISTA or ADMM is used. Note "ScipyLinearLSQ" can only be used with `RegularizedGaussian` of `box` or `nonnegativity` constraint. *Optional*.
199206
200207
callback : callable, *Optional*
201208
If set this function will be called after every sample.
@@ -204,23 +211,41 @@ class RegularizedLinearRTO(LinearRTO):
204211
An example is shown in demos/demo31_callback.py.
205212
206213
"""
207-
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, **kwargs):
214+
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, **kwargs):
208215

209216
super().__init__(target=target, initial_point=initial_point, **kwargs)
210217

211218
# Other parameters
212219
self.stepsize = stepsize
213-
self.abstol = abstol
220+
self.abstol = abstol
221+
self.inner_abstol = inner_abstol
214222
self.adaptive = adaptive
215223
self.maxit = maxit
216224
self.inner_max_it = inner_max_it
217225
self.penalty_parameter = penalty_parameter
226+
self.solver = solver
218227

219228
def _initialize(self):
220229
super()._initialize()
221-
if self._inner_solver == "FISTA":
230+
if self.solver is None:
231+
self.solver = "FISTA" if callable(self.proximal) else "ADMM"
232+
if self.solver == "FISTA":
222233
self._stepsize = self._choose_stepsize()
223234

235+
@property
236+
def solver(self):
237+
return self._solver
238+
239+
@solver.setter
240+
def solver(self, value):
241+
if value == "ScipyLinearLSQ":
242+
if (self.target.prior._preset == "nonnegativity" or self.target.prior._preset == "box"):
243+
self._solver = value
244+
else:
245+
raise ValueError("ScipyLinearLSQ only supports RegularizedGaussian with box or nonnegativity constraint.")
246+
else:
247+
self._solver = value
248+
224249
@property
225250
def proximal(self):
226251
return self.target.prior.proximal
@@ -229,7 +254,6 @@ def validate_target(self):
229254
super().validate_target()
230255
if not isinstance(self.target.prior, (cuqi.implicitprior.RegularizedGaussian, cuqi.implicitprior.RegularizedGMRF)):
231256
raise TypeError("Prior needs to be RegularizedGaussian or RegularizedGMRF")
232-
self._inner_solver = "FISTA" if callable(self.proximal) else "ADMM"
233257

234258
def _choose_stepsize(self):
235259
if isinstance(self.stepsize, str):
@@ -254,15 +278,25 @@ def prior(self):
254278
def step(self):
255279
y = self.b_tild + np.random.randn(len(self.b_tild))
256280

257-
if self._inner_solver == "FISTA":
281+
if self.solver == "FISTA":
258282
sim = FISTA(self.M, y, self.proximal,
259283
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
260-
elif self._inner_solver == "ADMM":
284+
elif self.solver == "ADMM":
261285
sim = ADMM(self.M, y, self.proximal,
262-
self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
286+
self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
287+
elif self.solver == "ScipyLinearLSQ":
288+
A_op = sp.sparse.linalg.LinearOperator((sum([llh.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
289+
matvec=lambda x: self.M(x, 1),
290+
rmatvec=lambda x: self.M(x, 2)
291+
)
292+
sim = ScipyLinearLSQ(A_op, y, self.target.prior._box_bounds,
293+
max_iter = self.maxit,
294+
lsmr_maxiter = self.inner_max_it,
295+
tol = self.abstol,
296+
lsmr_tol = self.inner_abstol)
263297
else:
264298
raise ValueError("Choice of solver not supported.")
265299

266300
self.current_point, _ = sim.solve()
267301
acc = 1
268-
return acc
302+
return acc

cuqi/implicitprior/_regularizedGaussian.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
113113
elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
114114
self._proximal = lambda z, gamma: ProjectNonnegative(z)
115115
self._preset = "nonnegativity"
116+
self._box_bounds = (np.ones(self.dim)*0, np.ones(self.dim)*np.inf)
116117
elif (isinstance(constraint, str) and constraint.lower() == "box"):
117-
lower = optional_regularization_parameters["lower_bound"]
118-
upper = optional_regularization_parameters["upper_bound"]
119-
self._proximal = lambda z, _: ProjectBox(z, lower, upper)
118+
self._box_lower = optional_regularization_parameters["lower_bound"]
119+
self._box_upper = optional_regularization_parameters["upper_bound"]
120+
self._box_bounds = (np.ones(self.dim)*self._box_lower, np.ones(self.dim)*self._box_upper)
121+
self._proximal = lambda z, _: ProjectBox(z, self._box_lower, self._box_upper)
120122
self._preset = "box" # Not supported in Gibbs
121123
elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
122124
self._strength = optional_regularization_parameters["strength"]

cuqi/solver/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
ScipyLBFGSB,
33
ScipyMinimizer,
44
ScipyMaximizer,
5-
ScipyLeastSquares,
5+
ScipyLSQ,
6+
ScipyLinearLSQ,
67
CGLS,
78
LM,
89
PDHG,

cuqi/solver/_solver.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def ngradfunc(*args,**kwargs):
164164

165165

166166

167-
class ScipyLeastSquares(object):
167+
class ScipyLSQ(object):
168168
"""Wrapper for :meth:`scipy.optimize.least_squares`.
169169
170170
Solve nonlinear least-squares problems with bounds:
@@ -227,6 +227,44 @@ def solve(self):
227227
sol = solution['x']
228228
return sol, info
229229

230+
class ScipyLinearLSQ(object):
231+
"""Wrapper for :meth:`scipy.optimize.lsq_linear`.
232+
233+
Solve linear least-squares problems with bounds:
234+
235+
.. math::
236+
237+
\min \|A x - b\|_2^2
238+
239+
subject to :math:`lb <= x <= ub`.
240+
241+
Parameters
242+
----------
243+
A : ndarray, LinearOperator
244+
Design matrix (system matrix).
245+
b : ndarray
246+
The right-hand side of the linear system.
247+
bounds : 2-tuple of array_like or scipy.optimize Bounds
248+
Bounds for variables.
249+
kwargs : Other keyword arguments passed to Scipy's `lsq_linear`. See documentation of `scipy.optimize.lsq_linear` for details.
250+
"""
251+
def __init__(self, A, b, bounds=(-np.inf, np.inf), **kwargs):
252+
self.A = A
253+
self.b = b
254+
self.bounds = bounds
255+
self.kwargs = kwargs
256+
257+
def solve(self):
258+
"""Runs optimization algorithm and returns solution and optimization information.
259+
260+
Returns
261+
----------
262+
solution : Tuple
263+
Solution found (array_like) and optimization information (dictionary).
264+
"""
265+
res = opt.lsq_linear(self.A, self.b, bounds=self.bounds, **self.kwargs)
266+
x = res.pop('x')
267+
return x, res
230268

231269

232270
class CGLS(object):

tests/test_solver.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import scipy as sp
33

4-
from cuqi.solver import ScipyLBFGSB, ScipyMinimizer, ScipyLeastSquares, CGLS, LM, FISTA, ADMM, ProximalL1, ProjectNonnegative
4+
from cuqi.solver import ScipyLBFGSB, ScipyMinimizer, ScipyLSQ, ScipyLinearLSQ, CGLS, LM, FISTA, ADMM, ProximalL1, ProjectNonnegative, ProjectBox
55
from scipy.optimize import lsq_linear
66

77

@@ -54,28 +54,66 @@ def test_ScipyLBFGSB_with_gradient():
5454
sol_ref = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
5555
assert np.allclose(sol, sol_ref)
5656

57-
def test_ScipyLeastSquares_without_Jac():
57+
def test_ScipyLSQ_without_Jac():
5858
def fun_rosenbrock(x):
5959
return np.array([10 * (x[1] - x[0]**2), (1 - x[0])])
6060
x0 = np.array([2, 2])
61-
solver = ScipyLeastSquares(fun_rosenbrock, x0)
61+
solver = ScipyLSQ(fun_rosenbrock, x0)
6262
sol, _ = solver.solve()
6363
sol_ref = np.array([1, 1])
6464
assert np.allclose(sol, sol_ref)
6565

66-
def test_ScipyLeastSquares_with_Jac():
66+
def test_ScipyLSQ_with_Jac():
6767
def fun_rosenbrock(x):
6868
return np.array([10 * (x[1] - x[0]**2), (1 - x[0])])
6969
def jac_rosenbrock(x):
7070
return np.array([
7171
[-20 * x[0], 10],
7272
[-1, 0]])
7373
x0 = np.array([2, 2])
74-
solver = ScipyLeastSquares(fun_rosenbrock, x0, jacfun=jac_rosenbrock)
74+
solver = ScipyLSQ(fun_rosenbrock, x0, jacfun=jac_rosenbrock)
7575
sol, _ = solver.solve()
7676
sol_ref = np.array([1, 1])
7777
assert np.allclose(sol, sol_ref)
7878

79+
def test_ScipyLinearLSQ_with_matrix():
80+
rng = np.random.default_rng(seed = 1219)
81+
m, n = 10, 5
82+
A = rng.standard_normal((m, n))
83+
b = rng.standard_normal(m)
84+
res = lsq_linear(A, b, tol=1e-8)
85+
ref_sol = res.x
86+
sol, _ = ScipyLinearLSQ(A, b).solve()
87+
assert np.allclose(sol, ref_sol, rtol=1e-10)
88+
89+
def test_ScipyLinearLSQ_with_LinearOperator():
90+
rng = np.random.default_rng(seed = 1219)
91+
m, n = 10, 5
92+
A = rng.standard_normal((m, n))
93+
b = rng.standard_normal(m)
94+
A_op = sp.sparse.linalg.LinearOperator((m, n),
95+
matvec=lambda x: A @ x,
96+
rmatvec=lambda x: A.T @ x
97+
)
98+
res = lsq_linear(A, b, tol=1e-8)
99+
ref_sol = res.x
100+
sol, _ = ScipyLinearLSQ(A_op, b).solve()
101+
assert np.allclose(sol, ref_sol, rtol=1e-10)
102+
103+
def test_ScipyLinearLSQ_against_FISTA():
104+
A = np.array([[73,71,52],[87,74,46],[72,2,7],[80,89,71]])
105+
b = np.array([49,67,68,20])
106+
# solve with ScipyLinearLSQ
107+
lb = np.zeros(3)
108+
ub = lb + np.inf
109+
sol_lsq, _ = ScipyLinearLSQ(A, b, (lb,ub)).solve()
110+
# solve with FISTA
111+
rng = np.random.default_rng(seed = 1219)
112+
x0 = rng.standard_normal(3)
113+
sol_fista, _ = FISTA(A, b, lambda x, _: ProjectNonnegative(x), x0, stepsize=1e-7, maxit=100000, abstol=1e-16, adaptive=True).solve()
114+
115+
assert np.allclose(sol_lsq, sol_fista, rtol=1e-8)
116+
79117
def test_LM():
80118
# compare to MATLAB's original code solution
81119
t = np.arange(1, 10, 2)

0 commit comments

Comments
 (0)