Skip to content

[RFC] Unify activation checkpointing APIs #2114

Open
@ebsmothers

Description

Where are we today?

We currently provide two different APIs for activation checkpointing.

  1. set_activation_checkpointing is the default for most of our recipes and its contract is similar to that of FSDP wrapping: the user either provides a set of nn.Module types to wrap or a Callable[[nn.Module, bool, int], bool] (see a more detailed description here). In practice we only use the first approach in our recipes, see here.

  2. apply_selective_activation_checkpointing was added as a prototype feature. It has only been integrated into two of our recipes (distributed full finetune and distributed QAT) and is only exposed in a single dev config. It is not currently tested.

Currently neither of these APIs provides a superset of functionality of the other.

What needs to change?

Having both these APIs is redundant and potentially confusing. See e.g. here. We should consolidate behind a single, clean, well-tested API.

What are the requirements?

Imo our AC API should definitely support:

  • Wrapping a set of nn.Module types (i.e. the first case from (1))
  • Selective activation checkpointing (SAC) of every $k^{th}$ layer

However, I claim we do not need to support the second case from the set_activation_checkpointing API (I think the Callable contract is a bit confusing and it doesn't actually cover the most common SAC case of checkpointing every $k^{th}$ layer). Separately, there is op-level SAC as demonstrated in torchtitan here. Imo this is nice-to-have but not must-have as it requires some custom handling.

Proposal

Assuming this, I propose we take a similar approach to our current shard_model utility.

apply_activation_checkpointing(
	model: nn.Module,
	ac_conditions: List[Callable[str, nn.Module], bool]
):
	for n, m in reversed(list(model.named_modules())):
		if any([ac_condition(n, m) for ac_condition in ac_conditions]):
			# apply AC wrapping

Then we can address the first case with e.g. ac_condition = lambda n, m: isinstance(m, TransformerSelfAttentionLayer) and the second with ac_condition = lambda n, m: get_layer_num(n) % k == 0, where get_layer_num is a utility to infer the layer number from the full parameter name.

Potential drawbacks of this approach: (1) we maybe (?) need to do some setattr magic to handle e.g. this. And (2) is that string parsing may feel a bit hacky to infer layer numbers compared to what we currently do in apply_selective_activation_checkpointing. But imo this is worth it for the increased generality (e.g. that utility assumes that we are applying it to a model having a list of layers as a top-level attribute)

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

rfcRequest for comments

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions