-
Notifications
You must be signed in to change notification settings - Fork 264
Expand file tree
/
Copy pathllm_hub.py
More file actions
148 lines (121 loc) · 4.57 KB
/
llm_hub.py
File metadata and controls
148 lines (121 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import os
import random
import sys
import time
import yaml
from openai import InternalServerError, OpenAI, RateLimitError
context_files = json.loads(sys.argv[1])
question = sys.argv[2]
model = sys.argv[3]
model_type = sys.argv[4]
temperature_arg = sys.argv[5]
temperature = float(temperature_arg) if temperature_arg else None
provider = sys.argv[6]
litellm_config_file = os.environ.get("LITELLM_CONFIG_FILE")
if not litellm_config_file:
sys.exit("LITELLM_CONFIG_FILE environment variable is not set.")
with open(litellm_config_file, "r") as f:
config = yaml.safe_load(f)
servers = config.get("servers", {})
if servers and provider not in servers:
sys.exit(f"Provider '{provider}' not found in configuration.")
# Select the source: specific provider config if servers exist, otherwise global config (backward compatibility)
source = servers[provider] if servers else config
litellm_api_key = source.get("LITELLM_API_KEY")
litellm_base_url = source.get("LITELLM_BASE_URL")
if not litellm_api_key:
sys.exit(
"LiteLLM API key is not configured! Please set LITELLM_API_KEY environment variable."
)
if not litellm_base_url:
sys.exit(
"LiteLLM base URL is not configured! Please set LITELLM_BASE_URL environment variable."
)
client = OpenAI(
api_key=litellm_api_key,
base_url=litellm_base_url,
)
def read_text_file(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
except UnicodeDecodeError:
try:
with open(file_path, "r", encoding="latin-1") as f:
return f.read()
except Exception:
sys.exit(f"Could not read file {file_path} as text")
def get_image_mime_type(image_path):
import mimetypes
mime_type, _ = mimetypes.guess_type(image_path)
if mime_type and mime_type.startswith("image/"):
return mime_type
if image_path.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".tiff", ".bmp")):
ext = image_path.lower().split(".")[-1]
if ext == "jpg":
ext = "jpeg"
return f"image/{ext}"
return "image/jpeg"
def encode_image_to_base64(image_path):
import base64
try:
with open(image_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
mime_type = get_image_mime_type(image_path)
return f"data:{mime_type};base64,{base64_image}"
except Exception:
sys.exit(f"Could not process image file: {image_path}")
valid_model_types = {
"text": {"text"},
"image": {"image"},
"multimodal": {"text", "image"},
}
if model_type not in valid_model_types:
sys.exit(
f"Invalid model_type '{model_type}'. Must be one of: {', '.join(valid_model_types)}"
)
contents = []
for file_path, file_type in context_files:
if file_type not in valid_model_types[model_type]:
sys.exit(f"File type '{file_type}' not allowed for model_type '{model_type}'.")
if file_type == "image":
contents.append(
{
"type": "image_url",
"image_url": {"url": encode_image_to_base64(file_path)},
}
)
else:
contents.append(
{
"type": "text",
"text": f"File: {file_path}\nContent:\n{read_text_file(file_path)}",
}
)
if question and "text" in valid_model_types[model_type]:
contents.append({"type": "text", "text": question})
if not contents:
sys.exit("No input content provided.")
messages = [{"role": "user", "content": contents}]
max_retries = config.get("MAX_RETRIES", 3)
max_delay = config.get("MAX_DELAY", 900)
for attempt in range(max_retries):
try:
api_params = {"model": model, "messages": messages}
if temperature is not None:
api_params["temperature"] = temperature
response = client.chat.completions.create(**api_params)
with open("output.md", "w") as f:
f.write(response.choices[0].message.content or "")
break
except (InternalServerError, RateLimitError) as e:
if attempt == max_retries - 1:
sys.exit("Max retries reached. Exiting.")
sleep_time = min(2**attempt + random.uniform(0, 1), max_delay)
if isinstance(e, RateLimitError) and hasattr(e, "response") and e.response is not None:
retry_after = e.response.headers.get("retry-after")
if retry_after:
sleep_time = min(float(retry_after), max_delay)
print(f"Error encountered ({e}). Retrying in {sleep_time:.2f} seconds...")
time.sleep(sleep_time)