Skip to content

Commit 231aec2

Browse files
authored
Merge pull request #486 from CUQI-DTU/add_AffineModel
Add AffineModel
2 parents cd8c23a + 9682345 commit 231aec2

File tree

10 files changed

+586
-90
lines changed

10 files changed

+586
-90
lines changed

cuqi/experimental/mcmc/_laplace_approximation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def model(self):
7474
return self.target.model
7575

7676
@property
77-
def data(self):
78-
return self.target.data
77+
def _data(self):
78+
return self.target.data - self.target.model._shift
7979

8080
def _precompute(self):
8181

@@ -89,7 +89,7 @@ def Lk_fun(x_k):
8989
return W.sqrt() @ D
9090
self.Lk_fun = Lk_fun
9191

92-
self._m = len(self.data)
92+
self._m = len(self._data)
9393
self._L1 = self.likelihood.distribution.sqrtprec
9494

9595
# If prior location is scalar, repeat it to match dimensions
@@ -101,17 +101,17 @@ def Lk_fun(x_k):
101101
# Initial Laplace approx
102102
self._L2 = Lk_fun(self.initial_point)
103103
self._L2mu = self._L2@self._priorloc
104-
self._b_tild = np.hstack([self._L1@self.data, self._L2mu])
104+
self._b_tild = np.hstack([self._L1@self._data, self._L2mu])
105105

106106
# Least squares form
107107
def M(x, flag):
108108
if flag == 1:
109-
out1 = self._L1 @ self.model.forward(x)
109+
out1 = self._L1 @ self.model._forward_func_no_shift(x) # Use forward function which excludes shift
110110
out2 = np.sqrt(1/self.prior.scale)*(self._L2 @ x)
111111
out = np.hstack([out1, out2])
112112
elif flag == 2:
113113
idx = int(self._m)
114-
out1 = self.model.adjoint(self._L1.T@x[:idx])
114+
out1 = self.model._adjoint_func_no_shift(self._L1.T@x[:idx])
115115
out2 = np.sqrt(1/self.prior.scale)*(self._L2.T @ x[idx:])
116116
out = out1 + out2
117117
return out
@@ -121,7 +121,7 @@ def step(self):
121121
# Update Laplace approximation
122122
self._L2 = self.Lk_fun(self.current_point)
123123
self._L2mu = self._L2@self._priorloc
124-
self._b_tild = np.hstack([self._L1@self.data, self._L2mu])
124+
self._b_tild = np.hstack([self._L1@self._data, self._L2mu])
125125

126126
# Sample from approximate posterior
127127
e = np.random.randn(len(self._b_tild))
@@ -139,9 +139,9 @@ def validate_target(self):
139139
if not isinstance(self.target, cuqi.distribution.Posterior):
140140
raise ValueError(f"To initialize an object of type {self.__class__}, 'target' need to be of type 'cuqi.distribution.Posterior'.")
141141

142-
# Check Linear model
143-
if not isinstance(self.likelihood.model, cuqi.model.LinearModel):
144-
raise TypeError("Model needs to be linear")
142+
# Check Affine model
143+
if not isinstance(self.likelihood.model, cuqi.model.AffineModel):
144+
raise TypeError("Model needs to be affine or linear")
145145

146146
# Check Gaussian likelihood
147147
if not hasattr(self.likelihood.distribution, "sqrtprec"):

cuqi/experimental/mcmc/_rto.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class LinearRTO(Sampler):
1111
"""
1212
Linear RTO (Randomize-Then-Optimize) sampler.
1313
14-
Samples posterior related to the inverse problem with Gaussian likelihood and prior, and where the forward model is Linear.
14+
Samples posterior related to the inverse problem with Gaussian likelihood and prior, and where the forward model is linear or more generally affine.
1515
1616
Parameters
1717
------------
@@ -22,7 +22,7 @@ class LinearRTO(Sampler):
2222
2323
Here:
2424
data: is a m-dimensional numpy array containing the measured data.
25-
model: is a m by n dimensional matrix or LinearModel representing the forward model.
25+
model: is a m by n dimensional matrix, AffineModel or LinearModel representing the forward model.
2626
L_sqrtprec: is the squareroot of the precision matrix of the Gaussian likelihood.
2727
P_mean: is the prior mean.
2828
P_sqrtprec: is the squareroot of the precision matrix of the Gaussian mean.
@@ -71,21 +71,23 @@ def likelihoods(self):
7171

7272
@property
7373
def model(self):
74-
return self.target.model
75-
74+
return self.target.model
75+
7676
@property
77-
def data(self):
78-
return self.target.data
79-
77+
def models(self):
78+
if isinstance(self.target, cuqi.distribution.Posterior):
79+
return [self.target.model]
80+
elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
81+
return self.target.models
82+
8083
def _precompute(self):
8184
L1 = [likelihood.distribution.sqrtprec for likelihood in self.likelihoods]
8285
L2 = self.prior.sqrtprec
8386
L2mu = self.prior.sqrtprecTimesMean
8487

8588
# pre-computations
8689
self.n = self.prior.dim
87-
self.b_tild = np.hstack([L@likelihood.data for (L, likelihood) in zip(L1, self.likelihoods)]+ [L2mu])
88-
90+
self.b_tild = np.hstack([L@(likelihood.data - model._shift) for (L, likelihood, model) in zip(L1, self.likelihoods, self.models)]+ [L2mu]) # With shift from AffineModel
8991
callability = [callable(likelihood.model) for likelihood in self.likelihoods]
9092
notcallability = [not c for c in callability]
9193
if all(notcallability):
@@ -94,7 +96,7 @@ def _precompute(self):
9496
# in this case, model is a function doing forward and backward operations
9597
def M(x, flag):
9698
if flag == 1:
97-
out1 = [L @ likelihood.model.forward(x) for (L, likelihood) in zip(L1, self.likelihoods)]
99+
out1 = [L @ likelihood.model._forward_func_no_shift(x) for (L, likelihood) in zip(L1, self.likelihoods)] # Use forward function which excludes shift
98100
out2 = L2 @ x
99101
out = np.hstack(out1 + [out2])
100102
elif flag == 2:
@@ -103,7 +105,7 @@ def M(x, flag):
103105
out1 = np.zeros(self.n)
104106
for likelihood in self.likelihoods:
105107
idx_end += len(likelihood.data)
106-
out1 += likelihood.model.adjoint(likelihood.distribution.sqrtprec.T@x[idx_start:idx_end])
108+
out1 += likelihood.model._adjoint_func_no_shift(likelihood.distribution.sqrtprec.T@x[idx_start:idx_end])
107109
idx_start = idx_end
108110
out2 = L2.T @ x[idx_end:]
109111
out = out1 + out2
@@ -129,16 +131,16 @@ def validate_target(self):
129131

130132
# Check Linear model and Gaussian likelihood(s)
131133
if isinstance(self.target, cuqi.distribution.Posterior):
132-
if not isinstance(self.model, cuqi.model.LinearModel):
133-
raise TypeError("Model needs to be linear")
134+
if not isinstance(self.model, cuqi.model.AffineModel):
135+
raise TypeError("Model needs to be linear or more generally affine")
134136

135137
if not hasattr(self.likelihood.distribution, "sqrtprec"):
136138
raise TypeError("Distribution in Likelihood must contain a sqrtprec attribute")
137139

138140
elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior): # Elif used for further alternatives, e.g., stacked posterior
139141
for likelihood in self.likelihoods:
140-
if not isinstance(likelihood.model, cuqi.model.LinearModel):
141-
raise TypeError("Model needs to be linear")
142+
if not isinstance(likelihood.model, cuqi.model.AffineModel):
143+
raise TypeError("Model needs to be linear or more generally affine")
142144

143145
if not hasattr(likelihood.distribution, "sqrtprec"):
144146
raise TypeError("Distribution in Likelihood must contain a sqrtprec attribute")

cuqi/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._model import Model, LinearModel, PDEModel
1+
from ._model import Model, LinearModel, PDEModel, AffineModel

cuqi/model/_model.py

Lines changed: 132 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,126 @@ def __len__(self):
469469

470470
def __repr__(self) -> str:
471471
return "CUQI {}: {} -> {}.\n Forward parameters: {}.".format(self.__class__.__name__,self.domain_geometry,self.range_geometry,cuqi.utilities.get_non_default_args(self))
472-
473-
class LinearModel(Model):
472+
473+
474+
class AffineModel(Model):
475+
""" Model class representing an affine model, i.e. a linear operator with a fixed shift. For linear models, represented by a linear operator only, see :class:`~cuqi.model.LinearModel`.
476+
477+
The affine model is defined as:
478+
479+
.. math::
480+
481+
x \\mapsto Ax + shift
482+
483+
where :math:`A` is the linear operator and :math:`shift` is the shift.
484+
485+
Parameters
486+
----------
487+
488+
linear_operator : 2d ndarray, callable function or cuqi.model.LinearModel
489+
The linear operator. If ndarray is given, the operator is assumed to be a matrix.
490+
491+
shift : scalar or array_like
492+
The shift to be added to the forward operator.
493+
494+
linear_operator_adjoint : callable function, optional
495+
The adjoint of the linear operator. Also used for computing gradients.
496+
497+
range_geometry : cuqi.geometry.Geometry
498+
The geometry representing the range.
499+
500+
domain_geometry : cuqi.geometry.Geometry
501+
The geometry representing the domain.
502+
503+
"""
504+
505+
def __init__(self, linear_operator, shift, linear_operator_adjoint=None, range_geometry=None, domain_geometry=None):
506+
507+
# If input represents a matrix, extract needed properties from it
508+
if hasattr(linear_operator, '__matmul__') and hasattr(linear_operator, 'T'):
509+
if linear_operator_adjoint is not None:
510+
raise ValueError("Adjoint of linear operator should not be provided when linear operator is a matrix. If you want to provide an adjoint, use a callable function for the linear operator.")
511+
512+
matrix = linear_operator
513+
514+
linear_operator = lambda x: matrix@x
515+
linear_operator_adjoint = lambda y: matrix.T@y
516+
517+
if range_geometry is None:
518+
if hasattr(matrix, 'shape'):
519+
range_geometry = _DefaultGeometry1D(grid=matrix.shape[0])
520+
elif isinstance(matrix, LinearModel):
521+
range_geometry = matrix.range_geometry
522+
523+
if domain_geometry is None:
524+
if hasattr(matrix, 'shape'):
525+
domain_geometry = _DefaultGeometry1D(grid=matrix.shape[1])
526+
elif isinstance(matrix, LinearModel):
527+
domain_geometry = matrix.domain_geometry
528+
else:
529+
matrix = None
530+
531+
# Ensure that the operators are a callable functions (either provided or created from matrix)
532+
if not callable(linear_operator):
533+
raise TypeError("Linear operator must be defined as a matrix or a callable function of some kind")
534+
if linear_operator_adjoint is not None and not callable(linear_operator_adjoint):
535+
raise TypeError("Linear operator adjoint must be defined as a callable function of some kind")
536+
537+
# Check size of shift and match against range_geometry
538+
if not np.isscalar(shift):
539+
if len(shift) != range_geometry.par_dim:
540+
raise ValueError("The shift should have the same dimension as the range geometry.")
541+
542+
# Initialize Model class
543+
super().__init__(linear_operator, range_geometry, domain_geometry)
544+
545+
# Store matrix privately
546+
self._matrix = matrix
547+
548+
# Store shift as private attribute
549+
self._shift = shift
550+
551+
# Store linear operator privately
552+
self._linear_operator = linear_operator
553+
554+
# Store adjoint function
555+
self._linear_operator_adjoint = linear_operator_adjoint
556+
557+
# Define gradient
558+
self._gradient_func = lambda direction, wrt: linear_operator_adjoint(direction)
559+
560+
# Update forward function to include shift (overwriting the one from Model class)
561+
self._forward_func = lambda *args, **kwargs: linear_operator(*args, **kwargs) + shift
562+
563+
# Use arguments from user's callable linear operator (overwriting those found by Model class)
564+
self._non_default_args = cuqi.utilities.get_non_default_args(linear_operator)
565+
566+
@property
567+
def shift(self):
568+
""" The shift of the affine model. """
569+
return self._shift
570+
571+
@shift.setter
572+
def shift(self, value):
573+
""" Update the shift of the affine model. Updates both the shift value and the underlying forward function. """
574+
self._shift = value
575+
self._forward_func = lambda *args, **kwargs: self._linear_operator(*args, **kwargs) + value
576+
577+
def _forward_func_no_shift(self, x, is_par=True):
578+
""" Helper function for computing the forward operator without the shift. """
579+
return self._apply_func(self._linear_operator,
580+
self.range_geometry,
581+
self.domain_geometry,
582+
x, is_par)
583+
584+
def _adjoint_func_no_shift(self, y, is_par=True):
585+
""" Helper function for computing the adjoint operator without the shift. """
586+
return self._apply_func(self._linear_operator_adjoint,
587+
self.domain_geometry,
588+
self.range_geometry,
589+
y, is_par)
590+
591+
class LinearModel(AffineModel):
474592
"""Model based on a Linear forward operator.
475593
476594
Parameters
@@ -534,45 +652,11 @@ def adjoint(y):
534652
Note that you would need to specify the range and domain geometries in this
535653
case as they cannot be inferred from the forward and adjoint functions.
536654
"""
537-
# Linear forward model with forward and adjoint (transpose).
538655

539-
def __init__(self,forward,adjoint=None,range_geometry=None,domain_geometry=None):
540-
#Assume forward is matrix if not callable (TODO: add more checks)
541-
if not callable(forward):
542-
forward_func = lambda x: self._matrix@x
543-
adjoint_func = lambda y: self._matrix.T@y
544-
matrix = forward
545-
else:
546-
forward_func = forward
547-
adjoint_func = adjoint
548-
matrix = None
549-
550-
#Check if input is callable
551-
if callable(adjoint_func) is not True:
552-
raise TypeError("Adjoint needs to be callable function of some kind")
553-
554-
# Use matrix to derive range_geometry and domain_geometry
555-
if matrix is not None:
556-
if range_geometry is None:
557-
range_geometry = _DefaultGeometry1D(grid=matrix.shape[0])
558-
if domain_geometry is None:
559-
domain_geometry = _DefaultGeometry1D(grid=matrix.shape[1])
560-
561-
#Initialize Model class
562-
super().__init__(forward_func,range_geometry,domain_geometry)
563-
564-
#Add adjoint
565-
self._adjoint_func = adjoint_func
566-
567-
#Store matrix privately
568-
self._matrix = matrix
569-
570-
#Add gradient
571-
self._gradient_func = lambda direction, wrt: self._adjoint_func(direction)
656+
def __init__(self, forward, adjoint=None, range_geometry=None, domain_geometry=None):
572657

573-
# if matrix is not None:
574-
# assert(self.range_dim == matrix.shape[0]), "The parameter 'forward' dimensions are inconsistent with the parameter 'range_geometry'"
575-
# assert(self.domain_dim == matrix.shape[1]), "The parameter 'forward' dimensions are inconsistent with parameter 'domain_geometry'"
658+
#Initialize as AffineModel with shift=0
659+
super().__init__(forward, 0, adjoint, range_geometry, domain_geometry)
576660

577661
def adjoint(self, y, is_par=True):
578662
""" Adjoint of the model.
@@ -590,16 +674,21 @@ def adjoint(self, y, is_par=True):
590674
ndarray or cuqi.array.CUQIarray
591675
The adjoint model output. Always returned as parameters.
592676
"""
593-
return self._apply_func(self._adjoint_func,
677+
if self._linear_operator_adjoint is None:
678+
raise ValueError("No adjoint operator was provided for this model.")
679+
return self._apply_func(self._linear_operator_adjoint,
594680
self.domain_geometry,
595681
self.range_geometry,
596682
y, is_par)
597683

598-
684+
def __matmul__(self, x):
685+
return self.forward(x)
686+
599687
def get_matrix(self):
600688
"""
601689
Returns an ndarray with the matrix representing the forward operator.
602690
"""
691+
603692
if self._matrix is not None: #Matrix exists so return it
604693
return self._matrix
605694
else:
@@ -617,15 +706,12 @@ def get_matrix(self):
617706
#Store matrix for future use
618707
self._matrix = mat
619708

620-
return self._matrix
621-
622-
def __matmul__(self, x):
623-
return self.forward(x)
709+
return self._matrix
624710

625711
@property
626712
def T(self):
627713
"""Transpose of linear model. Returns a new linear model acting as the transpose."""
628-
transpose = LinearModel(self.adjoint,self.forward,self.domain_geometry,self.range_geometry)
714+
transpose = LinearModel(self.adjoint, self.forward, self.domain_geometry, self.range_geometry)
629715
if self._matrix is not None:
630716
transpose._matrix = self._matrix.T
631717
return transpose

0 commit comments

Comments
 (0)