Description
Bug description
I'm adapting the finetuning tutorial to use the Falcon-40B model with DeepSpeed ZeRO Stage 3 Offload on a Slurm cluster. Available to lightning are 16 V100s, each with 32GB of storage, plus an additional 1.2TB or so of CPU RAM, but when training starts I still see a memory spike and the get a CUDA OOM error:
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 31.75 GiB total capacity; 30.25 GiB already allocated; 62.19 MiB free; 30.67 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
param_coordinator.fetch_sub_module(sub_module, forward=True)
param_buffer = torch.empty(
It seems like it's happening at the gathering stage, this appears earlier in the error trace:
File ".../miniconda/envs/fsdp/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 424, in __all_gather_params
ret_val = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File ".../miniconda/envs/fsdp/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 424, in __all_gather_params
^^^^^^^^^^^^^^^^^^^^^
File ".../miniconda/envs/fsdp/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 424, in __all_gather_params
ret_val = func(*args, **kwargs)
What version are you seeing the problem on?
2.0.5
How to reproduce the bug
from datetime import datetime
from typing import Optional
import datasets
import transformers
import evaluate
import torch
from lightning.pytorch import LightningDataModule, LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader
from transformers import (
AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
AutoModelForCausalLM
)
from torch.distributed.fsdp.wrap import wrap
import functools
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from lightning.pytorch.strategies import FSDPStrategy, DeepSpeedStrategy
import os
import lightning.pytorch as pl
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
import warnings
warnings.filterwarnings("ignore")
class GLUEDataModule(LightningDataModule):
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
"mrpc": ["sentence1", "sentence2"],
"qqp": ["question1", "question2"],
"stsb": ["sentence1", "sentence2"],
"mnli": ["premise", "hypothesis"],
"qnli": ["question", "sentence"],
"rte": ["sentence1", "sentence2"],
"wnli": ["sentence1", "sentence2"],
"ax": ["premise", "hypothesis"],
}
glue_task_num_labels = {
"cola": 2,
"sst2": 2,
"mrpc": 2,
"qqp": 2,
"stsb": 1,
"mnli": 3,
"qnli": 2,
"rte": 2,
"wnli": 2,
"ax": 3,
}
loader_columns = [
"datasets_idx",
"input_ids",
# "token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]
def __init__(
self,
model_name_or_path: str,
task_name: str = "mrpc",
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.task_name = task_name
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, padding=True, truncation=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
def setup(self, stage: str):
self.dataset = datasets.load_dataset("glue", self.task_name)
for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
batch_size=len(self.dataset[split]),
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
def prepare_data(self):
datasets.load_dataset("glue", self.task_name)
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, num_workers=2)
def val_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size, num_workers=2)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size, num_workers=2) for x in self.eval_splits]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size, num_workers=2)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size, num_workers=2) for x in self.eval_splits]
def convert_to_features(self, example_batch, indices=None):
# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]
# Tokenize the text/text pairs
features = self.tokenizer(texts_or_text_pairs, padding=True) #truncation=True, , max_length=self.max_seq_length
# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]
return features
class FalconTransformer(LightningModule):
def __init__(
self,
model_name_or_path: str,
num_labels: int,
task_name: str,
learning_rate: float = 2e-5,
adam_epsilon: float = 1e-8,
warmup_steps: int = 0,
weight_decay: float = 0.0,
train_batch_size: int = 32,
eval_batch_size: int = 32,
eval_splits: Optional[list] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels, trust_remote_code=True, revision="main")
self.metric = evaluate.load(
"glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
)
self.validation_step_outputs = []
self.model_name = model_name_or_path
def configure_sharded_model(self):
self.model = AutoModelForSequenceClassification.from_config(config=self.config, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True, padding=True, truncation=True)
self.model.config.pad_token_id = self.tokenizer.eos_token_id # We use the eos_token_id as the pad_token_id because the eos_token_id is not used for causal language modeling.
# self.model = torch.hub.load('huggingface/pytorch-transformers', 'model', self.model_name, trust_remote_code=True)
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs[0]
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(**batch)
val_loss, logits = outputs[:2]
if self.hparams.num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif self.hparams.num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
self.validation_step_outputs.append({"loss": val_loss, "preds": preds, "labels": labels})
return {"loss": val_loss, "preds": preds, "labels": labels}
def on_validation_epoch_end(self):
outputs = self.validation_step_outputs
if self.hparams.task_name == "mnli":
for i, output in enumerate(outputs):
# matched or mismatched
split = self.hparams.eval_splits[i].split("_")[-1]
preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in output]).mean()
self.log(f"val_loss_{split}", loss, prog_bar=True)
split_metrics = {
f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
}
self.log_dict(split_metrics, prog_bar=True)
return loss
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("val_loss", loss, prog_bar=True)
self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
self.validation_step_outputs.clear()
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
# return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
# return FusedAdam(self.parameters())
return DeepSpeedCPUAdam(self.parameters(), lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
if __name__ == "__main__":
transformers.logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
seed_everything(42)
assert torch.cuda.is_available(), "GPU Required"
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
dm = GLUEDataModule(model_name_or_path="tiiuae/falcon-40b", task_name="cola")
dm.setup("fit")
model = FalconTransformer(
model_name_or_path="tiiuae/falcon-40b",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
train_batch_size=2,
eval_batch_size=2
)
# strategy = FSDPStrategy(cpu_offload=True, auto_wrap_policy=my_auto_wrap_policy)
strategy = DeepSpeedStrategy(
stage=3,
offload_optimizer=True,
offload_parameters=True,
)
trainer = pl.Trainer(
max_epochs=5,
accelerator="gpu",
devices=4,
num_nodes=int(os.environ.get("SLURM_NNODES")),
strategy=strategy,
precision="16-mixed"
)
trainer.fit(model, datamodule=dm)
Slurm Submission Script
#!/bin/bash #SBATCH --job-name=falcon-finetune #SBATCH --nodes=4 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=8 #SBATCH --gres=gpu:4 #SBATCH --output=falcon-finetune.out #SBATCH --error=falcon-finetune.errexport MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZEnodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)export MASTER_ADDR=$head_node_ip
echo Node IP: $head_node_ip
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFOmodule load miniconda
conda activate fsdpexport OMP_NUM_THREADS=12
export PYTHONFAULTHANDLER=1srun python fsdp_train.py
Error messages and logs
Environment
Current environment
- CUDA:
- This is from the head node, there are GPUs available
- GPU: None
- available: False
- version: 11.8 - Lightning:
- lightning: 2.0.5
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.0.5
- torch: 2.0.1+cu118
- torchaudio: 2.0.2+cu118
- torchmetrics: 1.0.1
- torchvision: 0.15.2+cu118 - Packages:
- absl-py: 1.4.0
- accelerate: 0.21.0
- aiohttp: 3.8.5
- aiosignal: 1.3.1
- anyio: 3.7.1
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 23.1.0
- backoff: 2.2.1
- beautifulsoup4: 4.12.2
- blessed: 1.20.0
- boto3: 1.28.8
- botocore: 1.31.8
- cachetools: 5.3.1
- certifi: 2022.12.7
- charset-normalizer: 2.1.1
- click: 8.1.6
- cmake: 3.25.0
- croniter: 1.4.1
- datasets: 2.13.1
- dateutils: 0.6.12
- deepdiff: 6.3.1
- deepspeed: 0.10.0
- dill: 0.3.6
- einops: 0.6.1
- evaluate: 0.4.0
- fastapi: 0.100.0
- filelock: 3.9.0
- frozenlist: 1.4.0
- fsspec: 2023.6.0
- google-auth: 2.22.0
- google-auth-oauthlib: 1.0.0
- grpcio: 1.56.2
- h11: 0.14.0
- hjson: 3.1.0
- huggingface-hub: 0.16.4
- idna: 3.4
- importlib-metadata: 6.8.0
- inquirer: 3.1.3
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- jmespath: 1.0.1
- joblib: 1.3.1
- lightning: 2.0.5
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- lit: 15.0.7
- markdown: 3.4.3
- markdown-it-py: 3.0.0
- markupsafe: 2.1.2
- mdurl: 0.1.2
- mpmath: 1.2.1
- multidict: 6.0.4
- multiprocess: 0.70.14
- networkx: 3.0
- ninja: 1.11.1
- numpy: 1.24.1
- oauthlib: 3.2.2
- ordered-set: 4.1.0
- packaging: 23.1
- pandas: 2.0.3
- pillow: 9.3.0
- pip: 22.2.2
- protobuf: 4.23.4
- psutil: 5.9.5
- py-cpuinfo: 9.0.0
- pyarrow: 12.0.1
- pyasn1: 0.5.0
- pyasn1-modules: 0.3.0
- pydantic: 1.10.11
- pygments: 2.15.1
- pyjwt: 2.8.0
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.5
- pytz: 2023.3
- pyyaml: 6.0.1
- readchar: 4.0.5
- regex: 2023.6.3
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- responses: 0.18.0
- rich: 13.4.2
- rsa: 4.9
- s3transfer: 0.6.1
- sacremoses: 0.0.53
- safetensors: 0.3.1
- scikit-learn: 1.3.0
- scipy: 1.11.1
- sentencepiece: 0.1.99
- setuptools: 68.0.0
- six: 1.16.0
- snakeviz: 2.2.0
- sniffio: 1.3.0
- soupsieve: 2.4.1
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.11.1
- tensorboard: 2.13.0
- tensorboard-data-server: 0.7.1
- threadpoolctl: 3.2.0
- tokenizers: 0.13.3
- torch: 2.0.1+cu118
- torchaudio: 2.0.2+cu118
- torchmetrics: 1.0.1
- torchvision: 0.15.2+cu118
- tornado: 6.3.2
- tqdm: 4.65.0
- traitlets: 5.9.0
- transformers: 4.31.0
- triton: 2.0.0
- typing-extensions: 4.7.1
- tzdata: 2023.3
- urllib3: 1.26.13
- uvicorn: 0.23.1
- wcwidth: 0.2.6
- websocket-client: 1.6.1
- websockets: 11.0.3
- werkzeug: 2.3.6
- wheel: 0.37.1
- xxhash: 3.2.0
- yarl: 1.9.2
- zipp: 3.16.2 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.0
- release: 5.4.0-152-generic
- version: empty meta_tags.csv #169-Ubuntu SMP Tue Jun 6 22:23:09 UTC 2023
More info
No response
cc @awaelchli