Skip to content

Pruning callback causes GPU memory leak when used iteratively #8542

Open
@MohammedAljahdali

Description

@MohammedAljahdali

Discussed in #8363

Originally posted by MohammedAljahdali July 10, 2021
Hi, I have a script that does the following logic:

model = Model()
dm = DataModule()
callback_a = ModelPruning()
for _ in range(N):
     callbacks = [CallbackB(), CallbackC(), callback_a]
     trainer = Trainer(callbacks=callbacks, ...)
     trainer.fit(model, dm)
     trainer.test()

There is much more going on, but to keep it simple this is the flow that I have, my question is there another preferable way to what I just did? Also, in my code I have some memory leak, that happens after each loop iteration, could this be somehow related to the trainer object, not being deleted properly?

After trying to reproduce the issue with the boring model, it turned out that the cause of the momery leak is not the reinitialization of the trainer, but the pruning callback itself.

Dependices:

  - python=3.8
  - pip
  - cudatoolkit=10.2
  - pytorch=1.8.1
  - torchvision=0.9.1
  - pytorch-lightning>=1.3.2
  - torchmetrics>=0.3.2
  - wandb>=0.10.30

Code to reproduce:

import os
import gc

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import wandb
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
import torchvision
import pytorch_lightning as pl
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pytorch_lightning.utilities.distributed import rank_zero_only
from collections import OrderedDict, defaultdict


class DM(LightningDataModule):

    def __init__(self, n_features, n_samples, batch_size):
        super().__init__()
        self.train_dataset = torchvision.datasets.CIFAR10(root='/tmp', download=True,
                                                          transform=torchvision.transforms.ToTensor())
        self.val_dataset = torchvision.datasets.CIFAR10(root='/tmp', transform=torchvision.transforms.ToTensor())
        self.test_dataset = torchvision.datasets.CIFAR10(root='/tmp', transform=torchvision.transforms.ToTensor())
        self.batch_size = batch_size

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )


class BoringModel(LightningModule):

    def __init__(self, in_features):
        super().__init__()
        self.run_id = 0
        self.test_counter = 0
        self.net = torchvision.models.resnet18()
        self.net.fc.is_classifier = True
        self.loss = torch.nn.CrossEntropyLoss()

    #         self.layer = torch.nn.Sequential(
    #             torch.nn.Linear(in_features, in_features // 2),
    #             torch.nn.ReLU(),
    #             torch.nn.Linear(in_features // 2, in_features // 4),
    #             torch.nn.ReLU(),
    #             torch.nn.Linear(in_features // 4, in_features // 8),
    #             torch.nn.ReLU(),
    #             torch.nn.Linear(in_features // 8, 2),
    #         )

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"train_loss {self.run_id}", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"valid_loss {self.run_id}", loss)

    def test_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"test_loss {self.run_id}", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.net.parameters(), lr=1)


def run():
    in_features = 224
    N = 20
    dm = DM(n_samples=40, n_features=in_features, batch_size=256)
    model = BoringModel(in_features=in_features)
    loggers = [WandbLogger(project='test_runs', save_dir='/tmp')]
    pruning_callback = pl.callbacks.ModelPruning(
        apply_pruning=True, use_lottery_ticket_hypothesis=True,
        pruning_fn='l1_unstructured', use_global_unstructured=True, verbose=1, make_pruning_permanent=False,
        amount=0.5
    )

    for i in range(N):
        callbacks = [
            pl.callbacks.ModelCheckpoint(monitor=f"valid_loss {model.run_id}", save_top_k=1, save_last=True),
            pl.callbacks.EarlyStopping(monitor=f"valid_loss {model.run_id}", )
        ]
        callbacks.append(pruning_callback)
        trainer = Trainer(
            default_root_dir='/tmp',
            limit_train_batches=4,
            limit_val_batches=4,
            limit_test_batches=4,
            num_sanity_val_steps=0,
            max_epochs=5,
            logger=loggers,
            gpus=1,
            callbacks=callbacks
        )

        trainer.fit(model, datamodule=dm)
        trainer.test(model, ckpt_path='best')
        trainer.test(model)
        model.run_id += 1
        counter = 0
        for obj in gc.get_objects():
            try:
                if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                    counter += 1
            except:
                pass
        print(f"Number of tensors after level {i - 1} is {counter}")

if __name__ == '__main__':
    run()

Metadata

Metadata

Labels

bugSomething isn't workinggood first issueGood for newcomerspriority: 1Medium priority taskwaiting on authorWaiting on user action, correction, or update

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions