diff --git a/src/praisonai-agents/praisonaiagents/llm/llm.py b/src/praisonai-agents/praisonaiagents/llm/llm.py index 1f9f6ae27..71bc27429 100644 --- a/src/praisonai-agents/praisonaiagents/llm/llm.py +++ b/src/praisonai-agents/praisonaiagents/llm/llm.py @@ -422,13 +422,22 @@ def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_ """ messages = [] + # Check if this is a Gemini model that supports native structured outputs + is_gemini_with_structured_output = False + if output_json or output_pydantic: + from .model_capabilities import supports_structured_outputs + is_gemini_with_structured_output = ( + self._is_gemini_model() and + supports_structured_outputs(self.model) + ) + # Handle system prompt if system_prompt: - # Append JSON schema if needed - if output_json: - system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {json.dumps(output_json.model_json_schema())}" - elif output_pydantic: - system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {json.dumps(output_pydantic.model_json_schema())}" + # Only append JSON schema for non-Gemini models or Gemini models without structured output support + if (output_json or output_pydantic) and not is_gemini_with_structured_output: + schema_model = output_json or output_pydantic + if schema_model and hasattr(schema_model, 'model_json_schema'): + system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {json.dumps(schema_model.model_json_schema())}" # Skip system messages for legacy o1 models as they don't support them if not self._needs_system_message_skip(): @@ -440,7 +449,8 @@ def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_ # Handle prompt modifications for JSON output original_prompt = prompt - if output_json or output_pydantic: + if (output_json or output_pydantic) and not is_gemini_with_structured_output: + # Only modify prompt for non-Gemini models if isinstance(prompt, str): prompt = prompt + "\nReturn ONLY a valid JSON object. No other text or explanation." elif isinstance(prompt, list): @@ -695,6 +705,8 @@ def get_response( temperature=temperature, stream=False, # force non-streaming tools=formatted_tools, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ) @@ -741,6 +753,8 @@ def get_response( tools=formatted_tools, temperature=temperature, stream=True, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ): @@ -760,6 +774,8 @@ def get_response( tools=formatted_tools, temperature=temperature, stream=True, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ): @@ -791,6 +807,8 @@ def get_response( tools=formatted_tools, temperature=temperature, stream=False, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ) @@ -944,6 +962,8 @@ def get_response( temperature=temperature, stream=False, # Force non-streaming response_format={"type": "json_object"}, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ) @@ -979,6 +999,8 @@ def get_response( temperature=temperature, stream=stream, response_format={"type": "json_object"}, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ): @@ -994,6 +1016,8 @@ def get_response( temperature=temperature, stream=stream, response_format={"type": "json_object"}, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ): @@ -1039,6 +1063,8 @@ def get_response( messages=messages, temperature=temperature, stream=True, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ): @@ -1053,6 +1079,8 @@ def get_response( messages=messages, temperature=temperature, stream=True, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ): @@ -1089,6 +1117,12 @@ def get_response( total_time = time.time() - start_time logging.debug(f"get_response completed in {total_time:.2f} seconds") + def _is_gemini_model(self) -> bool: + """Check if the model is a Gemini model.""" + if not self.model: + return False + return any(prefix in self.model.lower() for prefix in ['gemini', 'gemini/', 'google/gemini']) + async def get_response_async( self, prompt: Union[str, List[Dict]], @@ -1197,10 +1231,12 @@ async def get_response_async( resp = await litellm.acompletion( **self._build_completion_params( messages=messages, - temperature=temperature, - stream=False, # force non-streaming - **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} - ) + temperature=temperature, + stream=False, # force non-streaming + output_json=output_json, + output_pydantic=output_pydantic, + **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} + ) ) reasoning_content = resp["choices"][0]["message"].get("provider_specific_fields", {}).get("reasoning_content") response_text = resp["choices"][0]["message"]["content"] @@ -1239,6 +1275,8 @@ async def get_response_async( temperature=temperature, stream=True, tools=formatted_tools, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ): @@ -1259,6 +1297,8 @@ async def get_response_async( temperature=temperature, stream=True, tools=formatted_tools, + output_json=output_json, + output_pydantic=output_pydantic, **kwargs ) ): @@ -1283,6 +1323,8 @@ async def get_response_async( temperature=temperature, stream=False, tools=formatted_tools, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ) @@ -1343,6 +1385,8 @@ async def get_response_async( temperature=temperature, stream=False, # force non-streaming tools=formatted_tools, # Include tools + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ) @@ -1374,6 +1418,8 @@ async def get_response_async( temperature=temperature, stream=stream, tools=formatted_tools, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ): @@ -1389,6 +1435,8 @@ async def get_response_async( messages=messages, temperature=temperature, stream=stream, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ): @@ -1471,6 +1519,8 @@ async def get_response_async( temperature=temperature, stream=False, # Force non-streaming response_format={"type": "json_object"}, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ) @@ -1506,6 +1556,8 @@ async def get_response_async( temperature=temperature, stream=stream, response_format={"type": "json_object"}, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ): @@ -1521,6 +1573,8 @@ async def get_response_async( temperature=temperature, stream=stream, response_format={"type": "json_object"}, + output_json=output_json, + output_pydantic=output_pydantic, **{k:v for k,v in kwargs.items() if k != 'reasoning_steps'} ) ): @@ -1678,11 +1732,33 @@ def _build_completion_params(self, **override_params) -> Dict[str, Any]: # Override with any provided parameters params.update(override_params) + # Handle structured output parameters + output_json = override_params.get('output_json') + output_pydantic = override_params.get('output_pydantic') + + if output_json or output_pydantic: + # Always remove these from params as they're not native litellm parameters + params.pop('output_json', None) + params.pop('output_pydantic', None) + + # Check if this is a Gemini model that supports native structured outputs + if self._is_gemini_model(): + from .model_capabilities import supports_structured_outputs + schema_model = output_json or output_pydantic + + if schema_model and hasattr(schema_model, 'model_json_schema') and supports_structured_outputs(self.model): + schema = schema_model.model_json_schema() + + # Gemini uses response_mime_type and response_schema + params['response_mime_type'] = 'application/json' + params['response_schema'] = schema + + logging.debug(f"Using Gemini native structured output with schema: {json.dumps(schema, indent=2)}") + # Add tool_choice="auto" when tools are provided (unless already specified) if 'tools' in params and params['tools'] and 'tool_choice' not in params: # For Gemini models, use tool_choice to encourage tool usage - # More comprehensive Gemini model detection - if any(prefix in self.model.lower() for prefix in ['gemini', 'gemini/', 'google/gemini']): + if self._is_gemini_model(): try: import litellm # Check if model supports function calling before setting tool_choice diff --git a/src/praisonai-agents/praisonaiagents/llm/model_capabilities.py b/src/praisonai-agents/praisonaiagents/llm/model_capabilities.py index 2cb663694..90edd8640 100644 --- a/src/praisonai-agents/praisonaiagents/llm/model_capabilities.py +++ b/src/praisonai-agents/praisonaiagents/llm/model_capabilities.py @@ -30,6 +30,16 @@ "gpt-4.1-mini", "o4-mini", "o3", + + # Gemini models that support structured outputs + "gemini-2.0-flash", + "gemini-2.0-flash-exp", + "gemini-1.5-pro", + "gemini-1.5-pro-latest", + "gemini-1.5-flash", + "gemini-1.5-flash-latest", + "gemini-1.5-flash-8b", + "gemini-1.5-flash-8b-latest", } # Models that explicitly DON'T support structured outputs @@ -57,16 +67,23 @@ def supports_structured_outputs(model_name: str) -> bool: if not model_name: return False + # Strip provider prefixes (e.g., 'google/', 'openai/', etc.) + model_without_provider = model_name + for prefix in ['google/', 'openai/', 'anthropic/', 'gemini/', 'mistral/', 'deepseek/', 'groq/']: + if model_name.startswith(prefix): + model_without_provider = model_name[len(prefix):] + break + # First check if it's explicitly in the NOT supporting list - if model_name in MODELS_NOT_SUPPORTING_STRUCTURED_OUTPUTS: + if model_without_provider in MODELS_NOT_SUPPORTING_STRUCTURED_OUTPUTS: return False # Then check if it's in the supporting list - if model_name in MODELS_SUPPORTING_STRUCTURED_OUTPUTS: + if model_without_provider in MODELS_SUPPORTING_STRUCTURED_OUTPUTS: return True # For models with version suffixes, check the base model name - base_model = model_name.split('-2024-')[0].split('-2025-')[0] + base_model = model_without_provider.split('-2024-')[0].split('-2025-')[0] if base_model in MODELS_SUPPORTING_STRUCTURED_OUTPUTS: return True