-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLLMTransfer.py
More file actions
147 lines (131 loc) · 5.37 KB
/
LLMTransfer.py
File metadata and controls
147 lines (131 loc) · 5.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import re
import dashscope
import random
import requests
import json
from http import HTTPStatus
from openai import OpenAI
from time import sleep
from Deployment import Deployer
class LLMTransfer:
def __init__(self, model_name:str, temperature=10e-4, b_local:bool=False):
self.model_name = model_name
self.temperature = temperature
self.b_local = b_local
#self.sft_model_path = "/data/zonepg/models/Qwen/Qwen1.5-7B-Chat"
self.sft_model_path = "/data/zonepg/models/Qwen/Qwen1.5-7B-Chat-LoRA"
if model_name == "sft":
print("微调模型路径:", self.sft_model_path)
if b_local:
self.deployer = Deployer(model_name=model_name, temperature=temperature)
'''
阿里系千问大模型调用API
'''
def call_with_prompt_qw(self, prompt:str) -> any:
messages = [{
'role': 'user',
'content': prompt
}]
response = dashscope.Generation.call(
model=self.model_name,
messages=messages,
# set the random seed, optional, default to 1234 if not set
seed=2024,
result_format='text', # the format include 'text' and 'message'
temperature=self.temperature,
)
if response.status_code == HTTPStatus.OK:
return response.output['text']
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return None
'''
百度系千帆大模型调用API
'''
def call_with_prompt_qf(self, prompt:str) -> any:
def get_access_token():
"""
使用 AK,SK 生成鉴权签名(Access Token)
:return: access_token,或是None(如果错误)
"""
API_KEY = "jvyY9raEufISxdFTJVN1h229"
SECRET_KEY = "5bgNjaX1oIETYqOb1d7wNzGsA6ZDyndp"
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
return str(requests.post(url, params=params).json().get("access_token"))
model_name = re.sub('-', '_', self.model_name)
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_name}?access_token=" + get_access_token()
payload = json.dumps({
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": self.temperature,
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
sleep(0.5)
try:
return json.loads(response.text)["result"]
except:
print(f"无效回复:{response.text}")
return ''
'''
调用远程的微调大模型API
'''
def call_with_prompt_sft(self, prompt:str) -> any:
url = {}
url['qwen1.5-14B 微调'] = 'http://3d8b77a6.r2.cpolar.top/v1'
url['baichuan13B'] = 'http://517f7e08.r15.vip.cpolar.cn/v1'
url['qwen14B'] = 'http://3d8b77a6.r2.cpolar.top/v1'
url['sft_qwen'] = 'https://642a2878.r18.cpolar.top/v1'
url['sft_baichuan'] = 'http://642a2878.r18.cpolar.top/v1'
client = OpenAI(base_url="http://llmsapi.vip.cpolar.cn/v1", api_key="sk-coaihv832rfj0qaj09")
message = [
{"role": "system", "content": prompt[0]},
{"role": "user", "content": prompt[1]},
]
try:
response = client.chat.completions.create(
model = self.sft_model_path,
messages = message,
stream=False,
temperature=self.temperature,
timeout=60,
# max_tokens=1024,
)
except:
return "模型回复超时!"
# print(response)
return response.choices[0].message.content.strip()
'''
本地大模型调用API
'''
def call_with_prompt_local(self, prompt:str) -> str:
return self.deployer.response(prompt)
def call_with_prompt(self, prompt:str) -> str:
whole_prompt = prompt[0] + prompt[1]
# if self.model_name == 'mindchat':
# whole_prompt = "请帮我做一个心理学的选择题:\n" + prompt[1]
# # print(whole_prompt)
if self.b_local:
return self.call_with_prompt_local(whole_prompt)
elif self.model_name in ['qwen1.5-7b-chat', 'qwen1.5-14b-chat', 'baichuan2-13b-chat-v1']:
return self.call_with_prompt_qw(whole_prompt)
elif self.model_name in ['yi-34b-chat', 'qianfan-chinese-llama-2-7b', 'qianfan-chinese-llama-2-13b']:
return self.call_with_prompt_qf(whole_prompt)
elif self.model_name in ['sft', 'sft-qwen1.5-14b', 'sft-baichuan2-13b']:
return self.call_with_prompt_sft(prompt)
else: # baichuan2-7b-chat, chatglm3-6b-32k, chinese-alpaca-2-7b, chinese-alpaca-2-13b
return self.call_with_prompt_local(whole_prompt)
if __name__ == '__main__':
llm = LLMTransfer("chinese-alpaca-2-13b")
result = llm.call_with_prompt_sft('如何做西红柿炒鸡蛋?')
print(result)