From 9373385d02f1991fc4fbc8945b82395d37f787d9 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sun, 18 May 2025 17:19:39 +0100 Subject: [PATCH 1/4] Re-enable adjoint interpolation --- animate/interpolation.py | 4 +--- test/test_interpolation.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/animate/interpolation.py b/animate/interpolation.py index d2e1102..7d23356 100644 --- a/animate/interpolation.py +++ b/animate/interpolation.py @@ -254,9 +254,7 @@ def _transfer_adjoint(target_b, source_b, transfer_method, **kwargs): # Apply adjoint transfer operator to each component for i, (t_b, s_b) in enumerate(zip(target_b_split, source_b_split)): if transfer_method == "interpolate": - raise NotImplementedError( - "Adjoint of interpolation operator not implemented." - ) # TODO (#113) + s_b.interpolate(t_b, adjoint=True, **kwargs) elif transfer_method == "project": ksp = petsc4py.KSP().create() ksp.setOperators(assemble_mass_matrix(t_b.function_space(), lumped=bounded)) diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 1cdfd65..485937c 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -255,7 +255,7 @@ def test_transfer_same_space(self, transfer_method): expected = source self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand(["project"]) # TODO: interpolate (#113) + @parameterized.expand(["interpolate", "project"]) def test_transfer_same_space_adjoint(self, transfer_method): pytest.skip() # TODO: (#114) Vs = FunctionSpace(self.source_mesh, "CG", 1) @@ -279,7 +279,7 @@ def test_transfer_same_space_mixed(self, transfer_method): expected = source self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand(["project"]) # TODO: interpolate (#113) + @parameterized.expand(["interpolate", "project"]) def test_transfer_same_space_mixed_adjoint(self, transfer_method): pytest.skip() # TODO: (#114) P1 = FunctionSpace(self.source_mesh, "CG", 1) @@ -307,7 +307,7 @@ def test_transfer_same_mesh(self, transfer_method): expected = Function(Vt).project(source) self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand(["project"]) # TODO: interpolate (#113) + @parameterized.expand(["interpolate", "project"]) def test_transfer_same_mesh_adjoint(self, transfer_method): pytest.skip() # TODO: (#114) Vs = FunctionSpace(self.source_mesh, "CG", 1) @@ -343,7 +343,7 @@ def test_transfer_same_mesh_mixed(self, transfer_method): e2.project(s2) self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand(["project"]) # TODO: interpolate (#113) + @parameterized.expand(["interpolate", "project"]) def test_transfer_same_mesh_mixed_adjoint(self, transfer_method): pytest.skip() # TODO: (#114) P1 = FunctionSpace(self.source_mesh, "CG", 1) From db7734f20e69d415e02fa707aad6656d4a7c8ae3 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sun, 18 May 2025 17:49:22 +0100 Subject: [PATCH 2/4] Accept rvalue for (co)function2(co)function --- animate/utility.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/animate/utility.py b/animate/utility.py index 16d450a..1ac0c62 100644 --- a/animate/utility.py +++ b/animate/utility.py @@ -279,14 +279,17 @@ def assemble_mass_matrix(space, norm_type="L2", lumped=False): return mass_matrix.createDiagonal(x) -def cofunction2function(cofunc): +def cofunction2function(cofunc, func=None): """ :arg cofunc: a cofunction :type cofunc: :class:`firedrake.cofunction.Cofunction` + :kwarg func: a function for the return value + :type func: :class:`firedrake.function.Function` :returns: a function with the same underyling data :rtype: :class:`firedrake.function.Function` """ - func = ffunc.Function(cofunc.function_space().dual()) + if func is None: + func = ffunc.Function(cofunc.function_space().dual()) if isinstance(func.dat.data_with_halos, tuple): for i, arr in enumerate(func.dat.data_with_halos): arr[:] = cofunc.dat.data_with_halos[i] @@ -295,14 +298,17 @@ def cofunction2function(cofunc): return func -def function2cofunction(func): +def function2cofunction(func, cofunc=None): """ :arg func: a function :type func: :class:`firedrake.function.Function` + :kwarg cofunc: a cofunction for the return value + :type cofunc: :class:`firedrake.cofunction.Cofunction` :returns: a cofunction with the same underlying data :rtype: :class:`firedrake.cofunction.Cofunction` """ - cofunc = firedrake.Cofunction(func.function_space().dual()) + if cofunc is None: + cofunc = firedrake.Cofunction(func.function_space().dual()) if isinstance(cofunc.dat.data_with_halos, tuple): for i, arr in enumerate(cofunc.dat.data_with_halos): arr[:] = func.dat.data_with_halos[i] From 34856839eff0457298ec4129652072b3dbb2331c Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sun, 18 May 2025 17:49:43 +0100 Subject: [PATCH 3/4] Fix adjoint interpolation for same space case --- animate/interpolation.py | 26 ++++++++++++-------------- test/test_interpolation.py | 4 +--- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/animate/interpolation.py b/animate/interpolation.py index 7d23356..24e1f89 100644 --- a/animate/interpolation.py +++ b/animate/interpolation.py @@ -232,27 +232,25 @@ def _transfer_adjoint(target_b, source_b, transfer_method, **kwargs): bounded = is_project and kwargs.pop("bounded", False) # Map to Functions to apply the adjoint transfer - if not isinstance(target_b, firedrake.Function): - target_b = cofunction2function(target_b) - if not isinstance(source_b, firedrake.Function): - source_b = cofunction2function(source_b) + target_b_func = cofunction2function(target_b) + source_b_func = cofunction2function(source_b) - Vt = target_b.function_space() - Vs = source_b.function_space() + Vt = target_b_func.function_space() + Vs = source_b_func.function_space() if Vs == Vt: - source_b.assign(target_b) - return function2cofunction(source_b) + source_b_func.assign(target_b_func) + return function2cofunction(source_b_func, source_b) _validate_matching_spaces(Vs, Vt) if hasattr(Vs, "num_sub_spaces"): - target_b_split = target_b.subfunctions - source_b_split = source_b.subfunctions + target_b_func_split = target_b_func.subfunctions + source_b_func_split = source_b_func.subfunctions else: - target_b_split = [target_b] - source_b_split = [source_b] + target_b_func_split = [target_b_func] + source_b_func_split = [source_b_func] # Apply adjoint transfer operator to each component - for i, (t_b, s_b) in enumerate(zip(target_b_split, source_b_split)): + for i, (t_b, s_b) in enumerate(zip(target_b_func_split, source_b_func_split)): if transfer_method == "interpolate": s_b.interpolate(t_b, adjoint=True, **kwargs) elif transfer_method == "project": @@ -270,7 +268,7 @@ def _transfer_adjoint(target_b, source_b, transfer_method, **kwargs): ) # Map back to a Cofunction - return function2cofunction(source_b) + return function2cofunction(source_b_func, source_b) def _validate_matching_spaces(Vs, Vt): diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 485937c..eac3991 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -16,7 +16,7 @@ TensorFunctionSpace, VectorFunctionSpace, ) -from firedrake.norms import errornorm +from animate.utility import errornorm from firedrake.utility_meshes import UnitSquareMesh from parameterized import parameterized @@ -257,7 +257,6 @@ def test_transfer_same_space(self, transfer_method): @parameterized.expand(["interpolate", "project"]) def test_transfer_same_space_adjoint(self, transfer_method): - pytest.skip() # TODO: (#114) Vs = FunctionSpace(self.source_mesh, "CG", 1) source = Function(Vs).interpolate(self.sinusoid()) source = function2cofunction(source) @@ -281,7 +280,6 @@ def test_transfer_same_space_mixed(self, transfer_method): @parameterized.expand(["interpolate", "project"]) def test_transfer_same_space_mixed_adjoint(self, transfer_method): - pytest.skip() # TODO: (#114) P1 = FunctionSpace(self.source_mesh, "CG", 1) Vs = P1 * P1 source = Function(Vs) From 370eda68dd638a20b64c5462ee62b956273ce705 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sun, 18 May 2025 17:56:40 +0100 Subject: [PATCH 4/4] Lint --- test/test_interpolation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_interpolation.py b/test/test_interpolation.py index eac3991..012af8d 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -16,7 +16,6 @@ TensorFunctionSpace, VectorFunctionSpace, ) -from animate.utility import errornorm from firedrake.utility_meshes import UnitSquareMesh from parameterized import parameterized @@ -28,7 +27,7 @@ project, transfer, ) -from animate.utility import function2cofunction +from animate.utility import errornorm, function2cofunction class TestClement(unittest.TestCase):