Description
Is your feature request related to a problem? Please describe.
I rely on the forward_intermediates()
API for object detection models, and I'm experimenting with ViT-g and would like to try gradient checkpointing.
Describe the solution you'd like
In VisionTransformer.forward_features()
we have:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
I'm thinking something like this could work in VisionTransformer.forward_intermediates()
:
for i, blk in enumerate(blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_module(blk, x)
else:
x = blk(x)
I called this checkpoint_module()
but I think we could just use checkpoint_seq()
directly, based on the code? Either way, is this as simple as I think it would be, or am I missing something? I haven't used gradient checkpointing a lot so I'm not entirely sure.
I'm happy to submit a PR for a few models if it's as simple as calling checkpoint_seq()
in forward_intermediates()
as I've outlined above. I'm not sure how many models use this API and/or self.grad_checkpointing
, and whether you want this to be supported in all of them.