22from cuqi .distribution import Distribution , Gaussian
33from cuqi .solver import ProjectNonnegative , ProjectBox , ProximalL1
44from cuqi .geometry import Continuous1D , Continuous2D , Image2D
5- from cuqi .operator import FirstOrderFiniteDifference
5+ from cuqi .operator import FirstOrderFiniteDifference , Operator
66
77import numpy as np
8+ import scipy .sparse as sparse
89from copy import copy
910
1011
@@ -48,6 +49,8 @@ class RegularizedGaussian(Distribution):
4849 min_z 0.5||x-z||_2^2+scale*g(x).
4950 If list of tuples (callable proximal operator of f_i, linear operator L_i):
5051 Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
52+ Each linear operator needs to have the '__matmul__', 'T' and 'shape' attributes;
53+ this includes numpy.ndarray, scipy.sparse.sparray, scipy.sparse.linalg.LinearOperator and cuqi.operator.Operator.
5154 The corresponding regularization takes the form
5255 sum_i f_i(L_i x),
5356 where the sum ranges from 1 to an arbitrary n.
@@ -88,59 +91,137 @@ def __init__(self, mean=None, cov=None, prec=None, sqrtcov=None, sqrtprec=None,
8891 # Init from abstract distribution class
8992 super ().__init__ (** kwargs )
9093
94+ self ._force_list = False
9195 self ._parse_regularization_input_arguments (proximal , projector , constraint , regularization , optional_regularization_parameters )
9296
9397 def _parse_regularization_input_arguments (self , proximal , projector , constraint , regularization , optional_regularization_parameters ):
9498 """ Parse regularization input arguments with guarding statements and store internal states """
9599
96- # Check that only one of proximal, projector, constraint or regularization is provided
97- if (proximal is not None ) + (projector is not None ) + (constraint is not None ) + (regularization is not None ) != 1 :
98- raise ValueError ("Precisely one of proximal, projector, constraint or regularization needs to be provided." )
100+ # Guards checking whether the regularization inputs are valid
101+ if (proximal is not None ) + (projector is not None ) + max ((constraint is not None ), (regularization is not None )) == 0 :
102+ raise ValueError ("At least some constraint or regularization has to be specified for RegularizedGaussian" )
103+
104+ if (proximal is not None ) + (projector is not None ) == 2 :
105+ raise ValueError ("Only one of proximal or projector can be used." )
106+
107+ if (proximal is not None ) + (projector is not None ) + max ((constraint is not None ), (regularization is not None )) > 1 :
108+ raise ValueError ("User-defined proximals and projectors cannot be combined with pre-defined constraints and regularization." )
109+
110+ # Branch between user-defined and preset
111+ if (proximal is not None ) + (projector is not None ) >= 1 :
112+ self ._parse_user_specified_input (proximal , projector )
113+ else :
114+ # Set constraint and regularization presets for use with Gibbs
115+ self ._preset = {"constraint" : None ,
116+ "regularization" : None }
99117
118+ self ._parse_preset_constraint_input (constraint , optional_regularization_parameters )
119+ self ._parse_preset_regularization_input (regularization , optional_regularization_parameters )
120+
121+ # Merge
122+ self ._merge_predefined_option ()
123+
124+ def _parse_user_specified_input (self , proximal , projector ):
125+ # Guard for checking partial validy of proximals or projectors
126+ if proximal is not None :
127+ if callable (proximal ):
128+ if len (get_non_default_args (proximal )) != 2 :
129+ raise ValueError ("Proximal should take 2 arguments." )
130+ elif isinstance (proximal , list ):
131+ for val in proximal :
132+ if len (val ) != 2 :
133+ raise ValueError ("Each value in the proximal list needs to consistent of two elements: a proximal operator and a linear operator." )
134+ if callable (val [0 ]):
135+ if len (get_non_default_args (val [0 ])) != 2 :
136+ raise ValueError ("Proximal should take 2 arguments." )
137+ else :
138+ raise ValueError ("Proximal operators need to be callable." )
139+ if not (hasattr (val [1 ], '__matmul__' ) and hasattr (val [1 ], 'T' ) and hasattr (val [1 ], 'shape' )):
140+ raise ValueError ("Linear operator not supported, must have '__matmul__', 'T' and 'shape' attributes." )
141+ else :
142+ raise ValueError ("Proximal needs to be callable or a list. See documentation." )
143+
100144 if projector is not None :
101- if not callable (projector ):
102- raise ValueError ("Projector needs to be callable." )
103- if len (get_non_default_args (projector )) != 1 :
104- raise ValueError ("Projector should take 1 argument." )
145+ if callable (projector ):
146+ if len (get_non_default_args (projector )) != 1 :
147+ raise ValueError ("Projector should take 1 argument." )
148+ else :
149+ raise ValueError ("Projector needs to be callable" )
105150
106- # Preset information, for use in Gibbs
107- self ._preset = None
108-
151+ # Set user-defined proximals or projectors
109152 if proximal is not None :
110- # No need to generate the proximal and associated information
111- self .proximal = proximal
112- elif projector is not None :
153+ self ._preset = None
154+ self ._proximal = proximal
155+ return
156+
157+ if projector is not None :
158+ self ._preset = None
113159 self ._proximal = lambda z , gamma : projector (z )
114- elif (isinstance (constraint , str ) and constraint .lower () == "nonnegativity" ):
115- self ._proximal = lambda z , gamma : ProjectNonnegative (z )
116- self ._preset = "nonnegativity"
117- self ._box_bounds = (np .ones (self .dim )* 0 , np .ones (self .dim )* np .inf )
118- elif (isinstance (constraint , str ) and constraint .lower () == "box" ):
119- self ._box_lower = optional_regularization_parameters ["lower_bound" ]
120- self ._box_upper = optional_regularization_parameters ["upper_bound" ]
121- self ._box_bounds = (np .ones (self .dim )* self ._box_lower , np .ones (self .dim )* self ._box_upper )
122- self ._proximal = lambda z , _ : ProjectBox (z , self ._box_lower , self ._box_upper )
123- self ._preset = "box" # Not supported in Gibbs
124- elif (isinstance (regularization , str ) and regularization .lower () in ["l1" ]):
125- self ._strength = optional_regularization_parameters ["strength" ]
126- self ._proximal = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
127- self ._preset = "l1"
128- elif (isinstance (regularization , str ) and regularization .lower () in ["tv" ]):
160+ return
161+
162+ def _parse_preset_constraint_input (self , constraint , optional_regularization_parameters ):
163+ # Create data for constraints
164+ self ._constraint_prox = None
165+ self ._constraint_oper = None
166+ if constraint is not None :
167+ if not isinstance (constraint , str ):
168+ raise ValueError ("Constraint needs to be specified as a string." )
169+
170+ c_lower = constraint .lower ()
171+ if c_lower == "nonnegativity" :
172+ self ._constraint_prox = lambda z , gamma : ProjectNonnegative (z )
173+ self ._box_bounds = (np .ones (self .dim )* 0 , np .ones (self .dim )* np .inf )
174+ self ._preset ["constraint" ] = "nonnegativity"
175+ elif c_lower == "box" :
176+ _box_lower = optional_regularization_parameters ["lower_bound" ]
177+ _box_upper = optional_regularization_parameters ["upper_bound" ]
178+ self ._proximal = lambda z , _ : ProjectBox (z , _box_lower , _box_upper )
179+ self ._box_bounds = (np .ones (self .dim )* _box_lower , np .ones (self .dim )* _box_upper )
180+ self ._preset ["constraint" ] = "box"
181+ else :
182+ raise ValueError ("Constraint not supported." )
183+
184+ def _parse_preset_regularization_input (self , regularization , optional_regularization_parameters ):
185+ # Create data for regularization
186+ self ._regularization_prox = None
187+ self ._regularization_oper = None
188+ if regularization is not None :
189+ if not isinstance (regularization , str ):
190+ raise ValueError ("Regularization needs to be specified as a string." )
191+
129192 self ._strength = optional_regularization_parameters ["strength" ]
130- if isinstance (self .geometry , (Continuous1D , Continuous2D , Image2D )):
131- self ._transformation = FirstOrderFiniteDifference (self .geometry .fun_shape , bc_type = 'neumann' )
193+ r_lower = regularization .lower ()
194+ if r_lower == "l1" :
195+ self ._regularization_prox = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
196+ self ._preset ["regularization" ] = "l1"
197+ elif r_lower == "tv" :
198+ # Store the transformation to reuse when modifying the strength
199+ if not isinstance (self .geometry , (Continuous1D , Continuous2D , Image2D )):
200+ raise ValueError ("Geometry not supported for total variation" )
201+ self ._regularization_prox = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
202+ self ._regularization_oper = FirstOrderFiniteDifference (self .geometry .fun_shape , bc_type = 'neumann' )
203+ self ._preset ["regularization" ] = "tv"
132204 else :
133- raise ValueError ("Geometry not supported for total variation" )
134-
135- self ._regularization_prox = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
136- self ._regularization_oper = self ._transformation
137-
138- self ._proximal = [(self ._regularization_prox , self ._regularization_oper )]
139- self ._preset = "tv"
140- else :
141- raise ValueError ("Regularization not supported" )
142-
205+ raise ValueError ("Regularization not supported." )
143206
207+ def _merge_predefined_option (self ):
208+ # Check whether it is a single proximal and hence FISTA could be used in RegularizedLinearRTO
209+ if ((not self ._force_list ) and
210+ ((self ._constraint_prox is not None ) + (self ._regularization_prox is not None ) == 1 ) and
211+ ((self ._constraint_oper is not None ) + (self ._regularization_oper is not None ) == 0 )):
212+ if self ._constraint_prox is not None :
213+ self ._proximal = self ._constraint_prox
214+ else :
215+ self ._proximal = self ._regularization_prox
216+ return
217+
218+ # Merge regularization choices in list for use in ADMM by RegularizedLinearRTO
219+ self ._proximal = []
220+ if self ._constraint_prox is not None :
221+ self ._proximal += [(self ._constraint_prox , self ._constraint_oper if self ._constraint_oper is not None else sparse .eye (self .geometry .par_dim ))]
222+ if self ._regularization_prox is not None :
223+ self ._proximal += [(self ._regularization_prox , self ._regularization_oper if self ._regularization_oper is not None else sparse .eye (self .geometry .par_dim ))]
224+
144225 @property
145226 def transformation (self ):
146227 return self ._transformation
@@ -151,15 +232,15 @@ def strength(self):
151232
152233 @strength .setter
153234 def strength (self , value ):
154- if self ._preset not in self .regularization_options () :
235+ if self ._preset is None or self ._preset [ "regularization" ] is None :
155236 raise TypeError ("Strength is only used when the regularization is set to l1 or TV." )
156237
157238 self ._strength = value
158- if self ._preset == " tv":
239+ if self ._preset [ "regularization" ] in [ "l1" , " tv"]:
159240 self ._regularization_prox = lambda z , gamma : ProximalL1 (z , gamma * self ._strength )
160- self . _proximal = [( self . _regularization_prox , self . _regularization_oper )]
161- elif self . _preset == "l1" :
162- self ._proximal = lambda z , gamma : ProximalL1 ( z , gamma * self . _strength )
241+
242+ # Create new list of proximals based on updated regularization
243+ self ._merge_predefined_option ( )
163244
164245 # This is a getter only attribute for the underlying Gaussian
165246 # It also ensures that the name of the underlying Gaussian
@@ -266,7 +347,7 @@ def sqrtcov(self, value):
266347
267348 def get_mutable_variables (self ):
268349 mutable_vars = self .gaussian .get_mutable_variables ().copy ()
269- if self .preset in self .regularization_options () :
350+ if self .preset is not None and self .preset [ 'regularization' ] in [ "l1" , "tv" ] :
270351 mutable_vars += ["strength" ]
271352 return mutable_vars
272353
0 commit comments