-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathclass_finetune_inference.py
More file actions
79 lines (63 loc) · 2.51 KB
/
class_finetune_inference.py
File metadata and controls
79 lines (63 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import sys
import json
from pathlib import Path
from argparse import ArgumentParser
import tiktoken
import torch
from model.language_model import LanguageModel
from utils.model_inference import classify_review
from class_finetune_train import adapt_model_for_classification
if __name__ == "__main__":
"""
初始化模型并进行微调
Args:
--config (str): 模型配置参数文件路径
--model_path (str): 微调后保存模型权重文件路径
"""
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="configs/gpt2_config_355M.json")
parser.add_argument("--model_path", type=str, default="review_classifier.pth")
args = parser.parse_args()
config, model_path = vars(args).values()
# 如果模型权重文件不存在,则提示并退出程序
if not Path(model_path).exists():
print(f"模型权重文件 {model_path} 不存在,请先进行模型微调")
sys.exit()
with open(config) as f:
cfg = json.load(f)
# 如果你有一台支持 CUDA 的 GPU 机器,那么大语言模型将自动在 GPU 上训练且不需要修改代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 设置随机种子以保证结果可复现
torch.manual_seed(123)
# 初始化分词器
tokenizer = tiktoken.get_encoding("gpt2")
##############################
# 初始化模型并加载 GPT-2 预训练权重
##############################
model = LanguageModel(cfg)
model.to(device)
##############################
# 修改模型用于分类任务
##############################
adapt_model_for_classification(model, cfg)
# 加载之前训练好的模型权重参数,weights_only=True 表示只加载模型参数,不加载优化器等状态信息
model.load_state_dict(torch.load(model_path, weights_only=True))
# 切换为推理模式,将禁用 dropout 等只在训练时使用的功能
model.eval()
print("开始对话(输入'exit'退出)\n")
while True:
input_text = input("用户: ")
if input_text.lower() == '':
print("输入不能为空!")
continue
if input_text.lower() == 'exit':
break
# 使用模型进行分类评论
label = classify_review(
input_text=input_text,
model=model,
tokenizer=tokenizer,
device=device,
max_length=cfg["context_length"]
)
print(f"模型: {label}\n")