Description
When computing the adjoint, we want revolve-style checkpointing for the use cases where the whole simulation state can't be kept in memory, and recomputing is faster than saving to disk.
There are various algorithms for executing the revolving checkpoint. See, for example, https://gitlab.inria.fr/adjoint-computation/H-Revolve/-/tree/master .
However, the basic model these follow is that the simulation is an ordered chain of operations, each of which depends only on the output of the previous one. We, by contrast, have a computation graph in which arbitrary forward dependencies are allowed. However, the normal usage pattern is that the graph encodes a sequence of timesteps, and most data is only used in the next timestep (or a bounded number of timesteps later).
In order to benefit from the previous work, we therefore need to transform our graph into a single chain of operations. The normal usage pattern would suggest that this should be possible and fairly efficient.
The first stage in this process will be to ask the user to tell us where the timesteps end, by making a function call. This will enable us to label every block variable with the timestep number in which it is created. Input block variables (i.e. those that first appear as a dependency) shouldn't have a timestep number (or should have an out of band value such as -1). We can also gather the blocks into timestep groups on the tape.
The next challenge is to determine what needs to be checkpointed. Because we can have arbitrary reuse of old values, this is not as simple as checkpointing all the variables in block
- If
$n_{\textrm{last}} < n_c$ , instruct$v$ to checkpoint itself. Checkpointing is idempotent, so this is safe even if$v$ was already checkpointed. - For
$n_{\textrm{last}} < m <= n_b$ add$v$ to the checkpointable state record of timestep$m$ (this ensures that if the tape is rerun with the checkpoints in different places, the correct data is checkpointed). - Set
$n_{\textrm{last}} = n_b$ .