Skip to content

Commit a722b71

Browse files
authored
Adds o-series model support (#299)
1 parent 29722fb commit a722b71

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

controller/attribute/llm_response_tmpl.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class LLMProvider_A2VYBG(Enum):
6161
"presence_penalty": float("@@PRESENCE_PENALTY@@"),
6262
}
6363

64+
IS_O_SERIES_A2VYBG = bool("@@IS_O_SERIES@@")
65+
66+
if IS_O_SERIES_A2VYBG:
67+
del LLM_KWARGS_A2VYBG["temperature"]
68+
LLM_KWARGS_A2VYBG["max_completion_tokens"] = LLM_KWARGS_A2VYBG.pop("max_tokens")
69+
70+
6471
SYSTEM_PROMPT_A2VYBG = (
6572
"""@@SYSTEM_PROMPT@@ You must only output valid JSON. """
6673
"If there is not yet a schema defined for the JSON output, "
@@ -291,16 +298,29 @@ async def get_llm_response(record: dict, cached_records: dict):
291298
if curr_running_id in cached_records:
292299
return cached_records[curr_running_id]
293300

294-
messages = [
295-
{
296-
"role": "system",
297-
"content": SYSTEM_PROMPT_A2VYBG,
298-
},
299-
{
300-
"role": "user",
301-
"content": USER_PROMPT_A2VYBG,
302-
},
303-
]
301+
if IS_O_SERIES_A2VYBG:
302+
# doesn't have a system prompt
303+
messages = [
304+
{
305+
"role": "user",
306+
"content": f"""Instructions:
307+
{SYSTEM_PROMPT_A2VYBG}
308+
Further information:
309+
{USER_PROMPT_A2VYBG}
310+
""",
311+
},
312+
]
313+
else:
314+
messages = [
315+
{
316+
"role": "system",
317+
"content": SYSTEM_PROMPT_A2VYBG,
318+
},
319+
{
320+
"role": "user",
321+
"content": USER_PROMPT_A2VYBG,
322+
},
323+
]
304324
exception = None
305325
for _ in range(int(MAX_RETRIES_A2VYBG)):
306326
try:

controller/attribute/util.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,22 @@ def prepare_sample_records_doc_bin(
8080
return prefixed_doc_bin
8181

8282

83-
def test_openai_llm_connection(api_key: str, model: str):
83+
def test_openai_llm_connection(api_key: str, model: str, is_o_series: bool = False):
8484
# more here: https://platform.openai.com/docs/api-reference/making-requests
8585
headers = {
8686
"Content-Type": "application/json",
8787
"Authorization": f"Bearer {api_key}",
8888
}
89-
89+
if is_o_series:
90+
add_payload = {"max_completion_tokens": 5}
91+
else:
92+
add_payload = {"max_tokens": 5}
9093
payload = {
9194
"model": model,
9295
"messages": [
9396
{"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]},
9497
],
95-
"max_tokens": 5,
98+
**add_payload,
9699
}
97100

98101
response = requests.post(
@@ -124,7 +127,11 @@ def test_azure_foundry_llm_connection(api_key: str, base_endpoint: str):
124127

125128

126129
def test_azure_llm_connection(
127-
api_key: str, base_endpoint: str, api_version: str, model: str
130+
api_key: str,
131+
base_endpoint: str,
132+
api_version: str,
133+
model: str,
134+
is_o_series: bool = False,
128135
):
129136
# more here: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview
130137
base_endpoint = base_endpoint.rstrip("/")
@@ -146,11 +153,15 @@ def test_azure_llm_connection(
146153
"api-key": api_key,
147154
}
148155

156+
if is_o_series:
157+
add_payload = {"max_completion_tokens": 5}
158+
else:
159+
add_payload = {"max_tokens": 5}
149160
payload = {
150161
"messages": [
151162
{"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]},
152163
],
153-
"max_tokens": 5,
164+
**add_payload,
154165
}
155166

156167
response = requests.post(final_endpoint, headers=headers, json=payload)
@@ -190,13 +201,15 @@ def validate_llm_config(llm_config: Dict[str, Any]):
190201
test_openai_llm_connection(
191202
api_key=llm_config["apiKey"],
192203
model=llm_config["model"],
204+
is_o_series=llm_config.get("openAioSeries", False),
193205
)
194206
elif llm_config["llmIdentifier"] == enums.LLMProvider.AZURE.value:
195207
test_azure_llm_connection(
196208
api_key=llm_config["apiKey"],
197209
model=llm_config["model"],
198210
base_endpoint=llm_config["apiBase"],
199211
api_version=llm_config["apiVersion"],
212+
is_o_series=llm_config.get("openAioSeries", False),
200213
)
201214
elif llm_config["llmIdentifier"] == enums.LLMProvider.AZURE_FOUNDRY.value:
202215
test_azure_foundry_llm_connection(
@@ -291,6 +304,8 @@ async def ac(record):
291304
"@@CACHE_FILE_UPLOAD_LINK@@": llm_config.get(
292305
"llmAcCacheFileUploadLink", ""
293306
),
307+
# string quotes are replaced since bool("False") == True
308+
'"@@IS_O_SERIES@@"': str(llm_config.get("openAioSeries", False)),
294309
}
295310
except KeyError:
296311
raise LlmResponseError(

0 commit comments

Comments
 (0)