Skip to content

Does NATTEN support double backward? #191

@Luciennnnnnn

Description

@Luciennnnnnn

Hi, gradient loss is common in deep learning optimization, I would like to ensure NATTEN can support double backward such that gradient based loss do not fail silently.

However, currently I found that I can not utilize torch.autograd.gradcheck or torch.autograd.gradgradcheck to check whether NATTEN supporting double backward, since those apis require double type input while NATTEN only support FP32/FP16/BF16.

Below simple testing code just fail at torch.autograd.gradcheck:

import torch
import natten

q = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
k = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
v = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)

torch.autograd.gradcheck(lambda q, k, v: natten.functional.na2d(q, k, v, 3, scale=1.0), (q, k, v)) # fail
torch.autograd.gradgradcheck(lambda q, k, v: natten.functional.na2d(q, k, v, 3, scale=1.0), (q, k, v))
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-0.0596, -0.0745, -0.0298,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0447,  0.0447,  0.0298,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0298, -0.0298, -0.0298,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0596,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0298, -0.0149, -0.0298],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0596,  0.0000,  0.0894]],
       device='cuda:0')
analytical:tensor([[-0.0416, -0.0353, -0.0043,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0080,  0.0035,  0.0059,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0239, -0.0478, -0.0411,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0379,  0.0065,  0.0214],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0069,  0.0138,  0.0097],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0212, -0.0316,  0.0403]],
       device='cuda:0')

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions