-
Notifications
You must be signed in to change notification settings - Fork 592
Expand file tree
/
Copy pathopenai_compatible.py
More file actions
136 lines (110 loc) · 5.04 KB
/
openai_compatible.py
File metadata and controls
136 lines (110 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import base64
import json
import os
import time
from io import BytesIO
from typing import List, Tuple, Union
import numpy as np
import requests as url_requests
from accelerate import Accelerator, DistributedType
from tqdm import tqdm
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
try:
from decord import VideoReader, cpu
except ImportError:
pass
from dotenv import find_dotenv, load_dotenv
from loguru import logger as eval_logger
from openai import AzureOpenAI, OpenAI
from PIL import Image
from lmms_eval.models.model_utils.gen_metrics import log_metrics
from lmms_eval.models.simple.openai_compatible import (
OpenAICompatible as OpenAICompatibleSimple,
)
from lmms_eval.protocol import ChatMessages
load_dotenv(verbose=True)
@register_model("openai_compatible_chat")
class OpenAICompatible(OpenAICompatibleSimple):
is_simple = False
def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
e2e_latency = 0
total_tokens = 0
for ctx, doc_to_messages, gen_kwargs, doc_id, task, split in [reg.args for reg in requests]:
if self.continual_mode is True and self.cache_mode == "resume":
doc_uuid = f"{task}___{split}___{doc_id}"
if doc_uuid in self.response_cache:
response_text = self.response_cache[doc_uuid]
if response_text:
res.append(response_text)
pbar.update(1)
continue
chat_messages = doc_to_messages(self.task_dict[task][split][doc_id])
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages})
payload = {"messages": chat_messages.to_openai_messages()}
payload["model"] = self.model_version
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if gen_kwargs["max_new_tokens"] > 4096:
gen_kwargs["max_new_tokens"] = 4096
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
# payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
payload["max_tokens"] = gen_kwargs["max_new_tokens"]
payload["temperature"] = gen_kwargs["temperature"]
if "o1" in self.model_version or "o3" in self.model_version:
# del payload["max_output_tokens"]
del payload["temperature"]
payload["reasoning_effort"] = "medium"
payload["response_format"] = {"type": "text"}
payload.pop("max_tokens")
payload["max_completion_tokens"] = gen_kwargs["max_tokens"]
for attempt in range(self.max_retries):
try:
start_time = time.time()
response = self.client.chat.completions.create(**payload)
end_time = time.time()
response_text = response.choices[0].message.content
# Calculate timing metrics
e2e_latency += end_time - start_time
# Get token counts from response if available
if hasattr(response, "usage"):
total_tokens += response.usage.completion_tokens
else:
# Approximate token count if not provided
total_tokens += len(response_text.split())
break # If successful, break out of the loop
except Exception as e:
error_msg = str(e)
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}")
# On last attempt, log error and set empty response
if attempt == self.max_retries - 1:
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}")
response_text = ""
else:
time.sleep(self.timeout)
res.append(response_text)
pbar.update(1)
if self.continual_mode is True: # Cache the response
doc_uuid = f"{task}___{split}___{doc_id}"
self.response_cache[doc_uuid] = response_text
with open(self.response_persistent_file, "w") as f:
json.dump(self.response_cache, f)
# Calculate average speed
avg_speed = total_tokens / e2e_latency if e2e_latency > 0 else 0
# Log metrics
metric_dict = {
"total_tokens": total_tokens,
"e2e_latency": e2e_latency,
"avg_speed": avg_speed,
}
log_metrics(**metric_dict)
pbar.close()
return res