Description
🐛 Bug
The RoIAlign
operation and DeformConv2d
operation do not support second order derivatives right now. Thus when trying to compute higher order derivatives (a common case in meta learning), it raises an error.
A snippet to demonstrate this:
import torch
import torch.nn as nn
import torch.autograd as autograd
from torchvision.ops import DeformConv2d
normal_conv = nn.Conv2d(3, 5, 3, 1, 1).cuda()
deform_conv = DeformConv2d(3, 5, 3, 1, 1).cuda()
input = torch.rand(1, 3, 10, 10, requires_grad=True).cuda()
offset = torch.rand(1, 2*1*3*3, 10, 10, requires_grad=True).cuda()
# verify gradient of gradient for normal conv
out1 = normal_conv(input)
grad = autograd.grad(out1.sum(), input, create_graph=True)
loss = sum(g.sum() for g in grad)
loss.backward() # this works fine
print('Succeed in computing second order derivative for normal convs')
# deform conv
out2 = deform_conv(input, offset)
grad = autograd.grad(out2.sum(), [input, offset], create_graph=True)
loss = sum(g.sum() for g in grad)
loss.backward() # an error happens here
Expected behavior
Compute second order derivatives for these ops as well.
Environment
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.0
OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 430.50
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] numpy==1.17.2
[conda] blas 1.0 mkl
[conda] mkl 2020.0 166
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] pytorch 1.4.0 py3.7_cuda10.0.130_cudnn7.6.3_0 pytorch
[conda] torchvision 0.5.0 py37_cu100 pytorch
Additional context
I would be happy to help if someone can point out how to modify the code to support second order derivative.
Activity