Open
Description
System Info
bitsandbytes=0.42.0
transformers=4.37.1
Reproduction
import torch
import torch.nn as nn
import transformers
from transformers import BitsAndBytesConfig
class Wrapper(nn.Module):
def __init__(self):
super().__init__()
# get the llama model
llama_model = transformers.LlamaModel
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False, # whether to use double quantization
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16) # 4 bits qlora
# load the based model weights with quantization
self.model = llama_model.from_pretrained(
"Llama_weights/llama-2-7b-hf-weights",
low_cpu_mem_usage=True,
device_map= 0,
quantization_config=bnb_config_4bit,
torchscript=True,)
self.model.output_hidden_states = False
def forward(self, tokens_tensor):
self.model.eval()
o = self.model(tokens_tensor, output_hidden_states=False)
return o[0]
model = Wrapper()
model.eval()
with torch.no_grad():
dummy_tokens_tensor = torch.randint(0, 1000, (1, 50), dtype=torch.long).to("cuda")
outputs = model(dummy_tokens_tensor)
trace_model = torch.jit.trace(model, [dummy_tokens_tensor]) # this works, but with some trace waring
print("traced_model done")
torch.jit.save(trace_model, "llama_4bit.pt") # --> error!
The linetorch.jit.save(trace_model, "llama_4bit.pt") # --> error!
gives me error:
RuntimeError:
Could not export Python function call 'MatMul4Bit'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/autograd/function.py(506): apply
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py(577): matmul_4bit
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/nn/modules.py(256): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(386): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(798): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(1070): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home///test/test_torchscript_llama.py(62): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(1056): trace_module
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(794): trace
Expected behavior
troch.jit.save shall work properly for quantized model