diff --git a/aisuite/framework/chat_completion_stream_response.py b/aisuite/framework/chat_completion_stream_response.py new file mode 100644 index 00000000..e31ca1fa --- /dev/null +++ b/aisuite/framework/chat_completion_stream_response.py @@ -0,0 +1,37 @@ +# aisuite/framework/chat_completion_stream_response.py + +from typing import Optional + +class ChatCompletionStreamResponseDelta: + """ + Mimics the 'delta' object returned by OpenAI streaming chunks. + Example usage in code: + chunk.choices[0].delta.content + """ + def __init__(self, role: Optional[str] = None, content: Optional[str] = None): + self.role = role + self.content = content + + +class ChatCompletionStreamResponseChoice: + """ + Holds the 'delta' for a single chunk choice. + Example usage in code: + chunk.choices[0].delta + """ + def __init__(self, delta: ChatCompletionStreamResponseDelta): + self.delta = delta + + +class ChatCompletionStreamResponse: + """ + Container for streaming response chunks. + Each chunk has a 'choices' list, each with a 'delta'. + Example usage in code: + chunk.choices[0].delta.content + """ + def __init__(self, choices: list[ChatCompletionStreamResponseChoice]): + self.choices = choices + + + diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index b7edf71c..7eff25b5 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -1,13 +1,19 @@ -# Anthropic provider -# Links: -# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use +# aisuite/providers/anthropic_provider.py import anthropic import json + from aisuite.provider import Provider from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function +# Import our new streaming response classes: +from aisuite.framework.chat_completion_stream_response import ( + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + ChatCompletionStreamResponseDelta, +) + # Define a constant for the default max_tokens value DEFAULT_MAX_TOKENS = 4096 @@ -33,7 +39,7 @@ def convert_request(self, messages): return system_message, converted_messages def convert_response(self, response): - """Normalize the response from the Anthropic API to match OpenAI's response format.""" + """Normalize a non-streaming response from the Anthropic API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() normalized_response.choices[0].finish_reason = self._get_finish_reason(response) normalized_response.usage = self._get_usage_stats(response) @@ -57,7 +63,7 @@ def _convert_dict_message(self, msg): return {"role": msg["role"], "content": msg["content"]} def _convert_message_object(self, msg): - """Convert a Message object to Anthropic format.""" + """Convert a `Message` object to Anthropic format.""" if msg.role == self.ROLE_TOOL: return self._create_tool_result_message(msg.tool_call_id, msg.content) elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls: @@ -107,22 +113,23 @@ def _create_assistant_tool_message(self, content, tool_calls): return {"role": self.ROLE_ASSISTANT, "content": message_content} def _extract_system_message(self, messages): - """Extract system message if present, otherwise return empty list.""" - # TODO: This is a temporary solution to extract the system message. - # User can pass multiple system messages, which can mingled with other messages. - # This needs to be fixed to handle this case. + """ + Extract system message if present, otherwise return empty string. + If there are multiple system messages, or the system message is not the first, + you may need to adapt this approach. + """ if messages and messages[0]["role"] == "system": system_message = messages[0]["content"] messages.pop(0) return system_message - return [] + return "" def _get_finish_reason(self, response): """Get the normalized finish reason.""" return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop") def _get_usage_stats(self, response): - """Get the usage statistics.""" + """Get the usage statistics from Anthropic response.""" return { "prompt_tokens": response.usage.input_tokens, "completion_tokens": response.usage.output_tokens, @@ -135,9 +142,8 @@ def _get_message(self, response): tool_message = self.convert_response_with_tool_use(response) if tool_message: return tool_message - return Message( - content=response.content[0].text, + content=response.content[0].text if response.content else "", role="assistant", tool_calls=None, refusal=None, @@ -146,26 +152,22 @@ def _get_message(self, response): def convert_response_with_tool_use(self, response): """Convert Anthropic tool use response to the framework's format.""" tool_call = next( - (content for content in response.content if content.type == "tool_use"), + (c for c in response.content if c.type == "tool_use"), None, ) - if tool_call: function = Function( - name=tool_call.name, arguments=json.dumps(tool_call.input) + name=tool_call.name, + arguments=json.dumps(tool_call.input), ) tool_call_obj = ChatCompletionMessageToolCall( - id=tool_call.id, function=function, type="function" + id=tool_call.id, + function=function, + type="function", ) text_content = next( - ( - content.text - for content in response.content - if content.type == "text" - ), - "", + (c.text for c in response.content if c.type == "text"), "" ) - return Message( content=text_content or None, tool_calls=[tool_call_obj] if tool_call else None, @@ -177,11 +179,9 @@ def convert_response_with_tool_use(self, response): def convert_tool_spec(self, openai_tools): """Convert OpenAI tool specification to Anthropic format.""" anthropic_tools = [] - for tool in openai_tools: if tool.get("type") != "function": continue - function = tool["function"] anthropic_tool = { "name": function["name"], @@ -193,7 +193,6 @@ def convert_tool_spec(self, openai_tools): }, } anthropic_tools.append(anthropic_tool) - return anthropic_tools @@ -204,17 +203,70 @@ def __init__(self, **config): self.converter = AnthropicMessageConverter() def chat_completions_create(self, model, messages, **kwargs): - """Create a chat completion using the Anthropic API.""" + """ + Create a chat completion using the Anthropic API. + + If 'stream=True' is passed, return a generator that yields + `ChatCompletionStreamResponse` objects shaped like OpenAI's streaming chunks. + """ + stream = kwargs.pop("stream", False) + + if not stream: + # Non-streaming call + kwargs = self._prepare_kwargs(kwargs) + system_message, converted_messages = self.converter.convert_request(messages) + response = self.client.messages.create( + model=model, + system=system_message, + messages=converted_messages, + **kwargs + ) + return self.converter.convert_response(response) + else: + # Streaming call + return self._streaming_chat_completions_create(model, messages, **kwargs) + + def _streaming_chat_completions_create(self, model, messages, **kwargs): + """ + Generator that yields chunk objects in the shape: + chunk.choices[0].delta.content + """ kwargs = self._prepare_kwargs(kwargs) system_message, converted_messages = self.converter.convert_request(messages) - - response = self.client.messages.create( - model=model, system=system_message, messages=converted_messages, **kwargs - ) - return self.converter.convert_response(response) + first_chunk = True + + with self.client.messages.stream( + model=model, + system=system_message, + messages=converted_messages, + **kwargs + ) as stream_resp: + + for partial_text in stream_resp.text_stream: + # For the first token, include `role='assistant'`. + if first_chunk: + chunk = ChatCompletionStreamResponse(choices=[ + ChatCompletionStreamResponseChoice( + delta=ChatCompletionStreamResponseDelta( + role="assistant", + content=partial_text + ) + ) + ]) + first_chunk = False + else: + chunk = ChatCompletionStreamResponse(choices=[ + ChatCompletionStreamResponseChoice( + delta=ChatCompletionStreamResponseDelta( + content=partial_text + ) + ) + ]) + + yield chunk def _prepare_kwargs(self, kwargs): - """Prepare kwargs for the API call.""" + """Prepare kwargs for the Anthropic API call.""" kwargs = kwargs.copy() kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) @@ -222,3 +274,7 @@ def _prepare_kwargs(self, kwargs): kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"]) return kwargs + + + + diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index 8cb1b6c5..cc8111b2 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -1,7 +1,6 @@ import openai import os from aisuite.provider import Provider, LLMError -from aisuite.providers.message_converter import OpenAICompliantMessageConverter class OpenaiProvider(Provider): @@ -14,27 +13,63 @@ def __init__(self, **config): config.setdefault("api_key", os.getenv("OPENAI_API_KEY")) if not config["api_key"]: raise ValueError( - "OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + "OpenAI API key is missing. Please provide it in the config " + "or set the OPENAI_API_KEY environment variable." ) - # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically - # infer certain values from the environment variables. - # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. - # Pass the entire config to the OpenAI client constructor + # (Note: This assumes openai.OpenAI(...) is valid in your environment. + # If you typically do `openai.api_key = ...`, adapt as needed.) self.client = openai.OpenAI(**config) - self.transformer = OpenAICompliantMessageConverter() def chat_completions_create(self, model, messages, **kwargs): - # Any exception raised by OpenAI will be returned to the caller. - # Maybe we should catch them and raise a custom LLMError. - try: - transformed_messages = self.transformer.convert_request(messages) - response = self.client.chat.completions.create( + """ + Create chat completion using the OpenAI API. + If 'stream=True' is passed via kwargs, return a generator that yields + chunked responses in the OpenAI streaming format. + """ + stream = kwargs.pop("stream", False) + + if not stream: + # Non-streaming call + return self.client.chat.completions.create( model=model, - messages=transformed_messages, - **kwargs, # Pass any additional arguments to the OpenAI API + messages=messages, + **kwargs ) - return response - except Exception as e: - raise LLMError(f"An error occurred: {e}") + else: + # Streaming call: return a generator that yields each chunk + return self._streaming_chat_completions_create(model, messages, **kwargs) + + def _streaming_chat_completions_create(self, model, messages, **kwargs): + """ + Internal helper method that yields chunked responses for streaming. + Each chunk is already in the OpenAI streaming format: + + { + "id": ..., + "object": "chat.completion.chunk", + "created": ..., + "model": ..., + "choices": [ + { + "delta": { + "role": "assistant" or "content": ... + } + } + ] + } + """ + response_gen = self.client.chat.completions.create( + model=model, + messages=messages, + stream=True, + **kwargs + ) + + # Yield chunks as they arrive + for chunk in response_gen: + yield chunk + + +