-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy patheval.py
113 lines (90 loc) · 4.6 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys
import random
import numpy as np
import torch
import utils
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from accelerate import infer_auto_device_map
from utils.quant_utils import wrap_to_quant_model, init_weight_quantizer, init_input_quantizer, register_online_had, init_k_quantizer, init_v_quantizer
import utils.model_utils as model_utils
import utils.rotation_utils as rotation_utils
from main import evaluate
from utils.train_utils import load_json_as_namespace,create_logger
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_in_model
torch.backends.cudnn.benchmark = True
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--quant_model_path", type=str, help="model path of quantized model")
parser.add_argument("--output_dir", default="./log/test", type=str, help="direction of logging file")
parser.add_argument("--real_quant", default=False, action="store_true",
help="use real quantization instead of fake quantization, can reduce memory footprint")
parser.add_argument("--ppl_seqlen", type=int, default=2048, help="lenth of the training sequence.")
parser.add_argument("--seed", type=int, default=2, help="Seed for sampling the calibration data.")
parser.add_argument("--eval_ppl", action="store_true",help="evaluate perplexity on wikitext2 and c4 with 2048 context length")
parser.add_argument("--eval_tasks", type=str,default="", help="exampe:piqa,arc_easy,arc_challenge,hellaswag,winogrande")
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--max_memory", type=str, default="70GiB",help="The maximum memory of each GPU")
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# init logger
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
output_dir = Path(args.output_dir)
logger = create_logger(output_dir)
quant_config = load_json_as_namespace(os.path.join(args.quant_model_path, 'prefixequant_config.json'))
# if quant_config['set_prefixed_tokens']:
if quant_config.set_prefixed_tokens:
prefixed_key_values = torch.load(os.path.join(args.quant_model_path, 'prefixed_key_values.pth'))
else:
prefixed_key_values = None
# init quantized model
config = AutoConfig.from_pretrained(args.quant_model_path,trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.quant_model_path, use_fast=False,legacy=False,trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(args.quant_model_path, config=config, device_map='cpu',torch_dtype=torch.float16,trust_remote_code=True)
wrap_to_quant_model(model)
# register on-line hadadamrd transformation
if quant_config.down_online_had:
register_online_had(model)
# wrap rope for online_had and rope output capture
rope_function_name = model_utils.get_rope_function_name(model)
layers = model_utils.get_layers(model)
for layer in layers:
rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
layer.self_attn,
rope_function_name,
config=model.config,
online_had=quant_config.qk_online_had)
# init weight quantizer
if quant_config.wbits < 16:
logger.info('init weight quantizer')
init_weight_quantizer(quant_config, model, minmax_init=False)
# init input quantizer
if quant_config.input_bits < 16:
logger.info('init input quantizer')
init_input_quantizer(quant_config, model, minmax_init=False)
# init kv quantizer
if quant_config.v_bits < 16:
logger.info('init v quantizer')
init_v_quantizer(quant_config, model, minmax_init=False)
# if True:
if quant_config.k_bits < 16:
# consistently init for wrap rope
logger.info('init k quantizer')
init_k_quantizer(quant_config, model, minmax_init=False)
# model.tie_weights()
device_map = infer_auto_device_map(model)
print("Loading pre-computed quantized weights...")
load_checkpoint_in_model(model,checkpoint=args.quant_model_path,device_map=device_map,dtype=torch.float16)
model.half() # to make sure same evaluation results with main
evaluate(model, tokenizer, prefixed_key_values, args,logger)
if __name__ == "__main__":
print(sys.argv)
main()