-
Notifications
You must be signed in to change notification settings - Fork 45
Open
Description
Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)
import torch
import torch.nn as nn
class SequentialBackprop(nn.Module):
def __init__(self, module, batch_size = 1):
super().__init__()
self.module = module
self.batch_size = batch_size
def forward(self, x):
y = self.module(x.detach())
return self.Function.apply(x, y, self.batch_size, self.module)
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, batch_size, module):
ctx.save_for_backward(x)
ctx.batch_size = batch_size
ctx.module = module
return y
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grads = []
for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
with torch.enable_grad():
x_mini = x_mini.detach().requires_grad_()
x_mini.retain_grad()
y_mini = ctx.module(x_mini)
torch.autograd.backward(y_mini, g_mini)
grads.append(x_mini.grad)
return torch.cat(grads), None, None, None
if __name__ == '__main__':
backbone = nn.Linear(3, 6)
neck = nn.Linear(6, 12)
head = nn.Linear(12, 1)
model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)
print('before', neck.weight.grad)
x = torch.rand(512, 3)
model(x).sum().backward()
print('after', neck.weight.grad)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels