Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 274 additions & 13 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
import os
import warnings
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Type, Union

from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
from openai.lib.streaming.chat import (
Expand All @@ -26,6 +26,11 @@
from camel.logger import get_logger
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.models.openai_responses_adapter import (
aiter_response_events_to_chat_chunks,
iter_response_events_to_chat_chunks,
response_to_chat_completion,
)
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
Expand All @@ -35,6 +40,7 @@
BaseTokenCounter,
OpenAITokenCounter,
api_keys_required,
get_current_agent_session_id,
is_langfuse_available,
)

Expand Down Expand Up @@ -99,6 +105,8 @@ class OpenAIModel(BaseModelBackend):
client instance. If provided, this client will be used instead of
creating a new one. The client should implement the AsyncOpenAI
client interface. (default: :obj:`None`)
api_mode (str, optional): OpenAI API mode to use. Supported values:
`"chat_completions"` (default) and `"responses"`.
**kwargs (Any): Additional arguments to pass to the
OpenAI client initialization. These can include parameters like
'organization', 'default_headers', 'http_client', etc.
Expand All @@ -121,13 +129,22 @@ def __init__(
max_retries: int = 3,
client: Optional[Any] = None,
async_client: Optional[Any] = None,
api_mode: str = "chat_completions",
**kwargs: Any,
) -> None:
if model_config_dict is None:
model_config_dict = ChatGPTConfig().as_dict()
api_key = api_key or os.environ.get("OPENAI_API_KEY")
url = url or os.environ.get("OPENAI_API_BASE_URL")
timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
if api_mode not in {"chat_completions", "responses"}:
raise ValueError(
"api_mode must be 'chat_completions' or 'responses', "
f"got: {api_mode}"
)
self._api_mode = api_mode
self._responses_previous_response_id_by_session: Dict[str, str] = {}
self._responses_last_message_count_by_session: Dict[str, int] = {}

# Store additional client args for later use
self._max_retries = max_retries
Expand Down Expand Up @@ -309,15 +326,27 @@ def _run(

if response_format:
if is_streaming:
# Use streaming parse for structured output
return self._request_stream_parse(
messages, response_format, tools
)
if self._api_mode == "responses":
return self._request_responses_stream(
messages, response_format, tools
)
return self._request_stream_parse(messages, response_format, tools)
else:
# Use non-streaming parse for structured output
if self._api_mode == "responses":
return self._request_responses(
messages, response_format, tools
)
return self._request_parse(messages, response_format, tools)
else:
result = self._request_chat_completion(messages, tools)
if self._api_mode == "responses":
if is_streaming:
result = self._request_responses_stream(
messages, None, tools
)
else:
result = self._request_responses(messages, None, tools)
else:
result = self._request_chat_completion(messages, tools)

return result

Expand Down Expand Up @@ -362,20 +391,252 @@ async def _arun(

if response_format:
if is_streaming:
# Use streaming parse for structured output
if self._api_mode == "responses":
return await self._arequest_responses_stream(
messages, response_format, tools
)
return await self._arequest_stream_parse(
messages, response_format, tools
)
else:
# Use non-streaming parse for structured output
return await self._arequest_parse(
messages, response_format, tools
)
if self._api_mode == "responses":
return await self._arequest_responses(
messages, response_format, tools
)
return await self._arequest_parse(messages, response_format, tools)
else:
result = await self._arequest_chat_completion(messages, tools)
if self._api_mode == "responses":
if is_streaming:
result = await self._arequest_responses_stream(
messages, None, tools
)
else:
result = await self._arequest_responses(
messages, None, tools
)
else:
result = await self._arequest_chat_completion(messages, tools)

return result

def _prepare_responses_request_config(
self,
tools: Optional[List[Dict[str, Any]]] = None,
response_format: Optional[Type[BaseModel]] = None,
stream: bool = False,
) -> Dict[str, Any]:
request_config = self._prepare_request_config(tools)
request_config = self._sanitize_config(request_config)

# Translate chat-completions style parameters to responses style.
max_tokens = request_config.pop("max_tokens", None)
if max_tokens is not None and "max_output_tokens" not in request_config:
request_config["max_output_tokens"] = max_tokens

# `n` is unsupported in responses. Keep backward compatibility.
if request_config.get("n") not in (None, 1):
warnings.warn(
"OpenAI Responses API does not support `n`; "
"ignoring configured value.",
UserWarning,
)
request_config.pop("n", None)

request_config.pop("response_format", None)
request_config.pop("stream_options", None)
request_config["stream"] = stream
# previous_response_id chaining requires stored responses.
if request_config.get("store") is False:
warnings.warn(
"Overriding `store=False` to `store=True` because "
"`previous_response_id` chaining is enabled.",
UserWarning,
)
request_config["store"] = True

if response_format is not None:
request_config["text"] = {
"format": {
"type": "json_schema",
"name": response_format.__name__,
"schema": response_format.model_json_schema(),
}
}

return request_config

def _get_response_chain_session_key(self) -> str:
return get_current_agent_session_id() or "__default__"

def _prepare_responses_input_and_chain(
self,
messages: List[OpenAIMessage],
) -> Dict[str, Any]:
session_key = self._get_response_chain_session_key()
previous_response_id = self._responses_previous_response_id_by_session.get(
session_key
)
last_message_count = self._responses_last_message_count_by_session.get(
session_key, 0
)

# If memory was reset/truncated, reset chain and send full context.
if len(messages) < last_message_count:
previous_response_id = None
last_message_count = 0
self._responses_previous_response_id_by_session.pop(
session_key, None
)

if previous_response_id and last_message_count > 0:
delta_messages = messages[last_message_count:]
input_messages = delta_messages if delta_messages else [messages[-1]]
else:
input_messages = messages

return {
"session_key": session_key,
"previous_response_id": previous_response_id,
"input_messages": input_messages,
"message_count": len(messages),
}

def _save_response_chain_state(
self,
session_key: str,
response_id: Optional[str],
message_count: int,
) -> None:
if response_id:
self._responses_previous_response_id_by_session[session_key] = (
response_id
)
self._responses_last_message_count_by_session[session_key] = (
message_count
)

def _request_responses(
self,
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> ChatCompletion:
chain_state = self._prepare_responses_input_and_chain(messages)
request_config = self._prepare_responses_request_config(
tools=tools, response_format=response_format, stream=False
)
if chain_state["previous_response_id"]:
request_config["previous_response_id"] = chain_state[
"previous_response_id"
]
response = self._client.responses.create(
input=chain_state["input_messages"],
model=self.model_type,
**request_config,
)
self._save_response_chain_state(
session_key=chain_state["session_key"],
response_id=getattr(response, "id", None)
if not isinstance(response, dict)
else response.get("id"),
message_count=chain_state["message_count"],
)
return response_to_chat_completion(
response=response,
model=str(self.model_type),
response_format=response_format,
)

async def _arequest_responses(
self,
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> ChatCompletion:
chain_state = self._prepare_responses_input_and_chain(messages)
request_config = self._prepare_responses_request_config(
tools=tools, response_format=response_format, stream=False
)
if chain_state["previous_response_id"]:
request_config["previous_response_id"] = chain_state[
"previous_response_id"
]
response = await self._async_client.responses.create(
input=chain_state["input_messages"],
model=self.model_type,
**request_config,
)
self._save_response_chain_state(
session_key=chain_state["session_key"],
response_id=getattr(response, "id", None)
if not isinstance(response, dict)
else response.get("id"),
message_count=chain_state["message_count"],
)
return response_to_chat_completion(
response=response,
model=str(self.model_type),
response_format=response_format,
)

def _request_responses_stream(
self,
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> Generator[ChatCompletionChunk, None, None]:
chain_state = self._prepare_responses_input_and_chain(messages)
request_config = self._prepare_responses_request_config(
tools=tools, response_format=response_format, stream=True
)
if chain_state["previous_response_id"]:
request_config["previous_response_id"] = chain_state[
"previous_response_id"
]
event_stream = self._client.responses.create(
input=chain_state["input_messages"],
model=self.model_type,
**request_config,
)
return iter_response_events_to_chat_chunks(
event_stream=event_stream,
model=str(self.model_type),
on_response_completed=lambda response_id: self._save_response_chain_state(
session_key=chain_state["session_key"],
response_id=response_id,
message_count=chain_state["message_count"],
),
)

async def _arequest_responses_stream(
self,
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> AsyncGenerator[ChatCompletionChunk, None]:
chain_state = self._prepare_responses_input_and_chain(messages)
request_config = self._prepare_responses_request_config(
tools=tools, response_format=response_format, stream=True
)
if chain_state["previous_response_id"]:
request_config["previous_response_id"] = chain_state[
"previous_response_id"
]
event_stream = await self._async_client.responses.create(
input=chain_state["input_messages"],
model=self.model_type,
**request_config,
)
return aiter_response_events_to_chat_chunks(
event_stream=event_stream,
model=str(self.model_type),
on_response_completed=lambda response_id: self._save_response_chain_state(
session_key=chain_state["session_key"],
response_id=response_id,
message_count=chain_state["message_count"],
),
)

def _request_chat_completion(
self,
messages: List[OpenAIMessage],
Expand Down
Loading
Loading