Skip to content

Save 4bits llama model to Torchscript failed #1009

Open
@dcy0577

Description

@dcy0577

System Info

bitsandbytes=0.42.0
transformers=4.37.1

Reproduction

import torch
import torch.nn as nn
import transformers
from transformers import BitsAndBytesConfig

class Wrapper(nn.Module):
    def __init__(self):
        super().__init__()

        # get the llama model
        llama_model = transformers.LlamaModel 

        bnb_config_4bit = BitsAndBytesConfig(  
                    load_in_4bit=True, 
                    bnb_4bit_use_double_quant=False, # whether to use double quantization
                    bnb_4bit_quant_type="nf4",  
                    bnb_4bit_compute_dtype=torch.float16) # 4 bits qlora

        # load the based model weights with quantization
        self.model = llama_model.from_pretrained(
                        "Llama_weights/llama-2-7b-hf-weights", 
                        low_cpu_mem_usage=True, 
                        device_map= 0, 
                        quantization_config=bnb_config_4bit,
                        torchscript=True,)
        
        self.model.output_hidden_states = False
        
    def forward(self, tokens_tensor):
        self.model.eval()
        o = self.model(tokens_tensor, output_hidden_states=False)
        return o[0]


model = Wrapper()
model.eval()    
with torch.no_grad():
    dummy_tokens_tensor = torch.randint(0, 1000, (1, 50), dtype=torch.long).to("cuda")
    outputs = model(dummy_tokens_tensor)
    trace_model = torch.jit.trace(model, [dummy_tokens_tensor]) # this works, but with some trace waring
    print("traced_model done")
    torch.jit.save(trace_model, "llama_4bit.pt") # --> error!

The linetorch.jit.save(trace_model, "llama_4bit.pt") # --> error! gives me error:

RuntimeError: 
Could not export Python function call 'MatMul4Bit'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/autograd/function.py(506): apply
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py(577): matmul_4bit
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/nn/modules.py(256): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(386): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(798): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(1070): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home///test/test_torchscript_llama.py(62): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(1056): trace_module
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(794): trace

Expected behavior

troch.jit.save shall work properly for quantized model

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghuggingface-relatedA bug that is likely due to the interaction between bnb and HF libs (transformers, accelerate, peft)medium priority(will be worked on after all high priority issues)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions