Description
Where are we today?
We currently provide two different APIs for activation checkpointing.
-
set_activation_checkpointing is the default for most of our recipes and its contract is similar to that of FSDP wrapping: the user either provides a set of nn.Module types to wrap or a
Callable[[nn.Module, bool, int], bool]
(see a more detailed description here). In practice we only use the first approach in our recipes, see here. -
apply_selective_activation_checkpointing was added as a prototype feature. It has only been integrated into two of our recipes (distributed full finetune and distributed QAT) and is only exposed in a single dev config. It is not currently tested.
Currently neither of these APIs provides a superset of functionality of the other.
What needs to change?
Having both these APIs is redundant and potentially confusing. See e.g. here. We should consolidate behind a single, clean, well-tested API.
What are the requirements?
Imo our AC API should definitely support:
- Wrapping a set of nn.Module types (i.e. the first case from (1))
- Selective activation checkpointing (SAC) of every
$k^{th}$ layer
However, I claim we do not need to support the second case from the set_activation_checkpointing
API (I think the Callable contract is a bit confusing and it doesn't actually cover the most common SAC case of checkpointing every
Proposal
Assuming this, I propose we take a similar approach to our current shard_model utility.
apply_activation_checkpointing(
model: nn.Module,
ac_conditions: List[Callable[str, nn.Module], bool]
):
for n, m in reversed(list(model.named_modules())):
if any([ac_condition(n, m) for ac_condition in ac_conditions]):
# apply AC wrapping
Then we can address the first case with e.g. ac_condition = lambda n, m: isinstance(m, TransformerSelfAttentionLayer)
and the second with ac_condition = lambda n, m: get_layer_num(n) % k == 0
, where get_layer_num
is a utility to infer the layer number from the full parameter name.
Potential drawbacks of this approach: (1) we maybe (?) need to do some setattr
magic to handle e.g. this. And (2) is that string parsing may feel a bit hacky to infer layer numbers compared to what we currently do in apply_selective_activation_checkpointing. But imo this is worth it for the increased generality (e.g. that utility assumes that we are applying it to a model having a list of layers as a top-level attribute)
Activity