Skip to content

ValueError: Per sample gradient is not initialized. Not updated in backward pass? #645

Open
@Lhsakfheudsj35407

Description

🐛 Bug

Hello!
I try to run a deep learning task on CIFAR10 using Opacus, but I get "ValueError: Per sample gradient is not initialized. Not updated in backward pass?" error. I have seen the discussion about this issue but I still can't figure out what is wrong with my code.
Here, I provide the code in the hope that others can reproduce it and help fix it.

Code

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from tqdm import tqdm
from opacus import PrivacyEngine
from opacus.accountants.utils import get_noise_multiplier

class Net(nn.Module):
    def __init__(self, dataset, strategy):
        super(Net, self).__init__()

        self.dataset = dataset
        self.strategy = strategy            
        self.num_classes = 10               

        self.module_splits = nn.ModuleList()
        self.head_splits = []
        self.classifier = None

        self.ind = -1
        self.enc = None
        self.head = None

        self.classifier = nn.Linear(64, 16)

        # conv1
        self.module_splits.append(nn.Sequential(nn.Conv2d(3, 32, 3),
                                                nn.ReLU(),
                                                nn.MaxPool2d((2, 2))))
        # conv2_1
        self.module_splits.append(nn.Sequential(nn.Conv2d(32, 64, 3),
                                                nn.ReLU(),
                                                nn.MaxPool2d(2, 2)))
        # conv2_2
        self.module_splits.append(nn.Sequential(nn.Conv2d(64, 64, 3),
                                                nn.ReLU()))
        # fc1
        self.module_splits.append(nn.Sequential(nn.Flatten(),
                                                nn.Linear(4 * 4 * 64, 64),
                                                nn.ReLU()))

        # head for conv1
        self.head_splits.append(nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                              nn.Flatten(),
                                              nn.Linear(32, 10)))
        # head for conv2_1
        self.head_splits.append(nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                              nn.Flatten(),
                                              nn.Linear(64, 10)))
        # head for conv2_2
        self.head_splits.append(nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                              nn.Flatten(),
                                              nn.Linear(64, 10)))
        # head for fc1
        self.head_splits.append(self.classifier)

    def forward(self, x):
        out = x
        for m in self.module_splits:
            out = m(out)
        out = self.classifier(out)
        return out

    def set_submodel(self, ind, strategy=None):
        # mnist models only has three stages
        # if 'mnist' in self.dataset and ind == 3:
        #     ind = 2
        self.ind = ind

        assert ind <= 3

        if strategy == None:
            strategy = self.strategy
        print(strategy, ind, self.dataset)

        if strategy == 'progressive':
            modules = []
            for i in range(ind + 1):
                modules.append(self.module_splits[i])
            self.enc = nn.Sequential(*modules)
            self.head = self.head_splits[ind]

        elif strategy == 'baseline':
            modules = []
            for i in range(len(self.module_splits)):
                modules.append(self.module_splits[i])
            self.enc = nn.Sequential(*modules)
            self.head = self.classifier
        else:
            raise NotImplementedError()

    def gen_submodel(self):
        return SingleSubModel(self.enc, self.head, self.strategy, self.ind)

def convnet(args):
    return Net(args.dataset, args.strategy)

class SingleSubModel(nn.Module):
    """Submodels that produce an only single output"""

    def __init__(self, enc, head, strategy, ind):
        super(SingleSubModel, self).__init__()
        m_list = nn.ModuleList()
        for m in enc:
            m_list.append(m)

        self.enc = m_list
        self.head = head
        self.strategy = strategy
        self.ind = ind

    def forward(self, x, verbose=False):
        feats = []
        out = x
        for m in self.enc:
            out = m(out)
            feats.append(out)

        if not verbose:
            return self.head(feats[-1])
        else:
            return self.head(feats[-1]), feats

    def print_weight(self):
        for n, p in self.named_parameters():
            print(n, p)

class UpdateScheduler(object):
    def __init__(self, update_cycles, num_stages=4, update_strategy=None):
        self.update_cycles = update_cycles
        self.num_stages = num_stages
        self.update_strategy = update_strategy
        if isinstance(update_cycles, int):
            self.update_cycles =[update_cycles for _ in range(num_stages-1)]

        self.accumulate()

    def __getitem__(self, index):
        assert index < self.num_stages
        return self.update_cycles[index]

    def __str__(self):
        return f'update_cycles: {self.update_cycles}; update_strategy: {self.update_strategy}'

    def accumulate(self):
        for i in range(1, len(self.update_cycles)):
            self.update_cycles[i] += self.update_cycles[i-1]
        self.update_cycles = np.append(self.update_cycles, [1e9])
        self.update_cycles = self.update_cycles.astype(int)

def main():
    use_cuda = True if torch.cuda.is_available() else False
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

    transform_train = transforms.Compose([
        transforms.RandomCrop(
            size=32,
            padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, **kwargs)

    testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=transform_test, download=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=True, **kwargs)

    # prepare the model
    model = Net(datasets,strategy).to(device)

    if strategy == 'baseline':
        layer_cnt = 3
    else:
        layer_cnt = 0
    model.set_submodel(layer_cnt)

    # progressive setting
    update_scheduler = UpdateScheduler(update_cycles=update_cycle, num_stages=num_stage, update_strategy=None)

    # loss function
    metric = nn.CrossEntropyLoss()

    # optimizer
    optimizer = optim.SGD(params=model.parameters(), lr=learning_rate, momentum=0.9)

    privacy_engine = PrivacyEngine()
    model, optimizer, trainloader = privacy_engine.make_private_with_epsilon(
        module=model,
        optimizer=optimizer,
        data_loader=trainloader,
        target_epsilon=epsilon,
        target_delta=delta,
        epochs=epochs,
        max_grad_norm=1.0,  
    )

    # train
    for epoch in tqdm(range(1, epochs + 1)):
        if (strategy != 'baseline' and epoch != 0 and epoch == update_scheduler[layer_cnt] and layer_cnt < num_stage-1):
            layer_cnt += 1
            model.set_submodel(layer_cnt)
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = metric(outputs, labels)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:
                print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss/2000))
            running_loss = 0.0

    # test
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

if __name__ == '__main__':

    # list of parameters
    epochs = 300
    learning_rate = 0.01
    train_batch_size = 256
    test_batch_size = 256
    datasets = 'cifar10'
    strategy = 'progressive'
    data_path = './data/'
    update_cycle = 25
    num_stage = 4
    epsilon = 1
    delta = 0.00001

    main()

And let me introduce my code breifly:
I try to use the method of "Progressive Learning" to train my model. Progressive learning refers to the process where we initially divide the model into smaller blocks and ensure they can be trained by adding some small structures called "head" . At the beginning, only a small part of the model is used. After training for some epochs, subsequent parts are connected in sequence until the model is complete. In the code, the parameter 'update_cycle' controls this process.

Expected behavior

First, if you run the code, it will come up with several warnings:
UserWarning: Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with secure_mode turned on.
UserWarning: Optimal order is the largest alpha. Please consider expanding the range of alphas to get a tighter privacy bound.
RuntimeWarning: invalid value encountered in log
UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.

And then, if you choose the strategy "baseline", the code can run perfectly. But if you choose the strategy "progressive", it will throw "ValueError: Per sample gradient is not initialized. Not updated in backward pass?" error like this:
"Traceback (most recent call last):
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\ProgDP\main.py", line 265, in
main()
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\ProgDP\main.py", line 231, in main
optimizer.step()
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 513, in step
if self.pre_step():
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 494, in pre_step
self.clip_and_accumulate()
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 397, in clip_and_accumulate
if len(self.grad_samples[0]) == 0:
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 345, in grad_samples
ret.append(self._get_flat_grad_sample(p))
File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 282, in _get_flat_grad_sample
raise ValueError(
ValueError: Per sample gradient is not initialized. Not updated in backward pass?"

Environment

python=3.10.13
pytorch=2.1.0
cuda=12.1
opacus=1.4.0
numpy=1.26.0

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions