Skip to content

4bit quantized model.dequantize() fails on CPU #1311

@npbool

Description

@npbool

System Info

ubuntu22.04, python3.10.4, intel cpu
bitsandbytes==0.43.3
transformers==4.43.3

Reproduction

quantization_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-7B-Instruct",
    torch_dtype=torch.float16,
    device_map="cpu",
    quantization_config=quantization_config,    
)
base_model.dequantize()

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[24], line 1
----> 1 base_model.dequantize()

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:1394, in PreTrainedModel.dequantize(self)
   1391 if hf_quantizer is None:
   1392     raise ValueError("You need to first quantize your model in order to dequantize it")
-> 1394 return hf_quantizer.dequantize(self)

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/quantizers/base.py:202, in HfQuantizer.dequantize(self, model)
    197 def dequantize(self, model):
    198     """
    199     Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
    200     Note not all quantization schemes support this.
    201     """
--> 202     model = self._dequantize(model)
    204     # Delete quantizer and quantization config
    205     del model.hf_quantizer

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/quantizers/quantizer_bnb_4bit.py:320, in Bnb4BitHfQuantizer._dequantize(self, model)
    317 def _dequantize(self, model):
    318     from ..integrations import dequantize_and_replace
--> 320     model = dequantize_and_replace(
    321         model, self.modules_to_not_convert, quantization_config=self.quantization_config
    322     )
    323     return model

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:458, in dequantize_and_replace(model, modules_to_not_convert, quantization_config)
    453 def dequantize_and_replace(
    454     model,
    455     modules_to_not_convert=None,
    456     quantization_config=None,
    457 ):
--> 458     model, has_been_replaced = _dequantize_and_replace(
    459         model,
    460         modules_to_not_convert=modules_to_not_convert,
    461         quantization_config=quantization_config,
    462     )
    464     if not has_been_replaced:
    465         logger.warning(
    466             "For some reason the model has not been properly dequantized. You might see unexpected behavior."
    467         )

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:441, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    439         model._modules[name] = new_module
    440 if len(list(module.children())) > 0:
--> 441     _, has_been_replaced = _dequantize_and_replace(
    442         module,
    443         modules_to_not_convert,
    444         current_key_name,
    445         quantization_config,
    446         has_been_replaced=has_been_replaced,
    447     )
    448 # Remove the last key for recursion
    449 current_key_name.pop(-1)

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:441, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    439         model._modules[name] = new_module
    440 if len(list(module.children())) > 0:
--> 441     _, has_been_replaced = _dequantize_and_replace(
    442         module,
    443         modules_to_not_convert,
    444         current_key_name,
    445         quantization_config,
    446         has_been_replaced=has_been_replaced,
    447     )
    448 # Remove the last key for recursion
    449 current_key_name.pop(-1)

    [... skipping similar frames: _dequantize_and_replace at line 441 (1 times)]

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:441, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    439         model._modules[name] = new_module
    440 if len(list(module.children())) > 0:
--> 441     _, has_been_replaced = _dequantize_and_replace(
    442         module,
    443         modules_to_not_convert,
    444         current_key_name,
    445         quantization_config,
    446         has_been_replaced=has_been_replaced,
    447     )
    448 # Remove the last key for recursion
    449 current_key_name.pop(-1)

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:425, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    422 else:
    423     state = None
--> 425 new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
    427 if bias is not None:
    428     new_module.bias = bias

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:349, in dequantize_bnb_weight(weight, state)
    346     return weight
    348 if cls_name == "Params4bit":
--> 349     output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
    350     logger.warning_once(
    351         f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
    352     )
    353     return output_tensor

File ~/projects/ml/venv/lib/python3.10/site-packages/bitsandbytes/functional.py:1333, in dequantize_4bit(A, quant_state, absmax, out, blocksize, quant_type)
   1330     raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
   1332 if quant_state is None:
-> 1333     assert absmax is not None and out is not None
   1335     quant_state = QuantState(
   1336         absmax=absmax,
   1337         shape=out.shape,
   (...)
   1340         quant_type=quant_type,
   1341     )
   1343 else:

AssertionError: 

Expected behavior

This code should work fine on cpu as on nvidia gpu.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions