Skip to content

Commit 0303b8f

Browse files
authored
Merge pull request #612 from jeverink/Regularized-Gaussian-TVwNonnegativity
Regularized Gaussian: Combined regularized and constraint presets
2 parents ead079a + d58450e commit 0303b8f

File tree

8 files changed

+220
-72
lines changed

8 files changed

+220
-72
lines changed

cuqi/experimental/mcmc/_conjugate.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def validate_target(self):
147147
if self.target.prior.dim != 1:
148148
raise ValueError("RegularizedGaussian-Gamma conjugacy only works with univariate ModifiedHalfNormal prior")
149149

150-
if self.target.likelihood.distribution.preset not in ["nonnegativity"]:
150+
if self.target.likelihood.distribution.preset["constraint"] not in ["nonnegativity"]:
151151
raise ValueError("RegularizedGaussian-Gamma conjugacy only works with implicit regularized Gaussian likelihood with nonnegativity constraints")
152152

153153
key_value_pairs = _get_conjugate_parameter(self.target)
@@ -183,7 +183,7 @@ def validate_target(self):
183183
if self.target.prior.dim != 1:
184184
raise ValueError("RegularizedUnboundedUniform-Gamma conjugacy only works with univariate Gamma prior")
185185

186-
if self.target.likelihood.distribution.preset not in ["l1", "tv"]:
186+
if self.target.likelihood.distribution.preset["regularization"] not in ["l1", "tv"]:
187187
raise ValueError("RegularizedUnboundedUniform-Gamma conjugacy only works with implicit regularized Gaussian likelihood with l1 or tv regularization")
188188

189189
key_value_pairs = _get_conjugate_parameter(self.target)
@@ -203,12 +203,7 @@ def conjugate_distribution(self):
203203

204204
# Compute likelihood quantities
205205
x = self.target.likelihood.data
206-
if self.target.likelihood.distribution.preset == "l1":
207-
m = count_nonzero(x)
208-
elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, Continuous1D):
209-
m = count_constant_components_1D(x)
210-
elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
211-
m = count_constant_components_2D(self.target.likelihood.distribution.geometry.par2fun(x))
206+
m = _compute_sparsity_level(self.target)
212207

213208
reg_op = self.target.likelihood.distribution._regularization_oper
214209
reg_strength = self.target.likelihood.distribution(np.array([1])).strength
@@ -224,7 +219,7 @@ def validate_target(self):
224219
if self.target.prior.dim != 1:
225220
raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy only works with univariate ModifiedHalfNormal prior")
226221

227-
if self.target.likelihood.distribution.preset not in ["l1", "tv"]:
222+
if self.target.likelihood.distribution.preset["regularization"] not in ["l1", "tv"]:
228223
raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy only works with implicit regularized Gaussian likelihood with l1 or tv regularization")
229224

230225
key_value_pairs = _get_conjugate_parameter(self.target)
@@ -254,13 +249,8 @@ def conjugate_distribution(self):
254249
x = self.target.likelihood.data
255250
mu = self.target.likelihood.distribution.mean
256251
L = self.target.likelihood.distribution(np.array([1])).sqrtprec
257-
258-
if self.target.likelihood.distribution.preset == "l1":
259-
m = count_nonzero(x)
260-
elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, Continuous1D):
261-
m = count_constant_components_1D(x)
262-
elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
263-
m = count_constant_components_2D(self.target.likelihood.distribution.geometry.par2fun(x))
252+
253+
m = _compute_sparsity_level(self.target)
264254

265255
reg_op = self.target.likelihood.distribution._regularization_oper
266256
reg_strength = self.target.likelihood.distribution(np.array([1])).strength
@@ -275,6 +265,26 @@ def conjugate_distribution(self):
275265
return ModifiedHalfNormal(conj_alpha, conj_beta, conj_gamma)
276266

277267

