Skip to content

[RFC] Unify activation checkpointing APIs #2114

Open
@ebsmothers

Description

@ebsmothers

Where are we today?

We currently provide two different APIs for activation checkpointing.

  1. set_activation_checkpointing is the default for most of our recipes and its contract is similar to that of FSDP wrapping: the user either provides a set of nn.Module types to wrap or a Callable[[nn.Module, bool, int], bool] (see a more detailed description here). In practice we only use the first approach in our recipes, see here.

  2. apply_selective_activation_checkpointing was added as a prototype feature. It has only been integrated into two of our recipes (distributed full finetune and distributed QAT) and is only exposed in a single dev config. It is not currently tested.

Currently neither of these APIs provides a superset of functionality of the other.

What needs to change?

Having both these APIs is redundant and potentially confusing. See e.g. here. We should consolidate behind a single, clean, well-tested API.

What are the requirements?

Imo our AC API should definitely support:

  • Wrapping a set of nn.Module types (i.e. the first case from (1))
  • Selective activation checkpointing (SAC) of every $k^{th}$ layer

However, I claim we do not need to support the second case from the set_activation_checkpointing API (I think the Callable contract is a bit confusing and it doesn't actually cover the most common SAC case of checkpointing every $k^{th}$ layer). Separately, there is op-level SAC as demonstrated in torchtitan here. Imo this is nice-to-have but not must-have as it requires some custom handling.

Proposal

Assuming this, I propose we take a similar approach to our current shard_model utility.

apply_activation_checkpointing(
	model: nn.Module,
	ac_conditions: List[Callable[str, nn.Module], bool]
):
	for n, m in reversed(list(model.named_modules())):
		if any([ac_condition(n, m) for ac_condition in ac_conditions]):
			# apply AC wrapping

Then we can address the first case with e.g. ac_condition = lambda n, m: isinstance(m, TransformerSelfAttentionLayer) and the second with ac_condition = lambda n, m: get_layer_num(n) % k == 0, where get_layer_num is a utility to infer the layer number from the full parameter name.

Potential drawbacks of this approach: (1) we maybe (?) need to do some setattr magic to handle e.g. this. And (2) is that string parsing may feel a bit hacky to infer layer numbers compared to what we currently do in apply_selective_activation_checkpointing. But imo this is worth it for the increased generality (e.g. that utility assumes that we are applying it to a model having a list of layers as a top-level attribute)

Metadata

Metadata

Assignees

Labels

rfcRequest for comments

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions