11from cuqi .utilities import get_non_default_args
22from cuqi .distribution import Distribution , Gaussian
33from cuqi .solver import ProjectNonnegative , ProjectBox , ProximalL1
4+ from cuqi .geometry import Continuous1D , Continuous2D , Image2D
5+ from cuqi .operator import FirstOrderFiniteDifference
46
57import numpy as np
68
@@ -39,28 +41,33 @@ class RegularizedGaussian(Distribution):
3941 sqrtprec
4042 See :class:`~cuqi.distribution.Gaussian` for details.
4143
42- proximal : callable f(x, scale) or None
43- Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
44- min_z 0.5||x-z||_2^2+scale*g(x).
45-
44+ proximal : callable f(x, scale), list of tuples (callable proximal operator of f_i, linear operator L_i) or None
45+ If callable:
46+ Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
47+ min_z 0.5||x-z||_2^2+scale*g(x).
48+ If list of tuples (callable proximal operator of f_i, linear operator L_i):
49+ Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
50+ The corresponding regularization takes the form
51+ sum_i f_i(L_i x),
52+ where the sum ranges from 1 to an arbitrary n.
4653
4754 projector : callable f(x) or None
4855 Euclidean projection onto the constraint C, that is, a solver for the optimization problem
4956 min_(z in C) 0.5||x-z||_2^2.
5057
5158 constraint : string or None
52- Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
59+ Preset constraints that generate the corresponding proximal parameter . Can be set to "nonnegativity" and "box". Required for use in Gibbs.
5360 For "box", the following additional parameters can be passed:
5461 lower_bound : array_like or None
5562 Lower bound of box, defaults to zero
5663 upper_bound : array_like
5764 Upper bound of box, defaults to one
5865
5966 regularization : string or None
60- Preset regularization. Can be set to "l1". Required for use in Gibbs in future update.
61- For "l1", the following additional parameters can be passed:
67+ Preset regularization that generate the corresponding proximal parameter . Can be set to "l1" or 'tv' . Required for use in Gibbs in future update.
68+ For "l1" or "tv" , the following additional parameters can be passed:
6269 strength : scalar
63- Regularization parameter, i.e., strength*||x ||_1 , defaults to one
70+ Regularization parameter, i.e., strength*||Lx ||_1, defaults to one
6471
6572 """
6673
@@ -75,6 +82,7 @@ def __init__(self, mean=None, cov=None, prec=None, sqrtcov=None, sqrtprec=None,
7582
7683 # We init the underlying Gaussian first for geometry and dimensionality handling
7784 self ._gaussian = Gaussian (mean = mean , cov = cov , prec = prec , sqrtcov = sqrtcov , sqrtprec = sqrtprec , ** kwargs )
85+ kwargs .pop ("geometry" , None )
7886
7987 # Init from abstract distribution class
8088 super ().__init__ (** kwargs )
@@ -88,12 +96,6 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
8896 if (proximal is not None ) + (projector is not None ) + (constraint is not None ) + (regularization is not None ) != 1 :
8997 raise ValueError ("Precisely one of proximal, projector, constraint or regularization needs to be provided." )
9098
91- if proximal is not None :
92- if not callable (proximal ):
93- raise ValueError ("Proximal needs to be callable." )
94- if len (get_non_default_args (proximal )) != 2 :
95- raise ValueError ("Proximal should take 2 arguments." )
96-
9799 if projector is not None :
98100 if not callable (projector ):
99101 raise ValueError ("Projector needs to be callable." )
@@ -104,7 +106,8 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
104106 self ._preset = None
105107
106108 if proximal is not None :
107- self ._proximal = proximal
109+ # No need to generate the proximal and associated information
110+ self .proximal = proximal
108111 elif projector is not None :
109112 self ._proximal = lambda z , gamma : projector (z )
110113 elif (isinstance (constraint , str ) and constraint .lower () == "nonnegativity" ):
@@ -113,15 +116,48 @@ def _parse_regularization_input_arguments(self, proximal, projector, constraint,
113116 elif (isinstance (constraint , str ) and constraint .lower () == "box" ):
114117 lower = optional_regularization_parameters ["lower_bound" ]
115118 upper = optional_regularization_parameters ["upper_bound" ]
116- self ._proximal = lambda z , gamma : ProjectBox (z , lower , upper )
119+ self ._proximal = lambda z , _ : ProjectBox (z , lower , upper )
117120 self ._preset = "box" # Not supported in Gibbs
118121 elif (isinstance (regularization , str ) and regularization .lower () in ["l1" ]):
119- strength = optional_regularization_parameters ["strength" ]
120- self ._proximal = lambda z , gamma : ProximalL1 (z , gamma * strength )
122+ self . _strength = optional_regularization_parameters ["strength" ]
123+ self ._proximal = lambda z , gamma : ProximalL1 (z , gamma * self . _strength )
121124 self ._preset = "l1"
125+ elif (isinstance (regularization , str ) and regularization .lower () in ["tv" ]):
126+ self ._strength = optional_regularization_parameters ["strength" ]
127+ if isinstance (self .geometry , (Continuous1D , Continuous2D , Image2D )):
128+ self ._transformation = FirstOrderFiniteDifference (self .geometry .fun_shape , bc_type = 'neumann' )
129+ else :
130+ raise ValueError ("Geometry not supported for total variation" )
131+
132+ self ._regularization_prox = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
133+ self ._regularization_oper = self ._transformation
134+
135+ self ._proximal = [(self ._regularization_prox , self ._regularization_oper )]
136+ self ._preset = "tv"
122137 else :
123138 raise ValueError ("Regularization not supported" )
124139
140+
141+ @property
142+ def transformation (self ):
143+ return self ._transformation
144+
145+ @property
146+ def strength (self ):
147+ return self ._strength
148+
149+ @strength .setter
150+ def strength (self , value ):
151+ if self ._preset not in self .regularization_options ():
152+ raise TypeError ("Strength is only used when the regularization is set to l1 or TV." )
153+
154+ self ._strength = value
155+ if self ._preset == "tv" :
156+ self ._regularization_prox = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
157+ self ._proximal = [(self ._regularization_prox , self ._regularization_oper )]
158+ elif self ._preset == "l1" :
159+ self ._proximal = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
160+
125161 # This is a getter only attribute for the underlying Gaussian
126162 # It also ensures that the name of the underlying Gaussian
127163 # matches the name of the implicit regularized Gaussian
@@ -135,6 +171,25 @@ def gaussian(self):
135171 def proximal (self ):
136172 return self ._proximal
137173
174+ @proximal .setter
175+ def proximal (self , value ):
176+ if callable (value ):
177+ if len (get_non_default_args (value )) != 2 :
178+ raise ValueError ("Proximal should take 2 arguments." )
179+ elif isinstance (value , list ):
180+ for (prox , op ) in value :
181+ if len (get_non_default_args (prox )) != 2 :
182+ raise ValueError ("Proximal should take 2 arguments." )
183+ if op .shape [1 ] != self .geometry .par_dim :
184+ raise ValueError ("Incorrect shape of linear operator in proximal list." )
185+ else :
186+ raise ValueError ("Proximal needs to be callable or a list. See documentation." )
187+
188+ self ._proximal = value
189+
190+ # For all the presets, self._proximal is set directly,
191+ self ._preset = None
192+
138193 @property
139194 def preset (self ):
140195 return self ._preset
@@ -154,7 +209,7 @@ def constraint_options():
154209
155210 @staticmethod
156211 def regularization_options ():
157- return ["l1" ]
212+ return ["l1" , "tv" ]
158213
159214
160215 # --- Defer behavior of the underlying Gaussian --- #
@@ -206,16 +261,18 @@ def sqrtcov(self):
206261 def sqrtcov (self , value ):
207262 self .gaussian .sqrtcov = value
208263
209- def get_conditioning_variables (self ):
210- return self .gaussian .get_conditioning_variables ()
211-
212264 def get_mutable_variables (self ):
213- return self .gaussian .get_mutable_variables ()
265+ mutable_vars = self .gaussian .get_mutable_variables ().copy ()
266+ if self .preset in self .regularization_options ():
267+ mutable_vars += ["strength" ]
268+ return mutable_vars
214269
215270 # Overwrite the condition method such that the underlying Gaussian is conditioned in general, except when conditioning on self.name
216271 # which means we convert Distribution to Likelihood or EvaluatedDensity.
217272 def _condition (self , * args , ** kwargs ):
218-
273+ if self .preset in self .regularization_options ():
274+ return super ()._condition (* args , ** kwargs )
275+
219276 # Handle positional arguments (similar code as in Distribution._condition)
220277 cond_vars = self .get_conditioning_variables ()
221278 kwargs = self ._parse_args_add_to_kwargs (cond_vars , * args , ** kwargs )
@@ -275,7 +332,7 @@ class ConstrainedGaussian(RegularizedGaussian):
275332 min_(z in C) 0.5||x-z||_2^2.
276333
277334 constraint : string or None
278- Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
335+ Preset constraints that generate the corresponding proximal parameter . Can be set to "nonnegativity" and "box". Required for use in Gibbs.
279336 For "box", the following additional parameters can be passed:
280337 lower_bound : array_like or None
281338 Lower bound of box, defaults to zero
0 commit comments