@@ -469,8 +469,126 @@ def __len__(self):
469469
470470 def __repr__ (self ) -> str :
471471 return "CUQI {}: {} -> {}.\n Forward parameters: {}." .format (self .__class__ .__name__ ,self .domain_geometry ,self .range_geometry ,cuqi .utilities .get_non_default_args (self ))
472-
473- class LinearModel (Model ):
472+
473+
474+ class AffineModel (Model ):
475+ """ Model class representing an affine model, i.e. a linear operator with a fixed shift. For linear models, represented by a linear operator only, see :class:`~cuqi.model.LinearModel`.
476+
477+ The affine model is defined as:
478+
479+ .. math::
480+
481+ x \\ mapsto Ax + shift
482+
483+ where :math:`A` is the linear operator and :math:`shift` is the shift.
484+
485+ Parameters
486+ ----------
487+
488+ linear_operator : 2d ndarray, callable function or cuqi.model.LinearModel
489+ The linear operator. If ndarray is given, the operator is assumed to be a matrix.
490+
491+ shift : scalar or array_like
492+ The shift to be added to the forward operator.
493+
494+ linear_operator_adjoint : callable function, optional
495+ The adjoint of the linear operator. Also used for computing gradients.
496+
497+ range_geometry : cuqi.geometry.Geometry
498+ The geometry representing the range.
499+
500+ domain_geometry : cuqi.geometry.Geometry
501+ The geometry representing the domain.
502+
503+ """
504+
505+ def __init__ (self , linear_operator , shift , linear_operator_adjoint = None , range_geometry = None , domain_geometry = None ):
506+
507+ # If input represents a matrix, extract needed properties from it
508+ if hasattr (linear_operator , '__matmul__' ) and hasattr (linear_operator , 'T' ):
509+ if linear_operator_adjoint is not None :
510+ raise ValueError ("Adjoint of linear operator should not be provided when linear operator is a matrix. If you want to provide an adjoint, use a callable function for the linear operator." )
511+
512+ matrix = linear_operator
513+
514+ linear_operator = lambda x : matrix @x
515+ linear_operator_adjoint = lambda y : matrix .T @y
516+
517+ if range_geometry is None :
518+ if hasattr (matrix , 'shape' ):
519+ range_geometry = _DefaultGeometry1D (grid = matrix .shape [0 ])
520+ elif isinstance (matrix , LinearModel ):
521+ range_geometry = matrix .range_geometry
522+
523+ if domain_geometry is None :
524+ if hasattr (matrix , 'shape' ):
525+ domain_geometry = _DefaultGeometry1D (grid = matrix .shape [1 ])
526+ elif isinstance (matrix , LinearModel ):
527+ domain_geometry = matrix .domain_geometry
528+ else :
529+ matrix = None
530+
531+ # Ensure that the operators are a callable functions (either provided or created from matrix)
532+ if not callable (linear_operator ):
533+ raise TypeError ("Linear operator must be defined as a matrix or a callable function of some kind" )
534+ if linear_operator_adjoint is not None and not callable (linear_operator_adjoint ):
535+ raise TypeError ("Linear operator adjoint must be defined as a callable function of some kind" )
536+
537+ # Check size of shift and match against range_geometry
538+ if not np .isscalar (shift ):
539+ if len (shift ) != range_geometry .par_dim :
540+ raise ValueError ("The shift should have the same dimension as the range geometry." )
541+
542+ # Initialize Model class
543+ super ().__init__ (linear_operator , range_geometry , domain_geometry )
544+
545+ # Store matrix privately
546+ self ._matrix = matrix
547+
548+ # Store shift as private attribute
549+ self ._shift = shift
550+
551+ # Store linear operator privately
552+ self ._linear_operator = linear_operator
553+
554+ # Store adjoint function
555+ self ._linear_operator_adjoint = linear_operator_adjoint
556+
557+ # Define gradient
558+ self ._gradient_func = lambda direction , wrt : linear_operator_adjoint (direction )
559+
560+ # Update forward function to include shift (overwriting the one from Model class)
561+ self ._forward_func = lambda * args , ** kwargs : linear_operator (* args , ** kwargs ) + shift
562+
563+ # Use arguments from user's callable linear operator (overwriting those found by Model class)
564+ self ._non_default_args = cuqi .utilities .get_non_default_args (linear_operator )
565+
566+ @property
567+ def shift (self ):
568+ """ The shift of the affine model. """
569+ return self ._shift
570+
571+ @shift .setter
572+ def shift (self , value ):
573+ """ Update the shift of the affine model. Updates both the shift value and the underlying forward function. """
574+ self ._shift = value
575+ self ._forward_func = lambda * args , ** kwargs : self ._linear_operator (* args , ** kwargs ) + value
576+
577+ def _forward_func_no_shift (self , x , is_par = True ):
578+ """ Helper function for computing the forward operator without the shift. """
579+ return self ._apply_func (self ._linear_operator ,
580+ self .range_geometry ,
581+ self .domain_geometry ,
582+ x , is_par )
583+
584+ def _adjoint_func_no_shift (self , y , is_par = True ):
585+ """ Helper function for computing the adjoint operator without the shift. """
586+ return self ._apply_func (self ._linear_operator_adjoint ,
587+ self .domain_geometry ,
588+ self .range_geometry ,
589+ y , is_par )
590+
591+ class LinearModel (AffineModel ):
474592 """Model based on a Linear forward operator.
475593
476594 Parameters
@@ -534,45 +652,11 @@ def adjoint(y):
534652 Note that you would need to specify the range and domain geometries in this
535653 case as they cannot be inferred from the forward and adjoint functions.
536654 """
537- # Linear forward model with forward and adjoint (transpose).
538655
539- def __init__ (self ,forward ,adjoint = None ,range_geometry = None ,domain_geometry = None ):
540- #Assume forward is matrix if not callable (TODO: add more checks)
541- if not callable (forward ):
542- forward_func = lambda x : self ._matrix @x
543- adjoint_func = lambda y : self ._matrix .T @y
544- matrix = forward
545- else :
546- forward_func = forward
547- adjoint_func = adjoint
548- matrix = None
549-
550- #Check if input is callable
551- if callable (adjoint_func ) is not True :
552- raise TypeError ("Adjoint needs to be callable function of some kind" )
553-
554- # Use matrix to derive range_geometry and domain_geometry
555- if matrix is not None :
556- if range_geometry is None :
557- range_geometry = _DefaultGeometry1D (grid = matrix .shape [0 ])
558- if domain_geometry is None :
559- domain_geometry = _DefaultGeometry1D (grid = matrix .shape [1 ])
560-
561- #Initialize Model class
562- super ().__init__ (forward_func ,range_geometry ,domain_geometry )
563-
564- #Add adjoint
565- self ._adjoint_func = adjoint_func
566-
567- #Store matrix privately
568- self ._matrix = matrix
569-
570- #Add gradient
571- self ._gradient_func = lambda direction , wrt : self ._adjoint_func (direction )
656+ def __init__ (self , forward , adjoint = None , range_geometry = None , domain_geometry = None ):
572657
573- # if matrix is not None:
574- # assert(self.range_dim == matrix.shape[0]), "The parameter 'forward' dimensions are inconsistent with the parameter 'range_geometry'"
575- # assert(self.domain_dim == matrix.shape[1]), "The parameter 'forward' dimensions are inconsistent with parameter 'domain_geometry'"
658+ #Initialize as AffineModel with shift=0
659+ super ().__init__ (forward , 0 , adjoint , range_geometry , domain_geometry )
576660
577661 def adjoint (self , y , is_par = True ):
578662 """ Adjoint of the model.
@@ -590,16 +674,21 @@ def adjoint(self, y, is_par=True):
590674 ndarray or cuqi.array.CUQIarray
591675 The adjoint model output. Always returned as parameters.
592676 """
593- return self ._apply_func (self ._adjoint_func ,
677+ if self ._linear_operator_adjoint is None :
678+ raise ValueError ("No adjoint operator was provided for this model." )
679+ return self ._apply_func (self ._linear_operator_adjoint ,
594680 self .domain_geometry ,
595681 self .range_geometry ,
596682 y , is_par )
597683
598-
684+ def __matmul__ (self , x ):
685+ return self .forward (x )
686+
599687 def get_matrix (self ):
600688 """
601689 Returns an ndarray with the matrix representing the forward operator.
602690 """
691+
603692 if self ._matrix is not None : #Matrix exists so return it
604693 return self ._matrix
605694 else :
@@ -617,15 +706,12 @@ def get_matrix(self):
617706 #Store matrix for future use
618707 self ._matrix = mat
619708
620- return self ._matrix
621-
622- def __matmul__ (self , x ):
623- return self .forward (x )
709+ return self ._matrix
624710
625711 @property
626712 def T (self ):
627713 """Transpose of linear model. Returns a new linear model acting as the transpose."""
628- transpose = LinearModel (self .adjoint ,self .forward ,self .domain_geometry ,self .range_geometry )
714+ transpose = LinearModel (self .adjoint , self .forward , self .domain_geometry , self .range_geometry )
629715 if self ._matrix is not None :
630716 transpose ._matrix = self ._matrix .T
631717 return transpose
0 commit comments