268+
def _compute_sparsity_level(target):
269+
"""Computes the sparsity level in accordance with Section 4 from [2],"""
270+
x = target.likelihood.data
271+
if target.likelihood.distribution.preset["constraint"] == "nonnegativity":
272+
if target.likelihood.distribution.preset["regularization"] == "l1":
273+
m = count_nonzero(x)
274+
elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
275+
m = count_constant_components_1D(x, lower = 0.0)
276+
elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
277+
m = count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x), lower = 0.0)
278+
else: # No constraints, only regularization
279+
if target.likelihood.distribution.preset["regularization"] == "l1":
280+
m = count_nonzero(x)
281+
elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
282+
m = count_constant_components_1D(x)
283+
elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
284+
m = count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x))
285+
return m
286+
287+
278288
def _get_conjugate_parameter(target):
279289
"""Extract the conjugate parameter name (e.g. d), and returns the mutable variable that is defined by the conjugate parameter, e.g. cov and its value e.g. lambda d:1/d"""
280290
par_name = target.prior.name

cuqi/experimental/mcmc/_rto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def solver(self):
239239
@solver.setter
240240
def solver(self, value):
241241
if value == "ScipyLinearLSQ":
242-
if (self.target.prior._preset == "nonnegativity" or self.target.prior._preset == "box"):
242+
if (self.target.prior.preset["constraint"] == "nonnegativity" or self.target.prior.preset["constraint"] == "box"):
243243
self._solver = value
244244
else:
245245
raise ValueError("ScipyLinearLSQ only supports RegularizedGaussian with box or nonnegativity constraint.")

cuqi/implicitprior/_regularizedGMRF.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(self, mean=None, prec=None, bc_type='zero', order=1, proximal = Non
6868
# Init from abstract distribution class
6969
super(Distribution, self).__init__(**kwargs)
7070

71+
self._force_list = False
7172
self._parse_regularization_input_arguments(proximal, projector, constraint, regularization, args)
7273

7374

cuqi/implicitprior/_regularizedGaussian.py

Lines changed: 128 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from cuqi.distribution import Distribution, Gaussian
33
from cuqi.solver import ProjectNonnegative, ProjectBox, ProximalL1
44
from cuqi.geometry import Continuous1D, Continuous2D, Image2D
5-
from cuqi.operator import FirstOrderFiniteDifference
5+
from cuqi.operator import FirstOrderFiniteDifference, Operator
66

77
import numpy as np
8+
import scipy.sparse as sparse
89
from copy import copy
910

1011

@@ -48,6 +49,8 @@ class RegularizedGaussian(Distribution):
4849
min_z 0.5||x-z||_2^2+scale*g(x).
4950
If list of tuples (callable proximal operator of f_i, linear operator L_i):
5051
Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
52+
Each linear operator needs to have the '__matmul__', 'T' and 'shape' attributes;
53+
this includes numpy.ndarray, scipy.sparse.sparray, scipy.sparse.linalg.LinearOperator and cuqi.operator.Operator.
5154
The corresponding regularization takes the form
5255
sum_i f_i(L_i x),
5356
where the sum ranges from 1 to an arbitrary n.
@@ -88,59 +91,137 @@ def __init__(self, mean=None, cov=None, prec=None, sqrtcov=None, sqrtprec=None,
8891
# Init from abstract distribution class
8992
super().__init__(**kwargs)
9093

94+
self._force_list = False
9195
self._parse_regularization_input_arguments(proximal, projector, constraint, regularization, optional_regularization_parameters)
9296

9397
def _parse_regularization_input_arguments(self, proximal, projector, constraint, regularization, optional_regularization_parameters):
9498
""" Parse regularization input arguments with guarding statements and store internal states """
9599

96-
# Check that only one of proximal, projector, constraint or regularization is provided
97-
if (proximal is not None) + (projector is not None) + (constraint is not None) + (regularization is not None) != 1:
98-
raise ValueError("Precisely one of proximal, projector, constraint or regularization needs to be provided.")
100+
# Guards checking whether the regularization inputs are valid
101+
if (proximal is not None) + (projector is not None) + max((constraint is not None), (regularization is not None)) == 0:
102+
raise ValueError("At least some constraint or regularization has to be specified for RegularizedGaussian")
103+
104+
if (proximal is not None) + (projector is not None) == 2:
105+
raise ValueError("Only one of proximal or projector can be used.")
106+
107+
if (proximal is not None) + (projector is not None) + max((constraint is not None), (regularization is not None)) > 1:
108+
raise ValueError("User-defined proximals and projectors cannot be combined with pre-defined constraints and regularization.")
109+
110+
# Branch between user-defined and preset
111+
if (proximal is not None) + (projector is not None) >= 1:
112+
self._parse_user_specified_input(proximal, projector)
113+
else:
114+
# Set constraint and regularization presets for use with Gibbs
115+
self._preset = {"constraint": None,
116+
"regularization": None}
99117

118+
self._parse_preset_constraint_input(constraint, optional_regularization_parameters)
119+
self._parse_preset_regularization_input(regularization, optional_regularization_parameters)
120+
121+
# Merge
122+
self._merge_predefined_option()
123+
124+
def _parse_user_specified_input(self, proximal, projector):
125+
# Guard for checking partial validy of proximals or projectors
126+
if proximal is not None:
127+
if callable(proximal):
128+
if len(get_non_default_args(proximal)) != 2:
129+
raise ValueError("Proximal should take 2 arguments.")
130+
elif isinstance(proximal, list):
131+
for val in proximal:
132+
if len(val) != 2:
133+
raise ValueError("Each value in the proximal list needs to consistent of two elements: a proximal operator and a linear operator.")
134+
if callable(val[0]):
135+
if len(get_non_default_args(val[0])) != 2:
136+
raise ValueError("Proximal should take 2 arguments.")
137+
else:
138+
raise ValueError("Proximal operators need to be callable.")
139+
if not (hasattr(val[1], '__matmul__') and hasattr(val[1], 'T') and hasattr(val[1], 'shape')):
140+
raise ValueError("Linear operator not supported, must have '__matmul__', 'T' and 'shape' attributes.")
141+
else:
142+
raise ValueError("Proximal needs to be callable or a list. See documentation.")
143+
100144
if projector is not None:
101-
if not callable(projector):
102-
raise ValueError("Projector needs to be callable.")
103-
if len(get_non_default_args(projector)) != 1:
104-
raise ValueError("Projector should take 1 argument.")
145+
if callable(projector):
146+
if len(get_non_default_args(projector)) != 1:
147+
raise ValueError("Projector should take 1 argument.")
148+
else:
149+
raise ValueError("Projector needs to be callable")
105150

106-
# Preset information, for use in Gibbs
107-
self._preset = None
108-
151+
# Set user-defined proximals or projectors
109152
if proximal is not None:
110-
# No need to generate the proximal and associated information
111-
self.proximal = proximal
112-
elif projector is not None:
153+
self._preset = None
154+
self._proximal = proximal
155+
return
156+
157+
if projector is not None:
158+
self._preset = None
113159
self._proximal = lambda z, gamma: projector(z)
114-
elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
115-
self._proximal = lambda z, gamma: ProjectNonnegative(z)
116-
self._preset = "nonnegativity"
117-
self._box_bounds = (np.ones(self.dim)*0, np.ones(self.dim)*np.inf)
118-
elif (isinstance(constraint, str) and constraint.lower() == "box"):
119-
self._box_lower = optional_regularization_parameters["lower_bound"]
120-
self._box_upper = optional_regularization_parameters["upper_bound"]
121-
self._box_bounds = (np.ones(self.dim)*self._box_lower, np.ones(self.dim)*self._box_upper)
122-
self._proximal = lambda z, _: ProjectBox(z, self._box_lower, self._box_upper)
123-
self._preset = "box" # Not supported in Gibbs
124-
elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
125-
self._strength = optional_regularization_parameters["strength"]
126-
self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
127-
self._preset = "l1"
128-
elif (isinstance(regularization, str) and regularization.lower() in ["tv"]):
160+
return
161+
162+
def _parse_preset_constraint_input(self, constraint, optional_regularization_parameters):
163+
# Create data for constraints
164+
self._constraint_prox = None
165+
self._constraint_oper = None
166+
if constraint is not None:
167+
if not isinstance(constraint, str):
168+
raise ValueError("Constraint needs to be specified as a string.")
169+
170+
c_lower = constraint.lower()
171+
if c_lower == "nonnegativity":
172+
self._constraint_prox = lambda z, gamma: ProjectNonnegative(z)
173+
self._box_bounds = (np.ones(self.dim)*0, np.ones(self.dim)*np.inf)
174+
self._preset["constraint"] = "nonnegativity"
175+
elif c_lower == "box":
176+
_box_lower = optional_regularization_parameters["lower_bound"]
177+
_box_upper = optional_regularization_parameters["upper_bound"]
178+
self._proximal = lambda z, _: ProjectBox(z, _box_lower, _box_upper)
179+
self._box_bounds = (np.ones(self.dim)*_box_lower, np.ones(self.dim)*_box_upper)
180+
self._preset["constraint"] = "box"
181+
else:
182+
raise ValueError("Constraint not supported.")
183+
184+
def _parse_preset_regularization_input(self, regularization, optional_regularization_parameters):
185+
# Create data for regularization
186+
self._regularization_prox = None
187+
self._regularization_oper = None
188+
if regularization is not None:
189+
if not isinstance(regularization, str):
190+
raise ValueError("Regularization needs to be specified as a string.")
191+
129192
self._strength = optional_regularization_parameters["strength"]
130-
if isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
131-
self._transformation = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
193+
r_lower = regularization.lower()
194+
if r_lower == "l1":
195+
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
196+
self._preset["regularization"] = "l1"
197+
elif r_lower == "tv":
198+
# Store the transformation to reuse when modifying the strength
199+
if not isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
200+
raise ValueError("Geometry not supported for total variation")
201+
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
202+
self._regularization_oper = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
203+
self._preset["regularization"] = "tv"
132204
else:
133-
raise ValueError("Geometry not supported for total variation")
134-
135-
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
136-
self._regularization_oper = self._transformation
137-
138-
self._proximal = [(self._regularization_prox, self._regularization_oper)]
139-
self._preset = "tv"
140-
else:
141-
raise ValueError("Regularization not supported")
142-
205+
raise ValueError("Regularization not supported.")
143206

207+
def _merge_predefined_option(self):
208+
# Check whether it is a single proximal and hence FISTA could be used in RegularizedLinearRTO
209+
if ((not self._force_list) and
210+
((self._constraint_prox is not None) + (self._regularization_prox is not None) == 1) and
211+
((self._constraint_oper is not None) + (self._regularization_oper is not None) == 0)):
212+
if self._constraint_prox is not None:
213+
self._proximal = self._constraint_prox
214+
else:
215+
self._proximal = self._regularization_prox
216+
return
217+
218+
# Merge regularization choices in list for use in ADMM by RegularizedLinearRTO
219+
self._proximal = []
220+
if self._constraint_prox is not None:
221+
self._proximal += [(self._constraint_prox, self._constraint_oper if self._constraint_oper is not None else sparse.eye(self.geometry.par_dim))]
222+
if self._regularization_prox is not None:
223+
self._proximal += [(self._regularization_prox, self._regularization_oper if self._regularization_oper is not None else sparse.eye(self.geometry.par_dim))]
224+
144225
@property
145226
def transformation(self):
146227
return self._transformation
@@ -151,15 +232,15 @@ def strength(self):
151232

152233
@strength.setter
153234
def strength(self, value):
154-
if self._preset not in self.regularization_options():
235+
if self._preset is None or self._preset["regularization"] is None:
155236
raise TypeError("Strength is only used when the regularization is set to l1 or TV.")
156237

157238
self._strength = value
158-
if self._preset == "tv":
239+
if self._preset["regularization"] in ["l1", "tv"]:
159240
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
160-
self._proximal = [(self._regularization_prox, self._regularization_oper)]
161-
elif self._preset == "l1":
162-
self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
241+
242+
# Create new list of proximals based on updated regularization
243+
self._merge_predefined_option()
163244

164245
# This is a getter only attribute for the underlying Gaussian
165246
# It also ensures that the name of the underlying Gaussian
@@ -266,7 +347,7 @@ def sqrtcov(self, value):
266347

267348
def get_mutable_variables(self):
268349
mutable_vars = self.gaussian.get_mutable_variables().copy()
269-
if self.preset in self.regularization_options():
350+
if self.preset is not None and self.preset['regularization'] in ["l1", "tv"]:
270351
mutable_vars += ["strength"]
271352
return mutable_vars
272353

cuqi/implicitprior/_regularizedUnboundedUniform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,5 @@ def __init__(self, geometry, proximal = None, projector = None, constraint = Non
6363
# Init from abstract distribution class
6464
super(Distribution, self).__init__(**kwargs)
6565

66+
self._force_list = False
6667
self._parse_regularization_input_arguments(proximal, projector, constraint, regularization, args)

0 commit comments

Comments
 (0)