Skip to content

checkpointing bcs #4284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 21, 2025
8 changes: 5 additions & 3 deletions firedrake/adjoint_utils/dirichletbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@

return deps[0]

def _ad_restore_at_checkpoint(self, checkpoint):
if checkpoint is not None:
self.set_value(checkpoint.saved_output)
def _ad_restore_at_checkpoint(self, bv):
if bv is not None:
bc = bc.reconstruct(g=bv.saved_output)

Check failure on line 47 in firedrake/adjoint_utils/dirichletbc.py

View workflow job for this annotation

GitHub Actions / test / Lint codebase

F821

firedrake/adjoint_utils/dirichletbc.py:47:18: F821 undefined name 'bc'
bc.block = self.block
return bc
return self
32 changes: 32 additions & 0 deletions tests/firedrake/adjoint/test_disk_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,35 @@ def delta_expr(x0, x, y, sigma_x=2000.0):

J_hat = ReducedFunctional(J, Control(c))
assert taylor_test(J_hat, c, Function(V).interpolate(0.1)) > 1.9


@pytest.mark.skipcomplex
def test_bcs():

enable_disk_checkpointing()

tape = get_working_tape()
tape.enable_checkpointing(SingleDiskStorageSchedule())

mesh = checkpointable_mesh(UnitSquareMesh(5, 5))
V = FunctionSpace(mesh, "CG", 2)
T = Function(V)
u = TrialFunction(V)
v = TestFunction(V)
a = inner(grad(u), grad(v)) * dx
x = SpatialCoordinate(mesh)
F = Function(V)
control = Control(F)
F.interpolate(sin(x[0] * pi) * sin(2 * x[1] * pi))
L = F * v * dx
uu = Function(V)
bcs = [DirichletBC(V, T, (1,))]
problem = LinearVariationalProblem(a, L, uu, bcs=bcs)
solver = LinearVariationalSolver(problem)

for i in tape.timestepper(iter(range(3))):
T.assign(T + 1.0)
solver.solve()
obj = assemble(uu * uu * dx)
rf = ReducedFunctional(obj, control)
assert np.allclose(rf(F), obj)
22 changes: 19 additions & 3 deletions tests/firedrake/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@ def test_cofunction_assign_functional():

@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
def test_bdy_control():
from firedrake.adjoint_utils.dirichletbc import DirichletBCBlock
# Test for the case the boundary condition is a control for a
# domain with length different from 1.
mesh = IntervalMesh(10, 0, 2)
Expand All @@ -1024,13 +1025,28 @@ def test_bdy_control():
problem = LinearVariationalProblem(lhs(F), rhs(F), sol, bcs=bc)
solver = LinearVariationalSolver(problem)
solver.solve()

# Analytical solution of the analytical Laplace equation is:
# u(x) = a + (b - a)/2 * x
u_analytical = a + (b - a)/2 * X[0]
der_analytical0 = assemble(derivative((u_analytical**2) * dx, a))
der_analytical1 = assemble(derivative((u_analytical**2) * dx, b))
def u_analytical(x, a, b):
return a + (b - a)/2 * x
der_analytical0 = assemble(derivative(
(u_analytical(X[0], a, b)**2) * dx, a))
der_analytical1 = assemble(derivative(
(u_analytical(X[0], a, b)**2) * dx, b))
J = assemble(sol * sol * dx)
J_hat = ReducedFunctional(J, [Control(a), Control(b)])
adj_derivatives = J_hat.derivative(options={"riesz_representation": "l2"})
assert np.allclose(adj_derivatives[0].dat.data_ro, der_analytical0.dat.data_ro)
assert np.allclose(adj_derivatives[1].dat.data_ro, der_analytical1.dat.data_ro)
a = Function(R, val=1.5)
b = Function(R, val=2.5)
J_hat([a, b])
tape = get_working_tape()
# Check the checkpointed boundary conditions are not updating the
# user-defined boundary conditions ``bc_left`` and ``bc_right``.
assert isinstance(tape._blocks[0], DirichletBCBlock) and \
tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg
# tape._blocks[1] is the DirichletBC block for the right boundary
assert isinstance(tape._blocks[1], DirichletBCBlock) and \
tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg
Loading