Skip to content

Memory Leak when instantiating Fabric multiple times #18356

Open
@vkakerbeck

Description

@vkakerbeck

Bug description

As mentioned here I am trying to instantiate multiple Fabric instances in a loop (one per iteration). However, after every iteration, the memory consumption goes up.
Is there some fabric.teardown() like method that could help here? Could this memory increase be caused by something else?

This is the memory increase with 3 iterations:
Screenshot 2023-08-21 at 11 38 25 AM

For reference here you can see when a new iteration starts (step count resets):
Screenshot 2023-08-21 at 11 38 33 AM

What version are you seeing the problem on?

v2.0

How to reproduce the bug

I tried to make a self-contained example of my problem (derived from lit-gpt code but simplified to illustrate the problem). Parameters are just chosen to show the problem (using large enough network). devices set to 1 for debugging but also shows up for larger number.


import shutil
import warnings
from pathlib import Path
from types import SimpleNamespace

import lightning as L
import torch
import torch.nn as nn
import wandb
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies import DeepSpeedStrategy


def main():
    config = SimpleNamespace(
        input_size=1024,
        output_size=1024,
        hidden_size=8192,
        num_layers=5,
        devices=1,
        ds_config=None,
        precision="bf16-true",
        entity="x",
        wandb_project="project",
        run_name="memory_test",
        checkpoint_path=None,
        learning_rate=1e-4,
        weight_decay=1e-5,
        max_steps=1000,
        batch_size=64,
        log_interval=1,
        out_dir=Path("./out"),
    )

    batch_size_per_device = config.batch_size / config.devices
    micro_batch_size = 1
    config.gradient_accumulation_steps = int(batch_size_per_device // micro_batch_size)

    for epoch in range(3):
        config.checkpoint_path = run_epoch(config)


def run_epoch(config):
    # Load Fabric
    fabric = L.Fabric(
        devices=config.devices,
        strategy=(
            DeepSpeedStrategy(config=config.ds_config) if config.devices > 1 else "auto"
        ),
        precision=config.precision,
    )
    fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)

    if fabric.global_rank == 0:
        wandb.init(
            entity=config.entity,
            project=config.wandb_project,
            name=config.run_name,
            config=config,
        )

    with fabric.device:
        torch.set_default_tensor_type(torch.HalfTensor)
        model = LargeMODEL(config).bfloat16()
        torch.set_default_tensor_type(torch.FloatTensor)
        # Load checkpoint if this is not the first epoch
        if config.checkpoint_path is not None:
            checkpoint_path = config.checkpoint_path
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint, strict=False)

    # Setup model and optimizer in fabric
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay
    )
    model, optimizer = fabric.setup(model, optimizer)

    train_data = load_datasets(config)

    # Train the model
    train(
        config,
        fabric,
        model,
        optimizer,
        train_data,
    )

    # Save the final checkpoint at the end of training
    save_path = config.out_dir / "model_trained.pth"
    fabric.print(f"Saving weights to {str(save_path)!r}")
    print(save_path)
    save_model_checkpoint(fabric, model, save_path)
    return save_path


def train(config, fabric, model, optimizer, train_data):
    step_count = 0
    max_iters = int(config.max_steps * config.gradient_accumulation_steps)
    for iter_num in range(max_iters):
        input_ids, targets = get_batch(fabric, train_data, config.batch_size)
        with fabric.no_backward_sync(
            model, enabled=((iter_num + 1) % config.gradient_accumulation_steps != 0)
        ):
            logits = model(input_ids)
            loss = loss_fn(logits, targets)
            fabric.backward(loss / config.gradient_accumulation_steps)
        if (iter_num + 1) % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            step_count += 1
            fabric.call("on_train_batch_end", model=model)

        # Report performance to command line and wandb
        if step_count % config.log_interval == 0:
            if fabric.global_rank == 0:
                wandb.log(
                    {
                        "iter": iter_num,
                        "step": step_count,
                        "train/loss": loss,
                        "train/lr": optimizer.param_groups[0]["lr"],
                    }
                )


def save_model_checkpoint(fabric, model, file_path: Path):
    file_path = Path(file_path)
    # Ensure the directory exists
    file_path.parent.mkdir(parents=True, exist_ok=True)

    if isinstance(fabric.strategy, DeepSpeedStrategy):
        from deepspeed.utils.zero_to_fp32 import (
            get_fp32_state_dict_from_zero_checkpoint,
        )

        tmp_path = file_path.with_suffix(".tmp")
        fabric.save(tmp_path, {"model": model})
        fabric.barrier()
        if fabric.global_rank == 0:
            state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
            torch.save(state_dict, file_path)
            shutil.rmtree(tmp_path)
    else:
        if fabric.global_rank == 0:
            state_dict = model.state_dict()
            torch.save(state_dict, file_path)
        fabric.barrier()

class LargeMODEL(nn.Module):
    def __init__(self, config):
        super(LargeMODEL, self).__init__()

        layers = []
        input_size = config.input_size
        for _ in range(config.num_layers):
            layers.append(nn.Linear(input_size, config.hidden_size))
            layers.append(nn.ReLU())
            input_size = config.hidden_size

        layers.append(nn.Linear(config.hidden_size, config.output_size))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def load_datasets(config):
    # For the sake of this demonstration, let's return some random data
    data = torch.randn(1000, config.input_size), torch.randn(1000, config.output_size)
    return data


def get_batch(fabric, data, batch_size):
    input_ids, targets = data
    ix = torch.randint(len(data), (batch_size,))
    x, y = input_ids[ix], targets[ix]
    # Return the first batch_size samples from the dataset
    if isinstance(fabric.accelerator, MPSAccelerator):
        x, y = fabric.to_device((x, y))
    else:
        x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
    return x, y


def loss_fn(logits, targets):
    return ((logits - targets) ** 2).mean()


if __name__ == "__main__":
    # Uncomment this line if you see an error:
    # "Expected is_sm80 to be true, but got false"
    # torch.backends.cuda.enable_flash_sdp(False)
    torch.set_float32_matmul_precision("high")

    # from jsonargparse.cli import CLI
    warnings.filterwarnings(
        # false positive using deepspeed:
        # https://github.com/Lightning-AI/lightning/pull/17761#discussion_r1219705307
        "ignore",
        message="Remove `.no_backward_sync()` from your code",
    )
    main()


### Error messages and logs

No error message in the example code above. If the network size increases I get the following warning:

[WARNING] [stage3.py:1898:step] 14 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time


If the size increases further or I use less GPUs I simply get an OOM error.


### Environment

<details>
  <summary>Current environment</summary>

`
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.0.6
#- Lightning App Version (e.g., 0.5.2): 2.1.0.dev0
#- PyTorch Version (e.g., 2.0): 2.1.0.dev20230801+cu118
#- Python version (e.g., 3.9): 3.10.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration: NVIDIA A100-SXM4-80GB
#- How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud): 

More info

No response

cc @carmocca @justusschock @awaelchli @Borda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions