Skip to content

Sequential backprop impl sketch #72

@vadimkantorov

Description

@vadimkantorov

Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)

import torch
import torch.nn as nn

class SequentialBackprop(nn.Module):
    def __init__(self, module, batch_size = 1):
        super().__init__()
        self.module = module
        self.batch_size = batch_size

    def forward(self, x):
        y = self.module(x.detach())
        return self.Function.apply(x, y, self.batch_size, self.module)

    class Function(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, y, batch_size, module):
            ctx.save_for_backward(x)
            ctx.batch_size = batch_size
            ctx.module = module
            return y

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_tensors
            grads = []
            for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
                with torch.enable_grad():
                    x_mini = x_mini.detach().requires_grad_()
                    x_mini.retain_grad()
                    y_mini = ctx.module(x_mini)
                torch.autograd.backward(y_mini, g_mini)
                grads.append(x_mini.grad)
            return torch.cat(grads), None, None, None

if __name__ == '__main__':
    backbone = nn.Linear(3, 6)
    neck = nn.Linear(6, 12)
    head = nn.Linear(12, 1)

    model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)

    print('before', neck.weight.grad)

    x = torch.rand(512, 3)
    model(x).sum().backward()
    print('after', neck.weight.grad)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions