Skip to content

Commit 9de8d55

Browse files
Ig-dolcidham
andauthored
Docs for checkpointing in adjoint simulations (#4094)
* Docs for checkpointing in adjoint simulations --------- Co-authored-by: David A. Ham <[email protected]>
1 parent 04e543e commit 9de8d55

File tree

3 files changed

+152
-21
lines changed

3 files changed

+152
-21
lines changed

docs/source/checkpointing.rst

Lines changed: 125 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -230,24 +230,137 @@ with `idx` parameter always unset, and the same :class:`~.Function` can only be
230230
loaded using the same mode.
231231

232232

233-
Using disk checkpointing in adjoint simulations
233+
Using checkpointing in adjoint simulations
234234
===============================================
235235

236236
When adjoint annotation is active, the result of every Firedrake operation is
237-
stored in memory. For some simulations, this can result in a very large memory
238-
footprint. As an alternative, it is possible to specify that those intermediate
239-
results in forward evaluations of the tape which have type
240-
:class:`~firedrake.function.Function` be written to disk. This is usually the
241-
bulk of the data stored on the tape so this largely alleviates the memory
242-
problem, at the cost of the time taken to read to and write from disk.
237+
stored in memory. For some time-dependent simulations, this can lead to a
238+
large memory footprint. To alleviate this, we can use checkpointing strategies to store only some intermediate forward data in memory or on disk.
243239

244-
Having imported `firedrake.adjoint`, there are two steps required to enable
245-
disk checkpointing of the forward tape state.
240+
Checkpointing for time-dependent adjoint simulations in Firedrake employs
241+
schedules, which determine how forward data is stored. These schedules are
242+
implemented in the `checkpoint_schedules package
243+
<https://www.firedrakeproject.org/checkpoint_schedules/>`_.
246244

247-
1. Call :func:`~firedrake.adjoint_utils.checkpointing.enable_disk_checkpointing`.
248-
2. Wrap all mesh constructors in :func:`~firedrake.adjoint_utils.checkpointing.checkpointable_mesh`.
249245

250-
See the documentation of those functions for more detail.
246+
To store every time step of the forward data required for adjoint-based gradient
247+
computation **in memory**, first import the schedule from the
248+
``checkpoint_schedules`` package, start adjoint annotation with ``continue_annotation()``,
249+
get the working tape with ``get_working_tape()``:
250+
251+
.. code-block:: python3
252+
253+
from firedrake import *
254+
from firedrake.adjoint import *
255+
from checkpoint_schedules import SingleMemoryStorageSchedule
256+
continue_annotation()
257+
tape = get_working_tape()
258+
259+
Define the schedule:
260+
261+
.. literalinclude:: ../../tests/firedrake/adjoint/test_burgers_newton.py
262+
:language: python3
263+
:dedent:
264+
:start-after: [test_disk_checkpointing 4]
265+
:end-before: [test_disk_checkpointing 5]
266+
267+
and enable checkpointing:
268+
269+
.. literalinclude:: ../../tests/firedrake/adjoint/test_burgers_newton.py
270+
:language: python3
271+
:dedent:
272+
:start-after: [test_disk_checkpointing 6]
273+
:end-before: [test_disk_checkpointing 7]
274+
275+
**For any checkpointing approach, it is essential to call the time loop as
276+
follows when advancing the solver in time:**
277+
278+
.. literalinclude:: ../../tests/firedrake/adjoint/test_burgers_newton.py
279+
:language: python3
280+
:dedent:
281+
:start-after: [test_disk_checkpointing 10]
282+
:end-before: [test_disk_checkpointing 11]
283+
284+
285+
``SingleMemoryStorageSchedule`` stores only the adjoint variables from the last adjoint
286+
time step, which corresponds to the zero forward time step due to the time-reversed nature
287+
of the adjoint solver.
288+
289+
290+
To store every time step of the forward data required for adjoint-based gradient
291+
computation **on disk**, write the necessary imports and start adjoint annotation:
292+
293+
.. code-block:: python3
294+
295+
from firedrake import *
296+
from firedrake.adjoint import *
297+
from checkpoint_schedules import SingleDiskStorageSchedule
298+
299+
continue_annotation()
300+
tape = get_working_tape()
301+
302+
Then, enable disk checkpointing following the code below:
303+
304+
.. literalinclude:: ../../tests/firedrake/adjoint/test_disk_checkpointing.py
305+
:language: python3
306+
:dedent:
307+
:start-after: [test_disk_checkpointing 1]
308+
:end-before: [test_disk_checkpointing 2]
309+
310+
For disk checkpointing, all mesh constructors must be wrapped using
311+
:func:`~.checkpointing.checkpointable_mesh`. For example:
312+
313+
.. literalinclude:: ../../tests/firedrake/adjoint/test_disk_checkpointing.py
314+
:language: python3
315+
:dedent:
316+
:start-after: [test_disk_checkpointing 2]
317+
:end-before: [test_disk_checkpointing 3]
318+
319+
``SingleDiskStorageSchedule`` stores only the adjoint variables from the last adjoint
320+
time step (equivalent to the zero forward time step) in memory. This checkpointing strategy
321+
reduces memory usage at the cost of reading and writing data from disk.
322+
323+
The ``checkpoint_schedules`` package provides other checkpointing
324+
strategies, such as Revolve, Mixed Schedule, and HRevolve. These methods
325+
store only a limited set of time steps, reducing memory usage at the cost of
326+
increased computational effort due to repeated forward calculations.
327+
328+
For example, to use the **Revolve** schedule:
329+
330+
.. code-block:: python3
331+
332+
from firedrake import *
333+
from firedrake.adjoint import *
334+
from checkpoint_schedules import Revolve
335+
336+
continue_annotation()
337+
tape = get_working_tape()
338+
339+
.. literalinclude:: ../../tests/firedrake/adjoint/test_burgers_newton.py
340+
:language: python3
341+
:dedent:
342+
:start-after: [test_disk_checkpointing 8]
343+
:end-before: [test_disk_checkpointing 9]
344+
345+
.. literalinclude:: ../../tests/firedrake/adjoint/test_burgers_newton.py
346+
:language: python3
347+
:dedent:
348+
:start-after: [test_disk_checkpointing 6]
349+
:end-before: [test_disk_checkpointing 7]
350+
351+
Then, advance the solver in time as follows:
352+
353+
.. literalinclude:: ../../tests/firedrake/adjoint/test_burgers_newton.py
354+
:language: python3
355+
:dedent:
356+
:start-after: [test_disk_checkpointing 10]
357+
:end-before: [test_disk_checkpointing 11]
358+
359+
``steps_to_store`` is the number of time steps stored in memory.
360+
361+
For more details on available checkpointing strategies, refer to the
362+
`checkpoint_schedules package documentation
363+
<https://www.firedrakeproject.org/checkpoint_schedules/>`_.
251364

252365

253366
Checkpointing with DumbCheckpoint

tests/firedrake/adjoint/test_burgers_newton.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _check_reverse(tape):
9999
assert out._checkpoint is None
100100

101101

102-
def J(ic, nu, solve_type, timestep, steps, V, nu_time_dependent=False):
102+
def J(ic, nu, solve_type, timestep, total_steps, V, nu_time_dependent=False):
103103
"""Burgers equation solver."""
104104
u_ = Function(V, name="u_")
105105
u = Function(V, name="u")
@@ -115,9 +115,15 @@ def J(ic, nu, solve_type, timestep, steps, V, nu_time_dependent=False):
115115

116116
tape = get_working_tape()
117117
J = 0.0
118-
for j in tape.timestepper(range(steps)):
119-
if nu_time_dependent and j > 4:
120-
nu.assign(nu*(1.0 + j/1000))
118+
119+
# The comment below and the others like it are used to generate the
120+
# documentation for the firedrake/docs/source/chekpointing.rst file.
121+
# [test_disk_checkpointing 10]
122+
for step in tape.timestepper(range(total_steps)):
123+
# Advance the forward model
124+
# [test_disk_checkpointing 11]
125+
if nu_time_dependent and step > 4:
126+
nu.assign(nu*(1.0 + step/1000))
121127
if solve_type == "NLVS":
122128
solver.solve()
123129
else:
@@ -135,26 +141,33 @@ def test_burgers_newton(solve_type, checkpointing, basics):
135141
"""
136142
tape = get_working_tape()
137143
tape.progress_bar = ProgressBar
138-
mesh, timestep, steps = basics
144+
mesh, timestep, total_steps = basics
139145
if checkpointing:
146+
steps_to_store = total_steps//3
140147
if checkpointing == "Revolve":
141-
schedule = Revolve(steps, steps//3)
148+
# [test_disk_checkpointing 8]
149+
schedule = Revolve(total_steps, steps_to_store)
150+
# [test_disk_checkpointing 9]
142151
if checkpointing == "SingleMemory":
152+
# [test_disk_checkpointing 4]
143153
schedule = SingleMemoryStorageSchedule()
154+
# [test_disk_checkpointing 5]
144155
if checkpointing == "Mixed":
145156
enable_disk_checkpointing()
146-
schedule = MixedCheckpointSchedule(steps, steps//3, storage=StorageType.DISK)
157+
schedule = MixedCheckpointSchedule(total_steps, steps_to_store, storage=StorageType.DISK)
147158
if checkpointing == "NoneAdjoint":
148159
schedule = NoneCheckpointSchedule()
160+
# [test_disk_checkpointing 6]
149161
tape.enable_checkpointing(schedule)
162+
# [test_disk_checkpointing 7]
150163

151164
if checkpointing and schedule.uses_storage_type(StorageType.DISK):
152165
mesh = checkpointable_mesh(mesh)
153166

154167
V, ic, nu = setup_test(mesh)
155-
val = J(ic, nu, solve_type, timestep, steps, V)
168+
val = J(ic, nu, solve_type, timestep, total_steps, V)
156169
if checkpointing:
157-
assert len(tape.timesteps) == steps
170+
assert len(tape.timesteps) == total_steps
158171
if checkpointing == "Revolve" or checkpointing == "Mixed":
159172
_check_forward(tape)
160173

tests/firedrake/adjoint/test_disk_checkpointing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,14 @@ def test_disk_checkpointing_parallel():
108108
tape = get_working_tape()
109109
tape.clear_tape()
110110
continue_annotation()
111+
# The comment below and the others like it are used to generate the
112+
# documentation for the firedrake/docs/source/chekpointing.rst file.
113+
# [test_disk_checkpointing 1]
111114
enable_disk_checkpointing()
112115
tape.enable_checkpointing(SingleDiskStorageSchedule())
116+
# [test_disk_checkpointing 2]
113117
mesh = checkpointable_mesh(UnitSquareMesh(10, 10))
118+
# [test_disk_checkpointing 3]
114119
J_disk, grad_J_disk = adjoint_example(mesh)
115120

116121
assert disk_checkpointing() is False

0 commit comments

Comments
 (0)