44from cuqi .experimental .mcmc import Sampler
55from cuqi .distribution import Posterior , Gaussian , Gamma , GMRF , ModifiedHalfNormal
66from cuqi .implicitprior import RegularizedGaussian , RegularizedGMRF , RegularizedUnboundedUniform
7- from cuqi .utilities import get_non_default_args , count_nonzero , count_constant_components_1D , count_constant_components_2D
7+ from cuqi .utilities import get_non_default_args , count_nonzero , count_within_bounds , count_constant_components_1D , count_constant_components_2D , piecewise_linear_1D_DoF
88from cuqi .geometry import Continuous1D , Continuous2D , Image2D
99
1010class Conjugate (Sampler ):
@@ -17,8 +17,8 @@ class Conjugate(Sampler):
1717 - (GMRF, Gamma) where Gamma is defined on the precision parameter of the GMRF
1818 - (RegularizedGaussian, Gamma) with preset constraints only and Gamma is defined on the precision parameter of the RegularizedGaussian
1919 - (RegularizedGMRF, Gamma) with preset constraints only and Gamma is defined on the precision parameter of the RegularizedGMRF
20- - (RegularizedGaussian, ModifiedHalfNormal) with preset constraints and regularization only
21- - (RegularizedGMRF, ModifiedHalfNormal) with preset constraints and regularization only
20+ - (RegularizedGaussian, ModifiedHalfNormal) with most of the preset constraints and regularization
21+ - (RegularizedGMRF, ModifiedHalfNormal) with most of the preset constraints and regularization
2222
2323 Currently the Gamma and ModifiedHalfNormal distribution must be univariate.
2424
@@ -147,8 +147,8 @@ def validate_target(self):
147147 if self .target .prior .dim != 1 :
148148 raise ValueError ("RegularizedGaussian-Gamma conjugacy only works with univariate ModifiedHalfNormal prior" )
149149
150- if self . target . likelihood . distribution . preset [ "constraint" ] not in [ "nonnegativity" ]:
151- raise ValueError ( "RegularizedGaussian-Gamma conjugacy only works with implicit regularized Gaussian likelihood with nonnegativity constraints" )
150+ # Raises error if preset is not supported
151+ _compute_sparsity_level ( self . target )
152152
153153 key_value_pairs = _get_conjugate_parameter (self .target )
154154 if len (key_value_pairs ) != 1 :
@@ -166,7 +166,7 @@ def validate_target(self):
166166 def conjugate_distribution (self ):
167167 # Extract variables
168168 b = self .target .likelihood .data # mu
169- m = np . count_nonzero ( b ) # n
169+ m = _compute_sparsity_level ( self . target )
170170 Ax = self .target .likelihood .distribution .mean # x_i
171171 L = self .target .likelihood .distribution (np .array ([1 ])).sqrtprec # L
172172 alpha = self .target .prior .shape # alpha
@@ -183,9 +183,9 @@ def validate_target(self):
183183 if self .target .prior .dim != 1 :
184184 raise ValueError ("RegularizedUnboundedUniform-Gamma conjugacy only works with univariate Gamma prior" )
185185
186- if self . target . likelihood . distribution . preset [ "regularization" ] not in [ "l1" , "tv" ]:
187- raise ValueError ( "RegularizedUnboundedUniform-Gamma conjugacy only works with implicit regularized Gaussian likelihood with l1 or tv regularization" )
188-
186+ # Raises error if preset is not supported
187+ _compute_sparsity_level ( self . target )
188+
189189 key_value_pairs = _get_conjugate_parameter (self .target )
190190 if len (key_value_pairs ) != 1 :
191191 raise ValueError (f"Multiple references to conjugate parameter { self .target .prior .name } found in likelihood. Only one occurance is supported." )
@@ -219,8 +219,8 @@ def validate_target(self):
219219 if self .target .prior .dim != 1 :
220220 raise ValueError ("RegularizedGaussian-ModifiedHalfNormal conjugacy only works with univariate ModifiedHalfNormal prior" )
221221
222- if self . target . likelihood . distribution . preset [ "regularization" ] not in [ "l1" , "tv" ]:
223- raise ValueError ( "RegularizedGaussian-ModifiedHalfNormal conjugacy only works with implicit regularized Gaussian likelihood with l1 or tv regularization" )
222+ # Raises error if preset is not supported
223+ _compute_sparsity_level ( self . target )
224224
225225 key_value_pairs = _get_conjugate_parameter (self .target )
226226 if len (key_value_pairs ) != 2 :
@@ -266,23 +266,74 @@ def conjugate_distribution(self):
266266
267267
268268def _compute_sparsity_level (target ):
269- """Computes the sparsity level in accordance with Section 4 from [2],"""
269+ """Computes the sparsity level in accordance with Section 4 from [2],
270+ this can be interpreted as the number of degrees of freedom, that is,
271+ the number of components n minus the dimension the of the subdifferential of the regularized.
272+ """
270273 x = target .likelihood .data
271- if target .likelihood .distribution .preset ["constraint" ] == "nonnegativity" :
272- if target .likelihood .distribution .preset ["regularization" ] == "l1" :
273- m = count_nonzero (x )
274- elif target .likelihood .distribution .preset ["regularization" ] == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
275- m = count_constant_components_1D (x , lower = 0.0 )
276- elif target .likelihood .distribution .preset ["regularization" ] == "tv" and isinstance (target .likelihood .distribution .geometry , (Continuous2D , Image2D )):
277- m = count_constant_components_2D (target .likelihood .distribution .geometry .par2fun (x ), lower = 0.0 )
278- else : # No constraints, only regularization
279- if target .likelihood .distribution .preset ["regularization" ] == "l1" :
280- m = count_nonzero (x )
281- elif target .likelihood .distribution .preset ["regularization" ] == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
282- m = count_constant_components_1D (x )
283- elif target .likelihood .distribution .preset ["regularization" ] == "tv" and isinstance (target .likelihood .distribution .geometry , (Continuous2D , Image2D )):
284- m = count_constant_components_2D (target .likelihood .distribution .geometry .par2fun (x ))
285- return m
274+
275+ constraint = target .likelihood .distribution .preset ["constraint" ]
276+ regularization = target .likelihood .distribution .preset ["regularization" ]
277+
278+ # There is no reference for some of these conjugacy rules
279+ if constraint == "nonnegativity" :
280+ if regularization in [None , "l1" ]:
281+ # Number of non-zero components in x
282+ return count_nonzero (x )
283+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
284+ # Number of non-zero constant components in x
285+ return count_constant_components_1D (x , lower = 0.0 )
286+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , (Continuous2D , Image2D )):
287+ # Number of non-zero constant components in x
288+ return count_constant_components_2D (target .likelihood .distribution .geometry .par2fun (x ), lower = 0.0 )
289+ elif constraint == "box" :
290+ bounds = target .likelihood .distribution ._box_bounds
291+ if regularization is None :
292+ # Number of components in x that are strictly between the lower and upper bound
293+ return count_within_bounds (x , bounds [0 ], bounds [1 ])
294+ elif regularization == "l1" :
295+ # Number of components in x that are strictly between the lower and upper bound and are not zero
296+ return count_within_bounds (x , bounds [0 ], bounds [1 ], exception = 0.0 )
297+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
298+ # Number of constant components in x between are strictly between the lower and upper bound
299+ return count_constant_components_1D (x , lower = bounds [0 ], upper = bounds [1 ])
300+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , (Continuous2D , Image2D )):
301+ # Number of constant components in x between are strictly between the lower and upper bound
302+ return count_constant_components_2D (target .likelihood .distribution .geometry .par2fun (x ), lower = bounds [0 ], upper = bounds [1 ])
303+ elif constraint in ["increasing" , "decreasing" ]:
304+ if regularization is None :
305+ # Number of constant components in x
306+ return count_constant_components_1D (x )
307+ elif regularization == "l1" :
308+ # Number of constant components in x that are not zero
309+ return count_constant_components_1D (x , exception = 0.0 )
310+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
311+ # Number of constant components in x
312+ return count_constant_components_1D (x )
313+ # Increasing and decreasing cannot be done in 2D
314+ elif constraint in ["convex" , "concave" ]:
315+ if regularization is None :
316+ # Number of piecewise linear components in x
317+ return piecewise_linear_1D_DoF (x )
318+ elif regularization == "l1" :
319+ # Number of piecewise linear components in x that are not zero
320+ return piecewise_linear_1D_DoF (x , exception_zero = True )
321+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
322+ # Number of piecewise linear components in x that are not flat
323+ return piecewise_linear_1D_DoF (x , exception_flat = True )
324+ # convex and concave has only been implemented in 1D
325+ elif constraint == None :
326+ if regularization == "l1" :
327+ # Number of non-zero components in x
328+ return count_nonzero (x )
329+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , Continuous1D ):
330+ # Number of non-zero constant components in x
331+ return count_constant_components_1D (x )
332+ elif regularization == "tv" and isinstance (target .likelihood .distribution .geometry , (Continuous2D , Image2D )):
333+ # Number of non-zero constant components in x
334+ return count_constant_components_2D (target .likelihood .distribution .geometry .par2fun (x ))
335+
336+ raise ValueError ("RegularizedGaussian preset constraint and regularization choice is currently not supported with conjugacy." )
286337
287338
288339def _get_conjugate_parameter (target ):
0 commit comments