Skip to content

load_from_checkpoint leads to CUDA errors while trying multi-gpu training with SLURM #18759

Open
@ashar-wfr

Description

@ashar-wfr

Bug description

I am trying to train a model with multiple GPUs as suggested in the official documentation using SLURM: Lightning with SLURM

Once I have a trained model checkpoint, I try to load the model checkpoint with the similar code (just the load_from_checkpoint added). It works if I set --ntasks-per-node=1 (but it doesn't effectively use the 2 gpus). However, if I set --ntasks-per-node=2 corresponding to the documentation, I get CUDA errors. Wondering if others face similar issue while trying to load from checkpoint with multiple gpus on SLURM?

Pytorch Lightning version: 2.0.9
SLURM Command used to run this: srun -u -c 20 --mem=500G --gres=gpu:2 --ntasks-per-node=2 python train.py

What version are you seeing the problem on?

master

How to reproduce the bug

# Based on https://lightning.ai/docs/pytorch/stable/starter/introduction.html

import os
import torch
import numpy as np
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl

# define the LightningModule
class LitModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 1))

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x = batch
        z = self.encoder(x)
        loss = nn.functional.mse_loss(z, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

class Dataset(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.dataset = torch.rand(64, 784)
    
    def train_dataloader(self):
        # setup data
        dist_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
        dataloader = torch.utils.data.DataLoader(self.dataset, sampler=dist_sampler)
        return dataloader
    
    def val_dataloader(self):
        # setup data
        dist_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
        dataloader = torch.utils.data.DataLoader(self.dataset, sampler=dist_sampler)
        return dataloader

def main():
    # init the autoencoder
    # Used for training
    # model = LitModule()
    checkpoint_path="trained_model.ckpt"
    model = LitModule().load_from_checkpoint(checkpoint_path=checkpoint_path)

    datamodule = Dataset()

    # train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
    trainer = pl.Trainer(strategy="ddp",
                         use_distributed_sampler=False,
                         devices=2,
                         accelerator="gpu", 
                         limit_train_batches=100, 
                         max_epochs=1)
    trainer.fit(model=model, datamodule=datamodule)

if __name__ == '__main__':
    main()

# SLURM Command used to run this: srun -u -c 20 --mem=500G  --gres=gpu:2 --ntasks-per-node=2 python train.py

Error messages and logs

Traceback (most recent call last):                                                                                                 
  File "/debugging_pl/train.py", line 68, in <module>                                           
    main()                                                                                                                         
  File "/debugging_pl/train.py", line 54, in main
    model = LitModule().load_from_checkpoint(checkpoint_path=checkpoint_path)                                             [25/1830]
  File "/miniconda3/envs/train/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1543, in lo
ad_from_checkpoint                                                                                                                 
    loaded = _load_from_checkpoint(                                                                                                
  File "/miniconda3/envs/train/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 63, in _loa
d_from_checkpoint                                                                                                                  
    checkpoint = pl_load(checkpoint_path, map_location=map_location)                                                               
  File "/miniconda3/envs/train/lib/python3.9/site-packages/lightning_fabric/utilities/cloud_io.py", line 52, i
n _load                                                                                                                            
    return torch.load(f, map_location=map_location)  # type: ignore[arg-type]                                                      
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/serialization.py", line 809, in load         
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)                                                  
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/serialization.py", line 1172, in _load       
    result = unpickler.load()                                                                                                      
  File "/miniconda3/envs/train/lib/python3.9/pickle.py", line 1212, in load                                   
    dispatch[key[0]](self)                                                                                                         
  File "/miniconda3/envs/train/lib/python3.9/pickle.py", line 1253, in load_binpersid                         
    self.append(self.persistent_load(pid))                                                                                         
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/serialization.py", line 1142, in persistent_l
oad                                                                                                                                
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))                                                 
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/serialization.py", line 1116, in load_tensor 
    wrap_storage=restore_location(storage, location),                                                                              
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/serialization.py", line 217, in default_restore_location                                                                                                                        
    result = fn(storage, location)                                                                                                 
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/serialization.py", line 187, in _cuda_deseria
lize
    return obj.cuda(device)
  File "/miniconda3/envs/train/lib/python3.9/site-packages/torch/_utils.py", line 84, in _cuda
    untyped_storage.copy_(self, non_blocking)
RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingver: 2.1.xwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions