Description
Bug description
As mentioned here I am trying to instantiate multiple Fabric instances in a loop (one per iteration). However, after every iteration, the memory consumption goes up.
Is there some fabric.teardown() like method that could help here? Could this memory increase be caused by something else?
This is the memory increase with 3 iterations:
For reference here you can see when a new iteration starts (step count resets):
What version are you seeing the problem on?
v2.0
How to reproduce the bug
I tried to make a self-contained example of my problem (derived from lit-gpt code but simplified to illustrate the problem). Parameters are just chosen to show the problem (using large enough network). devices set to 1 for debugging but also shows up for larger number.
import shutil
import warnings
from pathlib import Path
from types import SimpleNamespace
import lightning as L
import torch
import torch.nn as nn
import wandb
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies import DeepSpeedStrategy
def main():
config = SimpleNamespace(
input_size=1024,
output_size=1024,
hidden_size=8192,
num_layers=5,
devices=1,
ds_config=None,
precision="bf16-true",
entity="x",
wandb_project="project",
run_name="memory_test",
checkpoint_path=None,
learning_rate=1e-4,
weight_decay=1e-5,
max_steps=1000,
batch_size=64,
log_interval=1,
out_dir=Path("./out"),
)
batch_size_per_device = config.batch_size / config.devices
micro_batch_size = 1
config.gradient_accumulation_steps = int(batch_size_per_device // micro_batch_size)
for epoch in range(3):
config.checkpoint_path = run_epoch(config)
def run_epoch(config):
# Load Fabric
fabric = L.Fabric(
devices=config.devices,
strategy=(
DeepSpeedStrategy(config=config.ds_config) if config.devices > 1 else "auto"
),
precision=config.precision,
)
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
if fabric.global_rank == 0:
wandb.init(
entity=config.entity,
project=config.wandb_project,
name=config.run_name,
config=config,
)
with fabric.device:
torch.set_default_tensor_type(torch.HalfTensor)
model = LargeMODEL(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# Load checkpoint if this is not the first epoch
if config.checkpoint_path is not None:
checkpoint_path = config.checkpoint_path
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint, strict=False)
# Setup model and optimizer in fabric
optimizer = torch.optim.AdamW(
model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay
)
model, optimizer = fabric.setup(model, optimizer)
train_data = load_datasets(config)
# Train the model
train(
config,
fabric,
model,
optimizer,
train_data,
)
# Save the final checkpoint at the end of training
save_path = config.out_dir / "model_trained.pth"
fabric.print(f"Saving weights to {str(save_path)!r}")
print(save_path)
save_model_checkpoint(fabric, model, save_path)
return save_path
def train(config, fabric, model, optimizer, train_data):
step_count = 0
max_iters = int(config.max_steps * config.gradient_accumulation_steps)
for iter_num in range(max_iters):
input_ids, targets = get_batch(fabric, train_data, config.batch_size)
with fabric.no_backward_sync(
model, enabled=((iter_num + 1) % config.gradient_accumulation_steps != 0)
):
logits = model(input_ids)
loss = loss_fn(logits, targets)
fabric.backward(loss / config.gradient_accumulation_steps)
if (iter_num + 1) % config.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
step_count += 1
fabric.call("on_train_batch_end", model=model)
# Report performance to command line and wandb
if step_count % config.log_interval == 0:
if fabric.global_rank == 0:
wandb.log(
{
"iter": iter_num,
"step": step_count,
"train/loss": loss,
"train/lr": optimizer.param_groups[0]["lr"],
}
)
def save_model_checkpoint(fabric, model, file_path: Path):
file_path = Path(file_path)
# Ensure the directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
if isinstance(fabric.strategy, DeepSpeedStrategy):
from deepspeed.utils.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
)
tmp_path = file_path.with_suffix(".tmp")
fabric.save(tmp_path, {"model": model})
fabric.barrier()
if fabric.global_rank == 0:
state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
torch.save(state_dict, file_path)
shutil.rmtree(tmp_path)
else:
if fabric.global_rank == 0:
state_dict = model.state_dict()
torch.save(state_dict, file_path)
fabric.barrier()
class LargeMODEL(nn.Module):
def __init__(self, config):
super(LargeMODEL, self).__init__()
layers = []
input_size = config.input_size
for _ in range(config.num_layers):
layers.append(nn.Linear(input_size, config.hidden_size))
layers.append(nn.ReLU())
input_size = config.hidden_size
layers.append(nn.Linear(config.hidden_size, config.output_size))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def load_datasets(config):
# For the sake of this demonstration, let's return some random data
data = torch.randn(1000, config.input_size), torch.randn(1000, config.output_size)
return data
def get_batch(fabric, data, batch_size):
input_ids, targets = data
ix = torch.randint(len(data), (batch_size,))
x, y = input_ids[ix], targets[ix]
# Return the first batch_size samples from the dataset
if isinstance(fabric.accelerator, MPSAccelerator):
x, y = fabric.to_device((x, y))
else:
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y
def loss_fn(logits, targets):
return ((logits - targets) ** 2).mean()
if __name__ == "__main__":
# Uncomment this line if you see an error:
# "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
# from jsonargparse.cli import CLI
warnings.filterwarnings(
# false positive using deepspeed:
# https://github.com/Lightning-AI/lightning/pull/17761#discussion_r1219705307
"ignore",
message="Remove `.no_backward_sync()` from your code",
)
main()
### Error messages and logs
No error message in the example code above. If the network size increases I get the following warning:
[WARNING] [stage3.py:1898:step] 14 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
If the size increases further or I use less GPUs I simply get an OOM error.
### Environment
<details>
<summary>Current environment</summary>
`
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.0.6
#- Lightning App Version (e.g., 0.5.2): 2.1.0.dev0
#- PyTorch Version (e.g., 2.0): 2.1.0.dev20230801+cu118
#- Python version (e.g., 3.9): 3.10.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration: NVIDIA A100-SXM4-80GB
#- How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response