Skip to content

when using huggingface pretrained model with multi-gpu, model parameters were duplicate for every gpu in ram #17043

Open
@linyubupa

Description

@linyubupa

Bug description

when using huggingface pretrained model with multi-gpu, model parameters were duplicate for every gpu in ram

How to reproduce the bug

trainer = Trainer(
        max_epochs=1,
        devices=args.num_devices,
        precision=16,
        strategy="deepspeed_stage_3",
        accelerator='gpu',
        num_nodes=args.num_nodes,
    
    )
from transformers import (
    AdamW,
    GPTNeoForCausalLM,
    GPT2Tokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
    get_linear_schedule_with_warmup,
)
class AlpsModule(LightningModule):
    def __init__(
        self,
        model_name_or_path: str = "EleutherAI/gpt-j-6B",
        cache_dir: str ="/mntnlp/yumu/gpt-neo-x/" ,
        num_labels: int = 2,
        learning_rate: float = 5e-6,
        adam_epsilon: float = 3e-8,
        warmup_steps: int = 30,
        weight_decay: float = 0.01,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
    

        self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path
                                                        ,pad_token_id=self.tokenizer.pad_token_id
                                                        ,bos_token_id=self.tokenizer.bos_token_id
                                                        ,eos_token_id=self.tokenizer.eos_token_id
                                                        , cache_dir=cache_dir
                                                         # ,low_cpu_mem_usage=True
                                                        ).half()

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions