Skip to content

BUG: Mesh coordinates reset after compute_gradient with pointwise JAX operations #4089

Open
@ryan-david-murphy

Description

@ryan-david-murphy

When using an ml_operator with pointwise operations on the mesh coordinates, the updated mesh coordinates revert to their original values after calling compute_gradient, even though the assignment appears to work correctly before gradient computation.

Steps to reproduce the behaviour:

from firedrake import *
from firedrake.adjoint import *
continue_annotation()
from firedrake.ml.jax import *
from jax import config
config.update('jax_enable_x64', True)

class Model():
    def __call__(self, u):
        # Using pointwise operation which causes the mesh coordinates to revert after compute_gradient
        u = u.at[0].set(u[0] + 0.1)
        return u

# Create a unit square mesh
mesh = UnitSquareMesh(2, 2)

# Set up a VectorFunctionSpace and assign a dummy function to mesh coordinates for sensitivity computation
V = VectorFunctionSpace(mesh, "CG", 1)
coordinatePerturbation = Function(V)
mesh.coordinates.assign(mesh.coordinates + coordinatePerturbation)

# Create the ml_operator and assemble the updated coordinates
model = Model()
N = ml_operator(model, function_space=V, inputs_format=1)
u1 = assemble(N(mesh.coordinates))
mesh.coordinates.assign(u1)

print('u1:')
print(u1.dat.data)
print('mesh (after assign):')
print(mesh.coordinates.dat.data)

J = assemble(1 * dx(mesh))
dJdu0 = compute_gradient(J, Control(coordinatePerturbation), options={'riesz_representation': 'l2'})

print('mesh (after compute_gradient):')
print(mesh.coordinates.dat.data)

Expected behaviour
The updated mesh coordinates (after assigning u1) should persist even after compute_gradient is called.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions