Skip to content

[FEATURE] Gradient checkpointing in forward_intermediates() #2435

Open
@collinmccarthy

Description

@collinmccarthy

Is your feature request related to a problem? Please describe.
I rely on the forward_intermediates() API for object detection models, and I'm experimenting with ViT-g and would like to try gradient checkpointing.

Describe the solution you'd like
In VisionTransformer.forward_features() we have:

if self.grad_checkpointing and not torch.jit.is_scripting():
    x = checkpoint_seq(self.blocks, x)

I'm thinking something like this could work in VisionTransformer.forward_intermediates():

for i, blk in enumerate(blocks):
    if self.grad_checkpointing and not torch.jit.is_scripting():
        x = checkpoint_module(blk, x)
    else:
        x = blk(x)

I called this checkpoint_module() but I think we could just use checkpoint_seq() directly, based on the code? Either way, is this as simple as I think it would be, or am I missing something? I haven't used gradient checkpointing a lot so I'm not entirely sure.

I'm happy to submit a PR for a few models if it's as simple as calling checkpoint_seq() in forward_intermediates() as I've outlined above. I'm not sure how many models use this API and/or self.grad_checkpointing, and whether you want this to be supported in all of them.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions