Open
Description
Summary
Repro:
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = TorchAoConfig("autoquant", min_sqnr=None)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained(output_dir, safe_serialization=False)
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto", torch_dtype="auto")
First errror:
is that we dont add AutoQuantizableLinearWeight
to safe globals
Second is that HF on main will try and slice our param: https://github.com/huggingface/transformers/blob/0463901c92e08cefbccf19f409b6cc43c153352d/src/transformers/modeling_utils.py#L907
And we dont implement -> we also just print an error which is werid to me..
ao/torchao/quantization/autoquant.py
Line 313 in cf45336
is this because of how we do quantization? IF so we def SHOULD NOT have this a catch all error and narrow down the exception type
Metadata
Metadata
Assignees
Labels
No labels
Activity