-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantize.py
More file actions
81 lines (66 loc) · 2.37 KB
/
quantize.py
File metadata and controls
81 lines (66 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import sys
sys.path.append('/Users/vaishnavip/Projects/Quantization-inference/GPTQ')
sys.path.append('/Users/vaishnavip/Projects/Quantization-inference')
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import os
import json
from helpers import add_quantization_config, compress_and_save, extract_config
from GPTQ.gptq import GPTQ
from GPTQ.QuantizationConfig import QuantizationConfig
from load_datasets import load_data
import torch
import time
model_dir = "./Llama-3.2-3B"
config_file_path = "./Llama-3.2-3B/config.json"
output_dir = "./Llama-3.2-3B-quant"
model_path = os.path.abspath(model_dir)
import logging
# if you configured your module‐level logger like this:
logger = logging.getLogger(__name__)
# then just bump it to INFO or higher:
logger.setLevel(logging.INFO)
def GPTQ_quantize():
#Load the model
model = AutoModelForCausalLM.from_pretrained(model_path)
dataloader = load_data(model_dir, 128, 2048, 128)
batch = next(iter(dataloader))
print(batch.keys())
quant_config = {"quantization_config" : {
"group": {
"targets": ["Linear"],
"weights": {
"num_bits": 4,
"type": "int",
},
},
"format": "pack",
"scheme": "gptq",
"ignore": ["lm_head"],
"status": "unquantized",
}
}
cur = time.time()
quantization_config = QuantizationConfig(quant_config)
# Add the quantization config to the model
add_quantization_config(model, quantization_config)
# Depending on the type of quantization call the required quantizer
gptq_quantizer = GPTQ()
# Call the initializer which runs the forward and quantizes the layers
gptq_quantizer.initialize(model, dataloader)
# Removes any memory and extra parameters
gptq_quantizer.finalize()
# Compress and save the model
compress_and_save(model, config_file_path, quant_config, output_dir)
res = time.time()
logger.info('It took {res-cur} to quantize')
def run_inference(config_file_path, output_dir):
output_dir = ""
#Load the model
model = AutoConfig.from_pretrained(output_dir)
quantization_config_raw = extract_config(config_file_path)
quantization_config = QuantizationConfig(quantization_config_raw)
add_quantization_config(model,quantization_config)
with torch.nograd():
model.generate()
if __name__ == "__main__":
GPTQ_quantize()