@@ -230,24 +230,137 @@ with `idx` parameter always unset, and the same :class:`~.Function` can only be
230
230
loaded using the same mode.
231
231
232
232
233
- Using disk checkpointing in adjoint simulations
233
+ Using checkpointing in adjoint simulations
234
234
===============================================
235
235
236
236
When adjoint annotation is active, the result of every Firedrake operation is
237
- stored in memory. For some simulations, this can result in a very large memory
238
- footprint. As an alternative, it is possible to specify that those intermediate
239
- results in forward evaluations of the tape which have type
240
- :class: `~firedrake.function.Function ` be written to disk. This is usually the
241
- bulk of the data stored on the tape so this largely alleviates the memory
242
- problem, at the cost of the time taken to read to and write from disk.
237
+ stored in memory. For some time-dependent simulations, this can lead to a
238
+ large memory footprint. To alleviate this, we can use checkpointing strategies to store only some intermediate forward data in memory or on disk.
243
239
244
- Having imported `firedrake.adjoint `, there are two steps required to enable
245
- disk checkpointing of the forward tape state.
240
+ Checkpointing for time-dependent adjoint simulations in Firedrake employs
241
+ schedules, which determine how forward data is stored. These schedules are
242
+ implemented in the `checkpoint_schedules package
243
+ <https://www.firedrakeproject.org/checkpoint_schedules/> `_.
246
244
247
- 1. Call :func: `~firedrake.adjoint_utils.checkpointing.enable_disk_checkpointing `.
248
- 2. Wrap all mesh constructors in :func: `~firedrake.adjoint_utils.checkpointing.checkpointable_mesh `.
249
245
250
- See the documentation of those functions for more detail.
246
+ To store every time step of the forward data required for adjoint-based gradient
247
+ computation **in memory **, first import the schedule from the
248
+ ``checkpoint_schedules `` package, start adjoint annotation with ``continue_annotation() ``,
249
+ get the working tape with ``get_working_tape() ``:
250
+
251
+ .. code-block :: python3
252
+
253
+ from firedrake import *
254
+ from firedrake.adjoint import *
255
+ from checkpoint_schedules import SingleMemoryStorageSchedule
256
+ continue_annotation()
257
+ tape = get_working_tape()
258
+
259
+ Define the schedule:
260
+
261
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_burgers_newton.py
262
+ :language: python3
263
+ :dedent:
264
+ :start-after: [test_disk_checkpointing 4]
265
+ :end-before: [test_disk_checkpointing 5]
266
+
267
+ and enable checkpointing:
268
+
269
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_burgers_newton.py
270
+ :language: python3
271
+ :dedent:
272
+ :start-after: [test_disk_checkpointing 6]
273
+ :end-before: [test_disk_checkpointing 7]
274
+
275
+ **For any checkpointing approach, it is essential to call the time loop as
276
+ follows when advancing the solver in time: **
277
+
278
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_burgers_newton.py
279
+ :language: python3
280
+ :dedent:
281
+ :start-after: [test_disk_checkpointing 10]
282
+ :end-before: [test_disk_checkpointing 11]
283
+
284
+
285
+ ``SingleMemoryStorageSchedule `` stores only the adjoint variables from the last adjoint
286
+ time step, which corresponds to the zero forward time step due to the time-reversed nature
287
+ of the adjoint solver.
288
+
289
+
290
+ To store every time step of the forward data required for adjoint-based gradient
291
+ computation **on disk **, write the necessary imports and start adjoint annotation:
292
+
293
+ .. code-block :: python3
294
+
295
+ from firedrake import *
296
+ from firedrake.adjoint import *
297
+ from checkpoint_schedules import SingleDiskStorageSchedule
298
+
299
+ continue_annotation()
300
+ tape = get_working_tape()
301
+
302
+ Then, enable disk checkpointing following the code below:
303
+
304
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_disk_checkpointing.py
305
+ :language: python3
306
+ :dedent:
307
+ :start-after: [test_disk_checkpointing 1]
308
+ :end-before: [test_disk_checkpointing 2]
309
+
310
+ For disk checkpointing, all mesh constructors must be wrapped using
311
+ :func: `~.checkpointing.checkpointable_mesh `. For example:
312
+
313
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_disk_checkpointing.py
314
+ :language: python3
315
+ :dedent:
316
+ :start-after: [test_disk_checkpointing 2]
317
+ :end-before: [test_disk_checkpointing 3]
318
+
319
+ ``SingleDiskStorageSchedule `` stores only the adjoint variables from the last adjoint
320
+ time step (equivalent to the zero forward time step) in memory. This checkpointing strategy
321
+ reduces memory usage at the cost of reading and writing data from disk.
322
+
323
+ The ``checkpoint_schedules `` package provides other checkpointing
324
+ strategies, such as Revolve, Mixed Schedule, and HRevolve. These methods
325
+ store only a limited set of time steps, reducing memory usage at the cost of
326
+ increased computational effort due to repeated forward calculations.
327
+
328
+ For example, to use the **Revolve ** schedule:
329
+
330
+ .. code-block :: python3
331
+
332
+ from firedrake import *
333
+ from firedrake.adjoint import *
334
+ from checkpoint_schedules import Revolve
335
+
336
+ continue_annotation()
337
+ tape = get_working_tape()
338
+
339
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_burgers_newton.py
340
+ :language: python3
341
+ :dedent:
342
+ :start-after: [test_disk_checkpointing 8]
343
+ :end-before: [test_disk_checkpointing 9]
344
+
345
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_burgers_newton.py
346
+ :language: python3
347
+ :dedent:
348
+ :start-after: [test_disk_checkpointing 6]
349
+ :end-before: [test_disk_checkpointing 7]
350
+
351
+ Then, advance the solver in time as follows:
352
+
353
+ .. literalinclude :: ../../tests/firedrake/adjoint/test_burgers_newton.py
354
+ :language: python3
355
+ :dedent:
356
+ :start-after: [test_disk_checkpointing 10]
357
+ :end-before: [test_disk_checkpointing 11]
358
+
359
+ ``steps_to_store `` is the number of time steps stored in memory.
360
+
361
+ For more details on available checkpointing strategies, refer to the
362
+ `checkpoint_schedules package documentation
363
+ <https://www.firedrakeproject.org/checkpoint_schedules/> `_.
251
364
252
365
253
366
Checkpointing with DumbCheckpoint
0 commit comments