Skip to content

CUDA OOM with DeepSpeed ZeRO Stage 3 Offload #18134

Open
@rggs

Description

@rggs

Bug description

I'm adapting the finetuning tutorial to use the Falcon-40B model with DeepSpeed ZeRO Stage 3 Offload on a Slurm cluster. Available to lightning are 16 V100s, each with 32GB of storage, plus an additional 1.2TB or so of CPU RAM, but when training starts I still see a memory spike and the get a CUDA OOM error:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 31.75 GiB total capacity; 30.25 GiB already allocated; 62.19 MiB free; 30.67 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    param_coordinator.fetch_sub_module(sub_module, forward=True)
    param_buffer = torch.empty(

It seems like it's happening at the gathering stage, this appears earlier in the error trace:

  File ".../miniconda/envs/fsdp/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 424, in __all_gather_params
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File ".../miniconda/envs/fsdp/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 424, in __all_gather_params
              ^^^^^^^^^^^^^^^^^^^^^
  File ".../miniconda/envs/fsdp/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 424, in __all_gather_params
    ret_val = func(*args, **kwargs)

What version are you seeing the problem on?

2.0.5

How to reproduce the bug

from datetime import datetime
from typing import Optional

import datasets
import transformers
import evaluate
import torch
from lightning.pytorch import LightningDataModule, LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoModelForCausalLM
)
from torch.distributed.fsdp.wrap import wrap
import functools
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam




from accelerate import init_empty_weights, load_checkpoint_and_dispatch

from lightning.pytorch.strategies import FSDPStrategy, DeepSpeedStrategy
import os
import lightning.pytorch as pl
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)

import warnings
warnings.filterwarnings("ignore")


class GLUEDataModule(LightningDataModule):
    task_text_field_map = {
        "cola": ["sentence"],
        "sst2": ["sentence"],
        "mrpc": ["sentence1", "sentence2"],
        "qqp": ["question1", "question2"],
        "stsb": ["sentence1", "sentence2"],
        "mnli": ["premise", "hypothesis"],
        "qnli": ["question", "sentence"],
        "rte": ["sentence1", "sentence2"],
        "wnli": ["sentence1", "sentence2"],
        "ax": ["premise", "hypothesis"],
    }

    glue_task_num_labels = {
        "cola": 2,
        "sst2": 2,
        "mrpc": 2,
        "qqp": 2,
        "stsb": 1,
        "mnli": 3,
        "qnli": 2,
        "rte": 2,
        "wnli": 2,
        "ax": 3,
    }

    loader_columns = [
        "datasets_idx",
        "input_ids",
        # "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = "mrpc",
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, padding=True, truncation=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def setup(self, stage: str):
        self.dataset = datasets.load_dataset("glue", self.task_name)

        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                batch_size=len(self.dataset[split]),
                remove_columns=["label"],
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)
            
        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

    def prepare_data(self):
        datasets.load_dataset("glue", self.task_name)
        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def train_dataloader(self):
        return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size, num_workers=2)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size, num_workers=2) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size, num_workers=2)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size, num_workers=2) for x in self.eval_splits]

    def convert_to_features(self, example_batch, indices=None):
        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer(texts_or_text_pairs,  padding=True) #truncation=True, , max_length=self.max_seq_length

        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch["label"]

        return features
    
    
class FalconTransformer(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        task_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels, trust_remote_code=True, revision="main")
        self.metric = evaluate.load(
            "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )
        self.validation_step_outputs = []
        self.model_name = model_name_or_path

    def configure_sharded_model(self):
        self.model = AutoModelForSequenceClassification.from_config(config=self.config, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True, padding=True, truncation=True)
        self.model.config.pad_token_id = self.tokenizer.eos_token_id # We use the eos_token_id as the pad_token_id because the eos_token_id is not used for causal language modeling.
        # self.model = torch.hub.load('huggingface/pytorch-transformers', 'model', self.model_name, trust_remote_code=True)


    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]
        self.validation_step_outputs.append({"loss": val_loss, "preds": preds, "labels": labels})

        return {"loss": val_loss, "preds": preds, "labels": labels}

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        if self.hparams.task_name == "mnli":
            for i, output in enumerate(outputs):
                # matched or mismatched
                split = self.hparams.eval_splits[i].split("_")[-1]
                preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x["loss"] for x in output]).mean()
                self.log(f"val_loss_{split}", loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            return loss
        

        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        
        # return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
        # return FusedAdam(self.parameters())
        return DeepSpeedCPUAdam(self.parameters(), lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
    
if __name__ == "__main__":
    
    transformers.logging.set_verbosity_error()
    datasets.logging.set_verbosity_error()
    
    
    seed_everything(42)
    
    assert torch.cuda.is_available(), "GPU Required"    
    
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )
    
    dm = GLUEDataModule(model_name_or_path="tiiuae/falcon-40b", task_name="cola")
    dm.setup("fit")
    
    model = FalconTransformer(
        model_name_or_path="tiiuae/falcon-40b",
        num_labels=dm.num_labels,
        eval_splits=dm.eval_splits,
        task_name=dm.task_name,
        train_batch_size=2,
        eval_batch_size=2
    )
    
    # strategy = FSDPStrategy(cpu_offload=True, auto_wrap_policy=my_auto_wrap_policy)
    strategy = DeepSpeedStrategy(
        stage=3,
        offload_optimizer=True,
        offload_parameters=True,
    )

    trainer = pl.Trainer(
        max_epochs=5,
        accelerator="gpu",
        devices=4,
        num_nodes=int(os.environ.get("SLURM_NNODES")),
        strategy=strategy,
        precision="16-mixed"
    )
    
    trainer.fit(model, datamodule=dm)
Slurm Submission Script
#!/bin/bash
#SBATCH --job-name=falcon-finetune
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:4
#SBATCH --output=falcon-finetune.out
#SBATCH --error=falcon-finetune.err

export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZE

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

export MASTER_ADDR=$head_node_ip

echo Node IP: $head_node_ip

export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO

module load miniconda
conda activate fsdp

export OMP_NUM_THREADS=12
export PYTHONFAULTHANDLER=1

srun python fsdp_train.py

Error messages and logs

falcon-finetune-err.txt

Environment

Current environment
  • CUDA:
    - This is from the head node, there are GPUs available
    - GPU: None
    - available: False
    - version: 11.8
  • Lightning:
    - lightning: 2.0.5
    - lightning-cloud: 0.5.37
    - lightning-utilities: 0.9.0
    - pytorch-lightning: 2.0.5
    - torch: 2.0.1+cu118
    - torchaudio: 2.0.2+cu118
    - torchmetrics: 1.0.1
    - torchvision: 0.15.2+cu118
  • Packages:
    - absl-py: 1.4.0
    - accelerate: 0.21.0
    - aiohttp: 3.8.5
    - aiosignal: 1.3.1
    - anyio: 3.7.1
    - arrow: 1.2.3
    - async-timeout: 4.0.2
    - attrs: 23.1.0
    - backoff: 2.2.1
    - beautifulsoup4: 4.12.2
    - blessed: 1.20.0
    - boto3: 1.28.8
    - botocore: 1.31.8
    - cachetools: 5.3.1
    - certifi: 2022.12.7
    - charset-normalizer: 2.1.1
    - click: 8.1.6
    - cmake: 3.25.0
    - croniter: 1.4.1
    - datasets: 2.13.1
    - dateutils: 0.6.12
    - deepdiff: 6.3.1
    - deepspeed: 0.10.0
    - dill: 0.3.6
    - einops: 0.6.1
    - evaluate: 0.4.0
    - fastapi: 0.100.0
    - filelock: 3.9.0
    - frozenlist: 1.4.0
    - fsspec: 2023.6.0
    - google-auth: 2.22.0
    - google-auth-oauthlib: 1.0.0
    - grpcio: 1.56.2
    - h11: 0.14.0
    - hjson: 3.1.0
    - huggingface-hub: 0.16.4
    - idna: 3.4
    - importlib-metadata: 6.8.0
    - inquirer: 3.1.3
    - itsdangerous: 2.1.2
    - jinja2: 3.1.2
    - jmespath: 1.0.1
    - joblib: 1.3.1
    - lightning: 2.0.5
    - lightning-cloud: 0.5.37
    - lightning-utilities: 0.9.0
    - lit: 15.0.7
    - markdown: 3.4.3
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.2
    - mdurl: 0.1.2
    - mpmath: 1.2.1
    - multidict: 6.0.4
    - multiprocess: 0.70.14
    - networkx: 3.0
    - ninja: 1.11.1
    - numpy: 1.24.1
    - oauthlib: 3.2.2
    - ordered-set: 4.1.0
    - packaging: 23.1
    - pandas: 2.0.3
    - pillow: 9.3.0
    - pip: 22.2.2
    - protobuf: 4.23.4
    - psutil: 5.9.5
    - py-cpuinfo: 9.0.0
    - pyarrow: 12.0.1
    - pyasn1: 0.5.0
    - pyasn1-modules: 0.3.0
    - pydantic: 1.10.11
    - pygments: 2.15.1
    - pyjwt: 2.8.0
    - python-dateutil: 2.8.2
    - python-editor: 1.0.4
    - python-multipart: 0.0.6
    - pytorch-lightning: 2.0.5
    - pytz: 2023.3
    - pyyaml: 6.0.1
    - readchar: 4.0.5
    - regex: 2023.6.3
    - requests: 2.28.1
    - requests-oauthlib: 1.3.1
    - responses: 0.18.0
    - rich: 13.4.2
    - rsa: 4.9
    - s3transfer: 0.6.1
    - sacremoses: 0.0.53
    - safetensors: 0.3.1
    - scikit-learn: 1.3.0
    - scipy: 1.11.1
    - sentencepiece: 0.1.99
    - setuptools: 68.0.0
    - six: 1.16.0
    - snakeviz: 2.2.0
    - sniffio: 1.3.0
    - soupsieve: 2.4.1
    - starlette: 0.27.0
    - starsessions: 1.3.0
    - sympy: 1.11.1
    - tensorboard: 2.13.0
    - tensorboard-data-server: 0.7.1
    - threadpoolctl: 3.2.0
    - tokenizers: 0.13.3
    - torch: 2.0.1+cu118
    - torchaudio: 2.0.2+cu118
    - torchmetrics: 1.0.1
    - torchvision: 0.15.2+cu118
    - tornado: 6.3.2
    - tqdm: 4.65.0
    - traitlets: 5.9.0
    - transformers: 4.31.0
    - triton: 2.0.0
    - typing-extensions: 4.7.1
    - tzdata: 2023.3
    - urllib3: 1.26.13
    - uvicorn: 0.23.1
    - wcwidth: 0.2.6
    - websocket-client: 1.6.1
    - websockets: 11.0.3
    - werkzeug: 2.3.6
    - wheel: 0.37.1
    - xxhash: 3.2.0
    - yarl: 1.9.2
    - zipp: 3.16.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.0
    - release: 5.4.0-152-generic
    - version: empty meta_tags.csv #169-Ubuntu SMP Tue Jun 6 22:23:09 UTC 2023

More info

No response

cc @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions