Skip to content

Commit e495208

Browse files
authored
Merge pull request #564 from jeverink/egularizedGaussian_extension-TV
Regularized Gaussian extension: TV
2 parents 42e096c + 116be06 commit e495208

File tree

5 files changed

+151
-35
lines changed

5 files changed

+151
-35
lines changed

cuqi/experimental/mcmc/_rto.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
44
import numpy as np
55
import cuqi
6-
from cuqi.solver import CGLS, FISTA
6+
from cuqi.solver import CGLS, FISTA, ADMM
77
from cuqi.experimental.mcmc import Sampler
88

99

@@ -161,6 +161,13 @@ class RegularizedLinearRTO(LinearRTO):
161161
Regularized Linear RTO (Randomize-Then-Optimize) sampler.
162162
163163
Samples posterior related to the inverse problem with Gaussian likelihood and implicit Gaussian prior, and where the forward model is Linear.
164+
The sampler works by repeatedly solving regularized linear least squares problems for perturbed data.
165+
The solver for these optimization problems is chosen based on how the regularized is provided in the implicit Gaussian prior.
166+
Currently we use the following solvers:
167+
FISTA: [1] Beck, Amir, and Marc Teboulle. "A fast iterative shrinkage-thresholding algorithm for linear inverse problems." SIAM journal on imaging sciences 2.1 (2009): 183-202.
168+
Used when prior.proximal is callable.
169+
ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
170+
Used when prior.proximal is a list of penalty terms.
164171
165172
Parameters
166173
------------
@@ -171,12 +178,19 @@ class RegularizedLinearRTO(LinearRTO):
171178
Initial point for the sampler. *Optional*.
172179
173180
maxit : int
174-
Maximum number of iterations of the inner FISTA solver. *Optional*.
181+
Maximum number of iterations of the inner FISTA/ADMM solver. *Optional*.
182+
183+
inner_max_it : int
184+
Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
175185
176186
stepsize : string or float
177187
If stepsize is a string and equals either "automatic", then the stepsize is automatically estimated based on the spectral norm.
178188
If stepsize is a float, then this stepsize is used.
179189
190+
penalty_parameter : int
191+
Penalty parameter of the inner ADMM solver. *Optional*.
192+
See [2] or `cuqi.solver.ADMM`
193+
180194
abstol : float
181195
Absolute tolerance of the inner FISTA solver. *Optional*.
182196
@@ -190,7 +204,7 @@ class RegularizedLinearRTO(LinearRTO):
190204
An example is shown in demos/demo31_callback.py.
191205
192206
"""
193-
def __init__(self, target=None, initial_point=None, maxit=100, stepsize="automatic", abstol=1e-10, adaptive=True, **kwargs):
207+
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, **kwargs):
194208

195209
super().__init__(target=target, initial_point=initial_point, **kwargs)
196210

@@ -199,10 +213,13 @@ def __init__(self, target=None, initial_point=None, maxit=100, stepsize="automat
199213
self.abstol = abstol
200214
self.adaptive = adaptive
201215
self.maxit = maxit
216+
self.inner_max_it = inner_max_it
217+
self.penalty_parameter = penalty_parameter
202218

203219
def _initialize(self):
204220
super()._initialize()
205-
self._stepsize = self._choose_stepsize()
221+
if self._inner_solver == "FISTA":
222+
self._stepsize = self._choose_stepsize()
206223

207224
@property
208225
def proximal(self):
@@ -212,8 +229,7 @@ def validate_target(self):
212229
super().validate_target()
213230
if not isinstance(self.target.prior, (cuqi.implicitprior.RegularizedGaussian, cuqi.implicitprior.RegularizedGMRF)):
214231
raise TypeError("Prior needs to be RegularizedGaussian or RegularizedGMRF")
215-
if not callable(self.proximal):
216-
raise TypeError("Proximal needs to be callable")
232+
self._inner_solver = "FISTA" if callable(self.proximal) else "ADMM"
217233

218234
def _choose_stepsize(self):
219235
if isinstance(self.stepsize, str):
@@ -237,8 +253,16 @@ def prior(self):
237253

238254
def step(self):
239255
y = self.b_tild + np.random.randn(len(self.b_tild))
240-
sim = FISTA(self.M, y, self.proximal,
241-
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
256+
257+
if self._inner_solver == "FISTA":
258+
sim = FISTA(self.M, y, self.proximal,
259+
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
260+
elif self._inner_solver == "ADMM":
261+
sim = ADMM(self.M, y, self.proximal,
262+
self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
263+
else:
264+
raise ValueError("Choice of solver not supported.")
265+
242266
self.current_point, _ = sim.solve()
243267
acc = 1
244268
return acc

cuqi/implicitprior/_regularizedGMRF.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, mean=None, prec=None, bc_type='zero', order=1, proximal = Non
6363

6464
# Underlying explicit GMRF
6565
self._gaussian = GMRF(mean, prec, bc_type=bc_type, order=order, **kwargs)
66+
kwargs.pop("geometry", None)
6667

6768
# Init from abstract distribution class
6869
super(Distribution, self).__init__(**kwargs)

cuqi/implicitprior/_regularizedGaussian.py

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from cuqi.utilities import get_non_default_args
22
from cuqi.distribution import Distribution, Gaussian
33
from cuqi.solver import ProjectNonnegative, ProjectBox, ProximalL1
4+
from cuqi.geometry import Continuous1D, Continuous2D, Image2D
5+
from cuqi.operator import FirstOrderFiniteDifference
46

57
import numpy as np
68

@@ -39,28 +41,33 @@ class RegularizedGaussian(Distribution):
3941
sqrtprec
4042
See :class:`~cuqi.distribution.Gaussian` for details.
4143
42-
proximal : callable f(x, scale) or None
43-
Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
44-
min_z 0.5||x-z||_2^2+scale*g(x).
45-
44+
proximal : callable f(x, scale), list of tuples (callable proximal operator of f_i, linear operator L_i) or None
45+
If callable:
46+
Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
47+
min_z 0.5||x-z||_2^2+scale*g(x).
48+
If list of tuples (callable proximal operator of f_i, linear operator L_i):
49+
Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
50+
The corresponding regularization takes the form
51+
sum_i f_i(L_i x),
52+
where the sum ranges from 1 to an arbitrary n.
4653
4754
projector : callable f(x) or None
4855
Euclidean projection onto the constraint C, that is, a solver for the optimization problem
4956
min_(z in C) 0.5||x-z||_2^2.
5057
5158
constraint : string or None
52-
Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
59+
Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
5360
For "box", the following additional parameters can be passed:
5461
lower_bound : array_like or None
5562
Lower bound of box, defaults to zero
5663
upper_bound : array_like
5764
Upper bound of box, defaults to one
5865
5966
regularization : string or None
60-
Preset regularization. Can be set to "l1". Required for use in Gibbs in future update.
61-
For "l1", the following additional parameters can be passed:
67+
Preset regularization that generate the corresponding proximal parameter. Can be set to "l1" or 'tv'. Required for use in Gibbs in future update.
68+
For "l1" or "tv", the following additional parameters can be passed:
6269
strength : scalar
63-
Regularization parameter, i.e., strength*||x||_1 , defaults to one
70+
Regularization parameter, i.e., strength*||Lx||_1, defaults to one
6471
6572
"""
6673

@@ -75,6 +82,7 @@ def __init__(self, mean=None, cov=None, prec=None, sqrtcov=None, sqrtprec=None,
7582

7683
# We init the underlying Gaussian first for geometry and dimensionality handling
7784
self._gaussian = Gaussian(mean=mean, cov=cov, prec=prec, sqrtcov=sqrtcov, sqrtprec=sqrtprec, **kwargs)
85+
kwargs.pop("geometry", None)
7886

7987
# Init from abstract distribution class
8088
super().__init__(**kwargs)
@@ -88,12 +96,6 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
8896
if (proximal is not None) + (projector is not None) + (constraint is not None) + (regularization is not None) != 1:
8997
raise ValueError("Precisely one of proximal, projector, constraint or regularization needs to be provided.")
9098

91-
if proximal is not None:
92-
if not callable(proximal):
93-
raise ValueError("Proximal needs to be callable.")
94-
if len(get_non_default_args(proximal)) != 2:
95-
raise ValueError("Proximal should take 2 arguments.")
96-
9799
if projector is not None:
98100
if not callable(projector):
99101
raise ValueError("Projector needs to be callable.")
@@ -104,7 +106,8 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
104106
self._preset = None
105107

106108
if proximal is not None:
107-
self._proximal = proximal
109+
# No need to generate the proximal and associated information
110+
self.proximal = proximal
108111
elif projector is not None:
109112
self._proximal = lambda z, gamma: projector(z)
110113
elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
@@ -113,15 +116,48 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
113116
elif (isinstance(constraint, str) and constraint.lower() == "box"):
114117
lower = optional_regularization_parameters["lower_bound"]
115118
upper = optional_regularization_parameters["upper_bound"]
116-
self._proximal = lambda z, gamma: ProjectBox(z, lower, upper)
119+
self._proximal = lambda z, _: ProjectBox(z, lower, upper)
117120
self._preset = "box" # Not supported in Gibbs
118121
elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
119-
strength = optional_regularization_parameters["strength"]
120-
self._proximal = lambda z, gamma: ProximalL1(z, gamma*strength)
122+
self._strength = optional_regularization_parameters["strength"]
123+
self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
121124
self._preset = "l1"
125+
elif (isinstance(regularization, str) and regularization.lower() in ["tv"]):
126+
self._strength = optional_regularization_parameters["strength"]
127+
if isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
128+
self._transformation = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
129+
else:
130+
raise ValueError("Geometry not supported for total variation")
131+
132+
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
133+
self._regularization_oper = self._transformation
134+
135+
self._proximal = [(self._regularization_prox, self._regularization_oper)]
136+
self._preset = "tv"
122137
else:
123138
raise ValueError("Regularization not supported")
124139

140+
141+
@property
142+
def transformation(self):
143+
return self._transformation
144+
145+
@property
146+
def strength(self):
147+
return self._strength
148+
149+
@strength.setter
150+
def strength(self, value):
151+
if self._preset not in self.regularization_options():
152+
raise TypeError("Strength is only used when the regularization is set to l1 or TV.")
153+
154+
self._strength = value
155+
if self._preset == "tv":
156+
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
157+
self._proximal = [(self._regularization_prox, self._regularization_oper)]
158+
elif self._preset == "l1":
159+
self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
160+
125161
# This is a getter only attribute for the underlying Gaussian
126162
# It also ensures that the name of the underlying Gaussian
127163
# matches the name of the implicit regularized Gaussian
@@ -135,6 +171,25 @@ def gaussian(self):
135171
def proximal(self):
136172
return self._proximal
137173

174+
@proximal.setter
175+
def proximal(self, value):
176+
if callable(value):
177+
if len(get_non_default_args(value)) != 2:
178+
raise ValueError("Proximal should take 2 arguments.")
179+
elif isinstance(value, list):
180+
for (prox, op) in value:
181+
if len(get_non_default_args(prox)) != 2:
182+
raise ValueError("Proximal should take 2 arguments.")
183+
if op.shape[1] != self.geometry.par_dim:
184+
raise ValueError("Incorrect shape of linear operator in proximal list.")
185+
else:
186+
raise ValueError("Proximal needs to be callable or a list. See documentation.")
187+
188+
self._proximal = value
189+
190+
# For all the presets, self._proximal is set directly,
191+
self._preset = None
192+
138193
@property
139194
def preset(self):
140195
return self._preset
@@ -154,7 +209,7 @@ def constraint_options():
154209

155210
@staticmethod
156211
def regularization_options():
157-
return ["l1"]
212+
return ["l1", "tv"]
158213

159214

160215
# --- Defer behavior of the underlying Gaussian --- #
@@ -206,16 +261,18 @@ def sqrtcov(self):
206261
def sqrtcov(self, value):
207262
self.gaussian.sqrtcov = value
208263

209-
def get_conditioning_variables(self):
210-
return self.gaussian.get_conditioning_variables()
211-
212264
def get_mutable_variables(self):
213-
return self.gaussian.get_mutable_variables()
265+
mutable_vars = self.gaussian.get_mutable_variables().copy()
266+
if self.preset in self.regularization_options():
267+
mutable_vars += ["strength"]
268+
return mutable_vars
214269

215270
# Overwrite the condition method such that the underlying Gaussian is conditioned in general, except when conditioning on self.name
216271
# which means we convert Distribution to Likelihood or EvaluatedDensity.
217272
def _condition(self, *args, **kwargs):
218-
273+
if self.preset in self.regularization_options():
274+
return super()._condition(*args, **kwargs)
275+
219276
# Handle positional arguments (similar code as in Distribution._condition)
220277
cond_vars = self.get_conditioning_variables()
221278
kwargs = self._parse_args_add_to_kwargs(cond_vars, *args, **kwargs)
@@ -275,7 +332,7 @@ class ConstrainedGaussian(RegularizedGaussian):
275332
min_(z in C) 0.5||x-z||_2^2.
276333
277334
constraint : string or None
278-
Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
335+
Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
279336
For "box", the following additional parameters can be passed:
280337
lower_bound : array_like or None
281338
Lower bound of box, defaults to zero

cuqi/solver/_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ class ADMM(object):
669669
- flag=2 indicates multiplication of the transpose of A with vector x, that is A.T @ x.
670670
b : ndarray.
671671
penalty_terms : List of tuples (callable proximal operator of f_i, linear operator L_i)
672-
Each callable proximal operator f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
672+
Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
673673
x0 : ndarray. Initial guess.
674674
penalty_parameter : Trade-off between linear least squares and regularization term in the solver iterates. Denoted as "rho" in [1].
675675
maxit : The maximum number of iterations.

tests/test_implicit_priors.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_RegularizedGaussian_guarding_statements():
1616
cuqi.implicitprior.RegularizedGaussian(np.zeros(5), 1, proximal=lambda s,z: s, constraint="nonnegativity")
1717

1818
# Proximal
19-
with pytest.raises(ValueError, match="Proximal needs to be callable"):
19+
with pytest.raises(ValueError, match="Proximal needs to be callable or a list. See documentation."):
2020
cuqi.implicitprior.RegularizedGaussian(np.zeros(5), 1, proximal=1)
2121

2222
with pytest.raises(ValueError, match="Proximal should take 2 arguments"):
@@ -104,3 +104,37 @@ def test_RegularizedUnboundedUniform_is_RegularizedGaussian():
104104
x = cuqi.implicitprior.RegularizedUnboundedUniform(cuqi.geometry.Continuous1D(5), regularization="l1", strength = 5.0)
105105

106106
assert np.allclose(x.gaussian.sqrtprec, 0.0)
107+
108+
def test_RegularizedGaussian_conditioning_constrained():
109+
""" Test that conditioning the implicit regularized Gaussian works as expected """
110+
111+
x = cuqi.implicitprior.RegularizedGMRF(lambda a:a*np.ones(2**2),
112+
prec = lambda b:5*b,
113+
constraint = "nonnegativity",
114+
geometry = cuqi.geometry.Image2D((2,2)))
115+
116+
assert x.get_mutable_variables() == ['mean', 'prec']
117+
assert x.get_conditioning_variables() == ['a', 'b']
118+
119+
x = x(a=1, b=2)
120+
121+
assert np.allclose(x.mean, [1, 1, 1, 1])
122+
assert np.allclose(x.prec, 10)
123+
124+
def test_RegularizedGaussian_conditioning_strength():
125+
""" Test that conditioning the implicit regularized Gaussian works as expected """
126+
127+
x = cuqi.implicitprior.RegularizedGMRF(lambda a:a*np.ones(2**2),
128+
prec = lambda b:5*b,
129+
regularization = "tv",
130+
strength = lambda c:c*2,
131+
geometry = cuqi.geometry.Image2D((2,2)))
132+
133+
assert x.get_mutable_variables() == ['mean', 'prec', 'strength']
134+
assert x.get_conditioning_variables() == ['a', 'b', 'c']
135+
136+
x = x(a=1, b=2, c=3)
137+
138+
assert np.allclose(x.mean, [1, 1, 1, 1])
139+
assert np.allclose(x.prec, 10)
140+
assert np.allclose(x.strength, 6)

0 commit comments

Comments
 (0)