Skip to content

Missing second order derivatives for operations like RoIAlign and DeformConv #1982

Open
@xieshuqin

Description

@xieshuqin

🐛 Bug

The RoIAlign operation and DeformConv2d operation do not support second order derivatives right now. Thus when trying to compute higher order derivatives (a common case in meta learning), it raises an error.

A snippet to demonstrate this:

import torch
import torch.nn as nn
import torch.autograd as autograd
from torchvision.ops import DeformConv2d

normal_conv = nn.Conv2d(3, 5, 3, 1, 1).cuda()
deform_conv = DeformConv2d(3, 5, 3, 1, 1).cuda()

input = torch.rand(1, 3, 10, 10, requires_grad=True).cuda()
offset = torch.rand(1, 2*1*3*3, 10, 10, requires_grad=True).cuda()

# verify gradient of gradient for normal conv
out1 = normal_conv(input)
grad = autograd.grad(out1.sum(), input, create_graph=True)
loss = sum(g.sum() for g in grad)
loss.backward() # this works fine
print('Succeed in computing second order derivative for normal convs')

# deform conv
out2 = deform_conv(input, offset)
grad = autograd.grad(out2.sum(), [input, offset], create_graph=True)
loss = sum(g.sum() for g in grad)
loss.backward() # an error happens here

Expected behavior

Compute second order derivatives for these ops as well.

Environment

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.0

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 430.50
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.17.2
[conda] blas 1.0 mkl
[conda] mkl 2020.0 166
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] pytorch 1.4.0 py3.7_cuda10.0.130_cudnn7.6.3_0 pytorch
[conda] torchvision 0.5.0 py37_cu100 pytorch

Additional context

I would be happy to help if someone can point out how to modify the code to support second order derivative.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions