diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 05d3438318..d11081022e 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1490,13 +1490,15 @@ def __init__(self, V, source_vom, target_vom, expr, arguments): ) self.V = V self.source_vom = source_vom + self.target_vom = target_vom self.expr = expr self.arguments = arguments self.reduce = reduce # note that interpolation doesn't include halo cells - self.handle = VomOntoVomDummyMat( - original_vom.input_ordering_without_halos_sf, reduce, V, source_vom, expr, arguments + self.dummy_mat = VomOntoVomDummyMat( + original_vom.input_ordering_without_halos_sf, reduce, V, source_vom, target_vom, expr, arguments ) + self.handle = self.dummy_mat._create_petsc_mat() @property def mpi_type(self): @@ -1505,14 +1507,14 @@ def mpi_type(self): Should correspond to the underlying data type of the PETSc Vec. """ - return self.handle.mpi_type + return self.dummy_mat.mpi_type @mpi_type.setter def mpi_type(self, val): - self.handle.mpi_type = val + self.dummy_mat.mpi_type = val def forward_operation(self, target_dat): - coeff = self.handle.expr_as_coeff() + coeff = self.dummy_mat.expr_as_coeff() with coeff.dat.vec_ro as coeff_vec, target_dat.vec_wo as target_vec: self.handle.mult(coeff_vec, target_vec) @@ -1543,11 +1545,12 @@ class VomOntoVomDummyMat(object): The arguments in the expression. """ - def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments): + def __init__(self, sf, forward_reduce, V, source_vom, target_vom, expr, arguments): self.sf = sf self.forward_reduce = forward_reduce self.V = V self.source_vom = source_vom + self.target_vom = target_vom self.expr = expr self.arguments = arguments @@ -1614,7 +1617,7 @@ def reduce(self, source_vec, target_vec): ) def broadcast(self, source_vec, target_vec): - source_arr = source_vec.getArray() + source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.bcastBegin( self.mpi_type, @@ -1629,7 +1632,7 @@ def broadcast(self, source_vec, target_vec): MPI.REPLACE, ) - def mult(self, source_vec, target_vec): + def mult(self, mat, source_vec, target_vec): # need to evaluate expression before doing mult coeff = self.expr_as_coeff(source_vec) with coeff.dat.vec_ro as coeff_vec: @@ -1638,7 +1641,10 @@ def mult(self, source_vec, target_vec): else: self.broadcast(coeff_vec, target_vec) - def multHermitian(self, source_vec, target_vec): + def multHermitian(self, mat, source_vec, target_vec): + self.multTranspose(mat, source_vec, target_vec) + + def multTranspose(self, mat, source_vec, target_vec): # can only do adjoint if our expression exclusively contains a # single argument, making the application of the adjoint operator # straightforward (haven't worked out how to do this otherwise!) @@ -1664,3 +1670,19 @@ def multHermitian(self, source_vec, target_vec): # matrix will then have rows of zeros for those points. target_vec.zeroEntries() self.reduce(source_vec, target_vec) + + def _create_petsc_mat(self): + mat = PETSc.Mat().create(comm=self.V.comm) + element = self.V.ufl_element() # Could be vector/tensor valued + P0DG_source = firedrake.FunctionSpace(self.source_vom, element) + P0DG_target = P0DG_source.reconstruct(mesh=self.target_vom) + source_size = P0DG_source.dof_dset.layout_vec.getSizes() + target_size = P0DG_target.dof_dset.layout_vec.getSizes() + mat.setSizes([target_size, source_size]) + mat.setType(PETSc.Mat().Type.PYTHON) + mat.setPythonContext(self) + mat.setUp() + return mat + + def duplicate(self, A=None, op=None): + return self._create_petsc_mat()