Skip to content

Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters #1122

Open
@Ancientshi

Description

Hi Pytorch team, recently I need to calculate per sample's gradient with respect to part of model's parameters. The problem is that for the toy example, it works. But for the Wide & Deep model, it doesn't work and returns me all 0 gradients. I don't know why.

Here is the toy example:

import torch
from functorch import grad
from functorch import make_functional_with_buffers, vmap, grad
import torch.nn.functional as F
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        output = x
        return output
    

device = 'cuda'
num_models = 10
batch_size = 64

data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)

model = SimpleCNN().to(device=device)
model=model.eval()
fmodel, params, buffers = make_functional_with_buffers(model)

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
    for key, value in params_tograd.items():
        params[key]=value
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None, None, 0, 0))

params_tograd={}
for i in [-2,0]:
    params_tograd[i]=params[i]
ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, data, targets)
print(ft_per_sample_grads)

The result is :
image

However, when I apply this method to the real scenario, it doesn't works and all return 0 gradient.

      model.load_state_dict(w_tao)
      fmodel, params, buffers = make_functional_with_buffers(model) 

      def loss_fn(predictions, targets):
          return F.mse_loss(predictions, targets)
      
      def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
          for key, value in params_tograd.items():
              params[key]=value
          batch = sample.unsqueeze(0)
          targets = target.unsqueeze(0)

          predictions = fmodel(params, buffers, batch) 
          loss = loss_fn(predictions, targets)
          return loss
      
      ft_compute_grad = grad(compute_loss_stateless_model)
      ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None,None, 0, 0))
      params_tograd={}
      if dataset=='lastfm-1k':
          params_tograd[-1]=params[-1]
          params_tograd[4]=params[4]
      else:
          params_tograd[0]=params[0]
          
      prod_all=[]
      for batch_idx, (inputs, targets) in tqdm(enumerate(train_data_loader)):
          inputs, targets = inputs.to(self.device).float(), targets.to(self.device).float()
          
          ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, inputs, targets)
          print(ft_per_sample_grads)
          sys.exit()
          if dataset!='lastfm-1k':
              params_grads=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
          elif dataset=='lastfm-1k':
              params_grad_dnn=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
              params_grad_linear=ft_per_sample_grads[4].reshape(ft_per_sample_grads[4].shape[0] ,-1)
              params_grads=torch.cat([params_grad_dnn,params_grad_linear],-1) 
              
          prod=torch.mm(params_grads,grad_mean.unsqueeze(1)).squeeze().detach().to('cpu').numpy()
          prod_all.extend(prod)         
      return dict(zip(range(1,len(prod_all)+1), prod_all))

image

image

Also, in the Wide & Deep module, the shape of linear_logit should be (batch_size,1), but when apply this method, the error will happened here, and the system said the shape of linear_logit and sparse_feat_logit is not match, here I attached the print out result. (I suppose when using this method, the X.shape[0]=0, but why?)

        linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
        if len(sparse_embedding_list) > 0:
            #torch.Size([1000, 1, 7])
            sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
            if sparse_feat_refine_weight is not None:
                # w_{x,i}=m_{x,i} * w_i (in IFM and DIFM)
                sparse_embedding_cat = sparse_embedding_cat * sparse_feat_refine_weight.unsqueeze(1)
            
            sparse_feat_logit = torch.sum(sparse_embedding_cat, dim=-1, keepdim=False)
            try:     
                linear_logit += sparse_feat_logit
            except:
                print(linear_logit.shape)
                print(sparse_feat_logit.shape)
                print('linear_logit\n',linear_logit)
                print('sparse_feat_logit\n',sparse_feat_logit)
                sys.exit()
                linear_logit=sparse_feat_logit

image

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions