Skip to content

Commit 81d1be7

Browse files
authored
feat(platform): Add OpenAI reasoning models (#8152)
1 parent 6da8007 commit 81d1be7

File tree

1 file changed

+25
-7
lines changed
  • autogpt_platform/backend/backend/blocks

1 file changed

+25
-7
lines changed

autogpt_platform/backend/backend/blocks/llm.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class ModelMetadata(NamedTuple):
3030

3131
class LlmModel(str, Enum):
3232
# OpenAI models
33+
O1_PREVIEW = "o1-preview"
34+
O1_MINI = "o1-mini"
3335
GPT4O_MINI = "gpt-4o-mini"
3436
GPT4O = "gpt-4o"
3537
GPT4_TURBO = "gpt-4-turbo"
@@ -57,6 +59,8 @@ def metadata(self) -> ModelMetadata:
5759

5860

5961
MODEL_METADATA = {
62+
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
63+
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=30),
6064
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
6165
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
6266
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
@@ -84,7 +88,10 @@ def metadata(self) -> ModelMetadata:
8488
class AIStructuredResponseGeneratorBlock(Block):
8589
class Input(BlockSchema):
8690
prompt: str
87-
expected_format: dict[str, str]
91+
expected_format: dict[str, str] = SchemaField(
92+
description="Expected format of the response. If provided, the response will be validated against this format. "
93+
"The keys should be the expected fields in the response, and the values should be the description of the field.",
94+
)
8895
model: LlmModel = LlmModel.GPT4_TURBO
8996
api_key: BlockSecret = SecretField(value="")
9097
sys_prompt: str = ""
@@ -132,7 +139,18 @@ def llm_call(
132139

133140
if provider == "openai":
134141
openai.api_key = api_key
135-
response_format = {"type": "json_object"} if json_format else None
142+
response_format = None
143+
144+
if model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
145+
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
146+
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
147+
prompt = [
148+
{"role": "user", "content": "\n".join(sys_messages)},
149+
{"role": "user", "content": "\n".join(usr_messages)},
150+
]
151+
elif json_format:
152+
response_format = {"type": "json_object"}
153+
136154
response = openai.chat.completions.create(
137155
model=model.value,
138156
messages=prompt, # type: ignore
@@ -207,11 +225,11 @@ def trim_prompt(s: str) -> str:
207225
format_prompt = ",\n ".join(expected_format)
208226
sys_prompt = trim_prompt(
209227
f"""
210-
|Reply in json format:
211-
|{{
212-
| {format_prompt}
213-
|}}
214-
"""
228+
|Reply strictly only in the following JSON format:
229+
|{{
230+
| {format_prompt}
231+
|}}
232+
"""
215233
)
216234
prompt.append({"role": "system", "content": sys_prompt})
217235

0 commit comments

Comments
 (0)