diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index c6337408d7..40f70cbee8 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -2525,122 +2525,79 @@ def _try_format_message( except ValidationError: return False - def _check_tools_strict_compatibility(self) -> bool: - r"""Check if all tools are compatible with OpenAI strict mode. - - Returns: - bool: True if all tools are strict mode compatible, - False otherwise. - """ - tool_schemas = self._get_full_tool_schemas() - for schema in tool_schemas: - if not schema.get("function", {}).get("strict", True): - return False - return True - - def _convert_response_format_to_prompt( - self, response_format: Type[BaseModel] - ) -> str: - r"""Convert a Pydantic response format to a prompt instruction. + @staticmethod + def _collect_tool_calls_from_completion( + tool_calls: List[Any], + accumulated_tool_calls: Dict[str, Any], + ) -> None: + r"""Convert tool calls from a ChatCompletion into the accumulated + tool-call dictionary format used by the streaming pipeline. Args: - response_format (Type[BaseModel]): The Pydantic model class. - - Returns: - str: A prompt instruction requesting the specific format. + tool_calls: Tool call objects from + ``completion.choices[0].message.tool_calls``. + accumulated_tool_calls: Mutable dict that will be populated with + the converted entries. """ - try: - # Get the JSON schema from the Pydantic model - schema = response_format.model_json_schema() - - # Create a prompt based on the schema - format_instruction = ( - "\n\nPlease respond in the following JSON format:\n{\n" - ) - - properties = schema.get("properties", {}) - for field_name, field_info in properties.items(): - field_type = field_info.get("type", "string") - description = field_info.get("description", "") - - if field_type == "array": - format_instruction += ( - f' "{field_name}": ["array of values"]' - ) - elif field_type == "object": - format_instruction += f' "{field_name}": {{"object"}}' - elif field_type == "boolean": - format_instruction += f' "{field_name}": true' - elif field_type == "number": - format_instruction += f' "{field_name}": 0' - else: - format_instruction += f' "{field_name}": "string value"' - - if description: - format_instruction += f' // {description}' - - # Add comma if not the last item - if field_name != list(properties.keys())[-1]: - format_instruction += "," - format_instruction += "\n" - - format_instruction += "}" - return format_instruction - - except Exception as e: - logger.warning( - f"Failed to convert response_format to prompt: {e}. " - f"Using generic format instruction." - ) - return ( - "\n\nPlease respond in a structured JSON format " - "that matches the requested schema." - ) + for tc in tool_calls: + accumulated_tool_calls[tc.id] = { + 'id': tc.id, + 'function': { + 'name': tc.function.name, + 'arguments': tc.function.arguments, + }, + 'complete': True, + } - def _handle_response_format_with_non_strict_tools( + def _record_and_build_display_message( self, - input_message: Union[BaseMessage, str], - response_format: Optional[Type[BaseModel]] = None, - ) -> Tuple[Union[BaseMessage, str], Optional[Type[BaseModel]], bool]: - r"""Handle response format when tools are not strict mode compatible. + final_content: str, + parsed_object: Any, + final_reasoning: Optional[str], + response_format: Optional[Type[BaseModel]], + ) -> BaseMessage: + r"""Record the full message to memory and build a display message. + + In delta mode the display message has empty content because all + content was already yielded incrementally. In accumulate mode the + display message carries the full content. Args: - input_message: The original input message. - response_format: The requested response format. + final_content: The full final content string. + parsed_object: The parsed object from structured output stream. + final_reasoning: The reasoning content, if any. + response_format: The (possibly modified) response format. Returns: - Tuple: (modified_message, modified_response_format, - used_prompt_formatting) + BaseMessage: The display message to yield to the caller. """ - if response_format is None: - return input_message, response_format, False - - # Check if tools are strict mode compatible - if self._check_tools_strict_compatibility(): - return input_message, response_format, False + parsed_cast = cast("BaseModel | dict[str, Any] | None", parsed_object) # type: ignore[arg-type] - # Tools are not strict compatible, convert to prompt - logger.info( - "Non-strict tools detected. Converting response_format to " - "prompt-based formatting." + # Record full content to memory + record_msg = BaseMessage( + role_name=self.role_name, + role_type=self.role_type, + meta_dict={}, + content=final_content, + parsed=parsed_cast, + reasoning_content=final_reasoning, ) - - format_prompt = self._convert_response_format_to_prompt( - response_format + if response_format: + self._try_format_message(record_msg, response_format) + self.record_message(record_msg) + + # Build display message (empty content in delta mode) + display_content = final_content if self.stream_accumulate else "" + display_reasoning = final_reasoning if self.stream_accumulate else None + display_msg = BaseMessage( + role_name=self.role_name, + role_type=self.role_type, + meta_dict={}, + content=display_content, + parsed=record_msg.parsed, + reasoning_content=display_reasoning, ) - - # Modify the message to include format instruction - modified_message: Union[BaseMessage, str] - if isinstance(input_message, str): - modified_message = input_message + format_prompt - else: - modified_message = input_message.create_new_instance( - input_message.content + format_prompt - ) - - # Return None for response_format to avoid strict mode conflicts - # and True to indicate we used prompt formatting - return modified_message, None, True + return display_msg def _is_called_from_registered_toolkit(self) -> bool: r"""Check if current step/astep call originates from a @@ -2669,66 +2626,6 @@ def _is_called_from_registered_toolkit(self) -> bool: return False - def _apply_prompt_based_parsing( - self, - response: ModelResponse, - original_response_format: Type[BaseModel], - ) -> None: - r"""Apply manual parsing when using prompt-based formatting. - - Args: - response: The model response to parse. - original_response_format: The original response format class. - """ - for message in response.output_messages: - if message.content: - try: - # Try to extract JSON from the response content - import json - - from pydantic import ValidationError - - # Try to find JSON in the content - content = message.content.strip() - - # Try direct parsing first - try: - parsed_json = json.loads(content) - message.parsed = ( - original_response_format.model_validate( - parsed_json - ) - ) - continue - except (json.JSONDecodeError, ValidationError): - pass - - # Try to extract JSON from text - json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' - json_matches = re.findall(json_pattern, content, re.DOTALL) - - for json_str in json_matches: - try: - parsed_json = json.loads(json_str) - message.parsed = ( - original_response_format.model_validate( - parsed_json - ) - ) - # Update content to just the JSON for consistency - message.content = json.dumps(parsed_json) - break - except (json.JSONDecodeError, ValidationError): - continue - - if not message.parsed: - logger.warning( - f"Failed to parse JSON from response: {content}" - ) - - except Exception as e: - logger.warning(f"Error during prompt-based parsing: {e}") - def _format_response_if_needed( self, response: ModelResponse, @@ -2867,14 +2764,6 @@ def _step_impl( # use disable_tools = self._is_called_from_registered_toolkit() - # Handle response format compatibility with non-strict tools - original_response_format = response_format - input_message, response_format, used_prompt_formatting = ( - self._handle_response_format_with_non_strict_tools( - input_message, response_format - ) - ) - # Convert input message to BaseMessage if necessary if isinstance(input_message, str): input_message = BaseMessage.make_user_message( @@ -3051,12 +2940,6 @@ def _step_impl( self._format_response_if_needed(response, response_format) - # Apply manual parsing if we used prompt-based formatting - if used_prompt_formatting and original_response_format: - self._apply_prompt_based_parsing( - response, original_response_format - ) - # Only record final output if we haven't already recorded tool calls # for this response (to avoid duplicate assistant messages) if not recorded_tool_calls: @@ -3170,14 +3053,6 @@ async def _astep_non_streaming_task( # use disable_tools = self._is_called_from_registered_toolkit() - # Handle response format compatibility with non-strict tools - original_response_format = response_format - input_message, response_format, used_prompt_formatting = ( - self._handle_response_format_with_non_strict_tools( - input_message, response_format - ) - ) - if isinstance(input_message, str): input_message = BaseMessage.make_user_message( role_name="User", content=input_message @@ -3353,12 +3228,6 @@ async def _astep_non_streaming_task( await self._aformat_response_if_needed(response, response_format) - # Apply manual parsing if we used prompt-based formatting - if used_prompt_formatting and original_response_format: - self._apply_prompt_based_parsing( - response, original_response_format - ) - # Only record final output if we haven't already recorded tool calls # for this response (to avoid duplicate assistant messages) if not recorded_tool_calls: @@ -4176,13 +4045,6 @@ def _stream( content, tool calls, and other information as they become available. """ - # Handle response format compatibility with non-strict tools - input_message, response_format, _ = ( - self._handle_response_format_with_non_strict_tools( - input_message, response_format - ) - ) - # Convert input message to BaseMessage if necessary if isinstance(input_message, str): input_message = BaseMessage.make_user_message( @@ -4203,7 +4065,9 @@ def _stream( # Start streaming response yield from self._stream_response( - openai_messages, num_tokens, response_format + openai_messages, + num_tokens, + response_format, ) def _get_token_count(self, content: str) -> int: @@ -4374,28 +4238,72 @@ def _stream_response( # Get final completion and record final message try: final_completion = stream.get_final_completion() - final_content = ( - final_completion.choices[0].message.content or "" + + # Check if the model wants to call tools + final_choice = final_completion.choices[0] + final_tool_calls = getattr( + final_choice.message, 'tool_calls', None ) + if final_tool_calls: + self._collect_tool_calls_from_completion( + final_tool_calls, + accumulated_tool_calls, + ) + + # Execute tools + for status_response in ( + self + )._execute_tools_sync_with_status_accumulator( + accumulated_tool_calls, + tool_call_records, + ): + yield status_response + + if tool_call_records: + logger.info("Sending back result to model") + + # Update usage + if final_completion.usage: + self._update_token_usage_tracker( + step_token_usage, + safe_model_dump(final_completion.usage), + ) + + # Continue the loop for next model call + accumulated_tool_calls.clear() + if tool_call_records and ( + self.max_iteration is None + or iteration_count < self.max_iteration + ): + try: + openai_messages, num_tokens = ( + self.memory.get_context() + ) + except RuntimeError as e: + yield self._step_terminate( + e.args[1], + tool_call_records, + "max_tokens_exceeded", + ) + return + content_accumulator.reset_streaming_content() + continue + else: + break + + final_content = final_choice.message.content or "" final_reasoning = ( content_accumulator.get_full_reasoning_content() or None ) - final_message = BaseMessage( - role_name=self.role_name, - role_type=self.role_type, - meta_dict={}, - content=final_content, - parsed=cast( - "BaseModel | dict[str, Any] | None", - parsed_object, - ), # type: ignore[arg-type] - reasoning_content=final_reasoning, + final_message = self._record_and_build_display_message( + final_content, + parsed_object, + final_reasoning, + response_format, ) - self.record_message(final_message) - # Create final response final_response = ChatAgentResponse( msgs=[final_message], @@ -5154,7 +5062,9 @@ async def _astream( # Start async streaming response last_response = None async for response in self._astream_response( - openai_messages, num_tokens, response_format + openai_messages, + num_tokens, + response_format, ): last_response = response yield response @@ -5313,28 +5223,74 @@ async def _astream_response( # Get final completion and record final message try: final_completion = await stream.get_final_completion() - final_content = ( - final_completion.choices[0].message.content or "" + + # Check if the model wants to call tools + final_choice = final_completion.choices[0] + final_tool_calls = getattr( + final_choice.message, 'tool_calls', None ) + if final_tool_calls: + self._collect_tool_calls_from_completion( + final_tool_calls, + accumulated_tool_calls, + ) + + # Execute tools + async for status_response in ( + self + )._execute_tools_async_with_status_accumulator( + accumulated_tool_calls, + content_accumulator, + step_token_usage, + tool_call_records, + ): + yield status_response + + if tool_call_records: + logger.info("Sending back result to model") + + # Update usage + if final_completion.usage: + self._update_token_usage_tracker( + step_token_usage, + safe_model_dump(final_completion.usage), + ) + + # Continue the loop for next model call + accumulated_tool_calls.clear() + if tool_call_records and ( + self.max_iteration is None + or iteration_count < self.max_iteration + ): + try: + openai_messages, num_tokens = ( + self.memory.get_context() + ) + except RuntimeError as e: + yield self._step_terminate( + e.args[1], + tool_call_records, + "max_tokens_exceeded", + ) + return + content_accumulator.reset_streaming_content() + continue + else: + break + + final_content = final_choice.message.content or "" final_reasoning = ( content_accumulator.get_full_reasoning_content() or None ) - final_message = BaseMessage( - role_name=self.role_name, - role_type=self.role_type, - meta_dict={}, - content=final_content, - parsed=cast( - "BaseModel | dict[str, Any] | None", - parsed_object, - ), # type: ignore[arg-type] - reasoning_content=final_reasoning, + final_message = self._record_and_build_display_message( + final_content, + parsed_object, + final_reasoning, + response_format, ) - self.record_message(final_message) - # Create final response final_response = ChatAgentResponse( msgs=[final_message], diff --git a/camel/models/anthropic_model.py b/camel/models/anthropic_model.py index a362a99df5..cd868392bd 100644 --- a/camel/models/anthropic_model.py +++ b/camel/models/anthropic_model.py @@ -14,7 +14,6 @@ import json import os import time -import warnings from typing import Any, Dict, List, Optional, Type, Union, cast from openai import AsyncStream, Stream @@ -636,6 +635,37 @@ def _convert_anthropic_stream_to_openai_chunk( usage=usage, ) + @staticmethod + def _add_additional_properties_false(schema: Dict[str, Any]) -> None: + r"""Recursively add additionalProperties: false to all object types.""" + if schema.get("type") == "object": + schema["additionalProperties"] = False + for value in schema.values(): + if isinstance(value, dict): + AnthropicModel._add_additional_properties_false(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + AnthropicModel._add_additional_properties_false(item) + + @staticmethod + def _build_output_config( + response_format: Type[BaseModel], + ) -> Dict[str, Any]: + r"""Convert a Pydantic model to Anthropic's output_config format.""" + schema = response_format.model_json_schema() + # Remove unsupported fields + schema.pop("$defs", None) + schema.pop("definitions", None) + # Anthropic requires additionalProperties: false on all object types + AnthropicModel._add_additional_properties_false(schema) + return { + "format": { + "type": "json_schema", + "schema": schema, + } + } + def _convert_openai_tools_to_anthropic( self, tools: Optional[List[Dict[str, Any]]] ) -> Optional[List[Dict[str, Any]]]: @@ -678,7 +708,7 @@ def _run( messages (List[OpenAIMessage]): Message list with the chat history in OpenAI API format. response_format (Optional[Type[BaseModel]]): The format of the - response. (Not supported by Anthropic API directly) + response. tools (Optional[List[Dict[str, Any]]]): The schema of the tools to use for the request. @@ -687,15 +717,6 @@ def _run( `ChatCompletion` in the non-stream mode, or `Stream[ChatCompletionChunk]` in the stream mode. """ - if response_format is not None: - warnings.warn( - "The 'response_format' parameter is not supported by the " - "Anthropic API and will be ignored. Consider using tools " - "for structured output instead.", - UserWarning, - stacklevel=2, - ) - # Update Langfuse trace with current agent session and metadata agent_session_id = get_current_agent_session_id() if agent_session_id: @@ -764,6 +785,12 @@ def _run( if key in self.model_config_dict: request_params[key] = self.model_config_dict[key] + # Add structured output via output_config + if response_format is not None: + request_params["output_config"] = self._build_output_config( + response_format + ) + # Convert tools anthropic_tools = self._convert_openai_tools_to_anthropic(tools) if anthropic_tools: @@ -803,7 +830,7 @@ async def _arun( messages (List[OpenAIMessage]): Message list with the chat history in OpenAI API format. response_format (Optional[Type[BaseModel]]): The format of the - response. (Not supported by Anthropic API directly) + response. tools (Optional[List[Dict[str, Any]]]): The schema of the tools to use for the request. @@ -812,15 +839,6 @@ async def _arun( `ChatCompletion` in the non-stream mode, or `AsyncStream[ChatCompletionChunk]` in the stream mode. """ - if response_format is not None: - warnings.warn( - "The 'response_format' parameter is not supported by the " - "Anthropic API and will be ignored. Consider using tools " - "for structured output instead.", - UserWarning, - stacklevel=2, - ) - # Update Langfuse trace with current agent session and metadata agent_session_id = get_current_agent_session_id() if agent_session_id: @@ -889,6 +907,12 @@ async def _arun( if key in self.model_config_dict: request_params[key] = self.model_config_dict[key] + # Add structured output via output_config + if response_format is not None: + request_params["output_config"] = self._build_output_config( + response_format + ) + # Convert tools anthropic_tools = self._convert_openai_tools_to_anthropic(tools) if anthropic_tools: diff --git a/camel/models/gemini_model.py b/camel/models/gemini_model.py index e4a2d60e13..f4b5491f1a 100644 --- a/camel/models/gemini_model.py +++ b/camel/models/gemini_model.py @@ -24,6 +24,10 @@ ) from openai import AsyncStream, Stream +from openai.lib.streaming.chat import ( + AsyncChatCompletionStreamManager, + ChatCompletionStreamManager, +) from pydantic import BaseModel from camel.configs import GeminiConfig @@ -451,6 +455,47 @@ async def async_thought_preserving_generator(): return async_thought_preserving_generator() + @staticmethod + def _clean_gemini_tools( + tools: Optional[List[Dict[str, Any]]], + ) -> Optional[List[Dict[str, Any]]]: + r"""Clean tools for Gemini API compatibility. + + Removes unsupported fields like strict, anyOf, and restricts + enum/format to allowed types. + """ + if not tools: + return tools + import copy + + tools = copy.deepcopy(tools) + for tool in tools: + function_dict = tool.get('function', {}) + function_dict.pop("strict", None) + + if 'parameters' in function_dict: + params = function_dict['parameters'] + if 'properties' in params: + for prop_name, prop_value in params['properties'].items(): + if 'anyOf' in prop_value: + first_type = prop_value['anyOf'][0] + params['properties'][prop_name] = first_type + if 'description' in prop_value: + params['properties'][prop_name][ + 'description' + ] = prop_value['description'] + + if prop_value.get('type') != 'string': + prop_value.pop('enum', None) + + if prop_value.get('type') not in [ + 'string', + 'integer', + 'number', + ]: + prop_value.pop('format', None) + return tools + @observe() def _run( self, @@ -479,19 +524,18 @@ def _run( "response_format", None ) messages = self._process_messages(messages) + is_streaming = self.model_config_dict.get("stream", False) + if response_format: - if tools: - raise ValueError( - "Gemini does not support function calling with " - "response format." + tools = self._clean_gemini_tools(tools) + if is_streaming: + return self._request_stream_parse( # type: ignore[return-value] + messages, response_format, tools ) - result: Union[ChatCompletion, Stream[ChatCompletionChunk]] = ( - self._request_parse(messages, response_format) - ) + else: + return self._request_parse(messages, response_format, tools) else: - result = self._request_chat_completion(messages, tools) - - return result + return self._request_chat_completion(messages, tools) @observe() async def _arun( @@ -521,67 +565,90 @@ async def _arun( "response_format", None ) messages = self._process_messages(messages) + is_streaming = self.model_config_dict.get("stream", False) + if response_format: - if tools: - raise ValueError( - "Gemini does not support function calling with " - "response format." + tools = self._clean_gemini_tools(tools) + if is_streaming: + return await self._arequest_stream_parse( # type: ignore[return-value] + messages, response_format, tools + ) + else: + return await self._arequest_parse( + messages, response_format, tools ) - result: Union[ - ChatCompletion, AsyncStream[ChatCompletionChunk] - ] = await self._arequest_parse(messages, response_format) else: - result = await self._arequest_chat_completion(messages, tools) + return await self._arequest_chat_completion(messages, tools) + + @staticmethod + def _build_gemini_response_format( + response_format: Type[BaseModel], + ) -> Dict[str, Any]: + r"""Convert a Pydantic model to Gemini-compatible response_format.""" + schema = response_format.model_json_schema() + # Remove $defs and other unsupported fields for Gemini + schema.pop("$defs", None) + schema.pop("definitions", None) + return { + "type": "json_schema", + "json_schema": { + "name": response_format.__name__, + "schema": schema, + }, + } + + def _request_stream_parse( + self, + messages: List[OpenAIMessage], + response_format: Type[BaseModel], + tools: Optional[List[Dict[str, Any]]] = None, + ) -> ChatCompletionStreamManager[BaseModel]: + r"""Gemini-specific streaming structured output. + + Uses regular streaming with response_format as JSON schema + instead of OpenAI's beta streaming API which is incompatible + with Gemini's tool call delta format. + """ + request_config = self._prepare_request_config(tools) + request_config["stream"] = True + request_config["response_format"] = self._build_gemini_response_format( + response_format + ) + + response = self._client.chat.completions.create( + messages=messages, + model=self.model_type, + **request_config, + ) + return self._preserve_thought_signatures(response) # type: ignore[return-value] - return result + async def _arequest_stream_parse( + self, + messages: List[OpenAIMessage], + response_format: Type[BaseModel], + tools: Optional[List[Dict[str, Any]]] = None, + ) -> AsyncChatCompletionStreamManager[BaseModel]: + r"""Gemini-specific async streaming structured output.""" + request_config = self._prepare_request_config(tools) + request_config["stream"] = True + request_config["response_format"] = self._build_gemini_response_format( + response_format + ) + + response = await self._async_client.chat.completions.create( + messages=messages, + model=self.model_type, + **request_config, + ) + return self._preserve_thought_signatures(response) # type: ignore[return-value] def _request_chat_completion( self, messages: List[OpenAIMessage], tools: Optional[List[Dict[str, Any]]] = None, ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: - import copy - - request_config = copy.deepcopy(self.model_config_dict) - # Remove strict and anyOf from each tool's function parameters since - # Gemini does not support them - if tools: - for tool in tools: - function_dict = tool.get('function', {}) - function_dict.pop("strict", None) - - # Process parameters to remove anyOf and handle enum/format - if 'parameters' in function_dict: - params = function_dict['parameters'] - if 'properties' in params: - for prop_name, prop_value in params[ - 'properties' - ].items(): - if 'anyOf' in prop_value: - # Replace anyOf with the first type in the list - first_type = prop_value['anyOf'][0] - params['properties'][prop_name] = first_type - # Preserve description if it exists - if 'description' in prop_value: - params['properties'][prop_name][ - 'description' - ] = prop_value['description'] - - # Handle enum and format restrictions for Gemini - # API enum: only allowed for string type - if prop_value.get('type') != 'string': - prop_value.pop('enum', None) - - # format: only allowed for string, integer, and - # number types - if prop_value.get('type') not in [ - 'string', - 'integer', - 'number', - ]: - prop_value.pop('format', None) - - request_config["tools"] = tools + tools = self._clean_gemini_tools(tools) + request_config = self._prepare_request_config(tools) response = self._client.chat.completions.create( messages=messages, @@ -597,48 +664,8 @@ async def _arequest_chat_completion( messages: List[OpenAIMessage], tools: Optional[List[Dict[str, Any]]] = None, ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]: - import copy - - request_config = copy.deepcopy(self.model_config_dict) - # Remove strict and anyOf from each tool's function parameters since - # Gemini does not support them - if tools: - for tool in tools: - function_dict = tool.get('function', {}) - function_dict.pop("strict", None) - - # Process parameters to remove anyOf and handle enum/format - if 'parameters' in function_dict: - params = function_dict['parameters'] - if 'properties' in params: - for prop_name, prop_value in params[ - 'properties' - ].items(): - if 'anyOf' in prop_value: - # Replace anyOf with the first type in the list - first_type = prop_value['anyOf'][0] - params['properties'][prop_name] = first_type - # Preserve description if it exists - if 'description' in prop_value: - params['properties'][prop_name][ - 'description' - ] = prop_value['description'] - - # Handle enum and format restrictions for Gemini - # API enum: only allowed for string type - if prop_value.get('type') != 'string': - prop_value.pop('enum', None) - - # format: only allowed for string, integer, and - # number types - if prop_value.get('type') not in [ - 'string', - 'integer', - 'number', - ]: - prop_value.pop('format', None) - - request_config["tools"] = tools + tools = self._clean_gemini_tools(tools) + request_config = self._prepare_request_config(tools) response = await self._async_client.chat.completions.create( messages=messages, diff --git a/examples/agents/chatagent_stream_structured_output.py b/examples/agents/chatagent_stream_structured_output.py new file mode 100644 index 0000000000..5f332d308e --- /dev/null +++ b/examples/agents/chatagent_stream_structured_output.py @@ -0,0 +1,112 @@ +# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= + +""" +Example: Streaming + Tool Calls + Structured Output + +Demonstrates using ChatAgent in streaming mode with both tools and +response_format (structured output) simultaneously, in both sync and async. +""" + +import asyncio + +from pydantic import BaseModel, Field + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.toolkits import MathToolkit +from camel.types import ModelPlatformType, ModelType + + +class Result(BaseModel): + sum_result: str = Field(description="Only the result of the addition") + product_result: str = Field( + description="Only the result of the multiplication" + ) + division_result: str = Field(description="Only the result of the division") + capital_result: str = Field( + description="Only the result of the capital search" + ) + + +USER_MESSAGE = ( + "Calculate: 1) 123.45 + 678.90 2) 100 * 3.14159 3) 1000 / 7, " + "also search what is the capital of Germany" +) + + +def create_agent() -> ChatAgent: + streaming_model = ModelFactory.create( + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, + model_config_dict={ + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + return ChatAgent( + system_message="You are a helpful assistant.", + model=streaming_model, + tools=MathToolkit().get_tools(), + stream_accumulate=False, # Delta mode + ) + + +def sync_example(): + """Sync streaming with tools + structured output.""" + print("=== Sync Example ===") + agent = create_agent() + + streaming_response = agent.step(USER_MESSAGE, response_format=Result) + + content_parts = [] + for chunk in streaming_response: + if chunk.msgs[0].content: + content_parts.append(chunk.msgs[0].content) + print(chunk.msgs[0].content, end="", flush=True) + + print() + + # Print tool call records + tool_calls = streaming_response.info.get("tool_calls", []) + if tool_calls: + print(f"\nTool calls made: {len(tool_calls)}") + for i, tc in enumerate(tool_calls, 1): + print(f" {i}. {tc.tool_name}({tc.args}) = {tc.result}") + + # Check parsed output + final_msg = streaming_response.msgs[0] + if final_msg.parsed: + print(f"\nParsed result: {final_msg.parsed}") + + +async def async_example(): + """Async streaming with tools + structured output.""" + print("\n=== Async Example ===") + agent = create_agent() + + content_parts = [] + async for chunk in await agent.astep(USER_MESSAGE, response_format=Result): + if chunk.msgs[0].content: + content_parts.append(chunk.msgs[0].content) + print(chunk.msgs[0].content, end="", flush=True) + + print() + full_content = "".join(content_parts) + print(f"\nFull content: {full_content}") + + +if __name__ == "__main__": + sync_example() + asyncio.run(async_example())