-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
146 lines (123 loc) · 4.63 KB
/
model.py
File metadata and controls
146 lines (123 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import openai
from openai import OpenAI
import pdb
from pprint import pprint
import time
import json
from util import get_from_cache, save_to_cache
LOG_MAX_LENGTH = 300
# one example for post_process_fn
def identity(res):
return res
def get_model(args):
# choose different LLM, e.g., GPT, TogetherAI
model_name, temperature = args.model, args.temperature
base_url = args.openai_proxy if hasattr(args, 'openai_proxy') else None
print('base_url: ', base_url)
if 'gpt' in model_name:
model = GPT(args.api_key, model_name, temperature, base_url)
return model
else:
raise KeyError(f"Model {model_name} not implemented")
class Model(object):
def __init__(self):
self.post_process_fn = identity
def set_post_process_fn(self, post_process_fn):
self.post_process_fn = post_process_fn
class GPT(Model):
def __init__(self, api_key, model_name, temperature=1.0, base_url=None):
super().__init__()
self.model_name = model_name
self.temperature = temperature
self.badrequest_count = 0
openai.api_key = api_key
if base_url:
openai.base_url = base_url
self.client = OpenAI(api_key = api_key, base_url = base_url)
# used in forward()
def get_response(self, **kwargs):
try:
# res = openai.chat.completions.create(**kwargs)
res = self.client.chat.completions.create(**kwargs)
return res
except openai.APIConnectionError as e:
print('APIConnectionError')
time.sleep(30)
return self.get_response(**kwargs)
except openai.APIConnectionError as err:
print('APIConnectionError')
time.sleep(30)
return self.get_response(**kwargs)
except openai.RateLimitError as e:
print('RateLimitError')
time.sleep(10)
return self.get_response(**kwargs)
except openai.APITimeoutError as e:
print('APITimeoutError')
time.sleep(30)
return self.get_response(**kwargs)
except openai.BadRequestError as e:
print('BadRequestError')
self.badrequest_count += 1
print('badrequest_count', self.badrequest_count)
return None
def forward(self, head=None, prompt=None, use_cache=True, logger=None, forward_use_logger=True, use_json_format=False):
messages = []
info = {}
if logger == None:
forward_use_logger = False
if head != None:
messages.append(
{"role": "system", "content": head}
)
messages.append(
{"role": "user", "content": prompt}
)
key = json.dumps([self.model_name, messages])
if forward_use_logger == True:
logger.info(f"Messages: {str(messages)[:LOG_MAX_LENGTH]}")
if use_cache:
cached_value = get_from_cache(key, logger, forward_use_logger)
if cached_value is not None:
if forward_use_logger == True:
logger.info("Cache Hit")
else:
print("Cache Hit")
return cached_value, None
if forward_use_logger:
logger.info("Cache Miss")
else:
print("Cache Miss")
if use_json_format:
response = self.get_response(
model=self.model_name,
messages=messages,
temperature=self.temperature,
response_format={"type": "json_object"},
)
else:
response = self.get_response(
model=self.model_name,
messages=messages,
temperature=self.temperature,
)
if response is None:
info['response'] = None
info['message'] = None
return None, info
else:
messages.append(
{"role": "assistant", "content": response.choices[0].message.content}
)
info = dict(response.usage) # completion_tokens, prompt_tokens, total_tokens
info['response'] = messages[-1]["content"]
info['message'] = messages
if forward_use_logger:
logger.info(f"Response: {str(info['response'])[:LOG_MAX_LENGTH]}")
# if use_cache:
save_to_cache(key, info['response'], logger, forward_use_logger)
if forward_use_logger:
logger.info("Cache Saved")
else:
print("Cache Miss")
return self.post_process_fn(info['response']), info