-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Labels
feature-requestFeature requestFeature request
Description
Hi, gradient loss is common in deep learning optimization, I would like to ensure NATTEN can support double backward such that gradient based loss do not fail silently.
However, currently I found that I can not utilize torch.autograd.gradcheck or torch.autograd.gradgradcheck to check whether NATTEN supporting double backward, since those apis require double type input while NATTEN only support FP32/FP16/BF16.
Below simple testing code just fail at torch.autograd.gradcheck:
import torch
import natten
q = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
k = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
v = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
torch.autograd.gradcheck(lambda q, k, v: natten.functional.na2d(q, k, v, 3, scale=1.0), (q, k, v)) # fail
torch.autograd.gradgradcheck(lambda q, k, v: natten.functional.na2d(q, k, v, 3, scale=1.0), (q, k, v))torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-0.0596, -0.0745, -0.0298, ..., 0.0000, 0.0000, 0.0000],
[ 0.0447, 0.0447, 0.0298, ..., 0.0000, 0.0000, 0.0000],
[-0.0298, -0.0298, -0.0298, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.0000, 0.0000, 0.0000, ..., -0.0596, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., -0.0298, -0.0149, -0.0298],
[ 0.0000, 0.0000, 0.0000, ..., 0.0596, 0.0000, 0.0894]],
device='cuda:0')
analytical:tensor([[-0.0416, -0.0353, -0.0043, ..., 0.0000, 0.0000, 0.0000],
[ 0.0080, 0.0035, 0.0059, ..., 0.0000, 0.0000, 0.0000],
[-0.0239, -0.0478, -0.0411, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.0000, 0.0000, 0.0000, ..., -0.0379, 0.0065, 0.0214],
[ 0.0000, 0.0000, 0.0000, ..., 0.0069, 0.0138, 0.0097],
[ 0.0000, 0.0000, 0.0000, ..., -0.0212, -0.0316, 0.0403]],
device='cuda:0')Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
feature-requestFeature requestFeature request