Open
Description
When using an ml_operator with pointwise operations on the mesh coordinates, the updated mesh coordinates revert to their original values after calling compute_gradient
, even though the assignment appears to work correctly before gradient computation.
Steps to reproduce the behaviour:
from firedrake import *
from firedrake.adjoint import *
continue_annotation()
from firedrake.ml.jax import *
from jax import config
config.update('jax_enable_x64', True)
class Model():
def __call__(self, u):
# Using pointwise operation which causes the mesh coordinates to revert after compute_gradient
u = u.at[0].set(u[0] + 0.1)
return u
# Create a unit square mesh
mesh = UnitSquareMesh(2, 2)
# Set up a VectorFunctionSpace and assign a dummy function to mesh coordinates for sensitivity computation
V = VectorFunctionSpace(mesh, "CG", 1)
coordinatePerturbation = Function(V)
mesh.coordinates.assign(mesh.coordinates + coordinatePerturbation)
# Create the ml_operator and assemble the updated coordinates
model = Model()
N = ml_operator(model, function_space=V, inputs_format=1)
u1 = assemble(N(mesh.coordinates))
mesh.coordinates.assign(u1)
print('u1:')
print(u1.dat.data)
print('mesh (after assign):')
print(mesh.coordinates.dat.data)
J = assemble(1 * dx(mesh))
dJdu0 = compute_gradient(J, Control(coordinatePerturbation), options={'riesz_representation': 'l2'})
print('mesh (after compute_gradient):')
print(mesh.coordinates.dat.data)
Expected behaviour
The updated mesh coordinates (after assigning u1) should persist even after compute_gradient
is called.