Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 88 additions & 12 deletions src/praisonai-agents/praisonaiagents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,22 @@ def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_
"""
messages = []

# Check if this is a Gemini model that supports native structured outputs
is_gemini_with_structured_output = False
if output_json or output_pydantic:
from .model_capabilities import supports_structured_outputs
is_gemini_with_structured_output = (
self._is_gemini_model() and
supports_structured_outputs(self.model)
)

# Handle system prompt
if system_prompt:
# Append JSON schema if needed
if output_json:
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {json.dumps(output_json.model_json_schema())}"
elif output_pydantic:
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {json.dumps(output_pydantic.model_json_schema())}"
# Only append JSON schema for non-Gemini models or Gemini models without structured output support
if (output_json or output_pydantic) and not is_gemini_with_structured_output:
schema_model = output_json or output_pydantic
if schema_model and hasattr(schema_model, 'model_json_schema'):
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {json.dumps(schema_model.model_json_schema())}"

# Skip system messages for legacy o1 models as they don't support them
if not self._needs_system_message_skip():
Expand All @@ -440,7 +449,8 @@ def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_

# Handle prompt modifications for JSON output
original_prompt = prompt
if output_json or output_pydantic:
if (output_json or output_pydantic) and not is_gemini_with_structured_output:
# Only modify prompt for non-Gemini models
if isinstance(prompt, str):
prompt = prompt + "\nReturn ONLY a valid JSON object. No other text or explanation."
elif isinstance(prompt, list):
Expand Down Expand Up @@ -695,6 +705,8 @@ def get_response(
temperature=temperature,
stream=False, # force non-streaming
tools=formatted_tools,
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
)
Expand Down Expand Up @@ -741,6 +753,8 @@ def get_response(
tools=formatted_tools,
temperature=temperature,
stream=True,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
):
Expand All @@ -760,6 +774,8 @@ def get_response(
tools=formatted_tools,
temperature=temperature,
stream=True,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
):
Expand Down Expand Up @@ -791,6 +807,8 @@ def get_response(
tools=formatted_tools,
temperature=temperature,
stream=False,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
)
Expand Down Expand Up @@ -944,6 +962,8 @@ def get_response(
temperature=temperature,
stream=False, # Force non-streaming
response_format={"type": "json_object"},
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
)
Expand Down Expand Up @@ -979,6 +999,8 @@ def get_response(
temperature=temperature,
stream=stream,
response_format={"type": "json_object"},
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
):
Expand All @@ -994,6 +1016,8 @@ def get_response(
temperature=temperature,
stream=stream,
response_format={"type": "json_object"},
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
):
Expand Down Expand Up @@ -1039,6 +1063,8 @@ def get_response(
messages=messages,
temperature=temperature,
stream=True,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
):
Expand All @@ -1053,6 +1079,8 @@ def get_response(
messages=messages,
temperature=temperature,
stream=True,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
):
Expand Down Expand Up @@ -1089,6 +1117,12 @@ def get_response(
total_time = time.time() - start_time
logging.debug(f"get_response completed in {total_time:.2f} seconds")

def _is_gemini_model(self) -> bool:
"""Check if the model is a Gemini model."""
if not self.model:
return False
return any(prefix in self.model.lower() for prefix in ['gemini', 'gemini/', 'google/gemini'])

async def get_response_async(
self,
prompt: Union[str, List[Dict]],
Expand Down Expand Up @@ -1197,10 +1231,12 @@ async def get_response_async(
resp = await litellm.acompletion(
**self._build_completion_params(
messages=messages,
temperature=temperature,
stream=False, # force non-streaming
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
temperature=temperature,
stream=False, # force non-streaming
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
)
reasoning_content = resp["choices"][0]["message"].get("provider_specific_fields", {}).get("reasoning_content")
response_text = resp["choices"][0]["message"]["content"]
Expand Down Expand Up @@ -1239,6 +1275,8 @@ async def get_response_async(
temperature=temperature,
stream=True,
tools=formatted_tools,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
):
Expand All @@ -1259,6 +1297,8 @@ async def get_response_async(
temperature=temperature,
stream=True,
tools=formatted_tools,
output_json=output_json,
output_pydantic=output_pydantic,
**kwargs
)
):
Expand All @@ -1283,6 +1323,8 @@ async def get_response_async(
temperature=temperature,
stream=False,
tools=formatted_tools,
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
)
Expand Down Expand Up @@ -1343,6 +1385,8 @@ async def get_response_async(
temperature=temperature,
stream=False, # force non-streaming
tools=formatted_tools, # Include tools
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
)
Expand Down Expand Up @@ -1374,6 +1418,8 @@ async def get_response_async(
temperature=temperature,
stream=stream,
tools=formatted_tools,
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
):
Expand All @@ -1389,6 +1435,8 @@ async def get_response_async(
messages=messages,
temperature=temperature,
stream=stream,
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
):
Expand Down Expand Up @@ -1471,6 +1519,8 @@ async def get_response_async(
temperature=temperature,
stream=False, # Force non-streaming
response_format={"type": "json_object"},
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
)
Expand Down Expand Up @@ -1506,6 +1556,8 @@ async def get_response_async(
temperature=temperature,
stream=stream,
response_format={"type": "json_object"},
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
):
Expand All @@ -1521,6 +1573,8 @@ async def get_response_async(
temperature=temperature,
stream=stream,
response_format={"type": "json_object"},
output_json=output_json,
output_pydantic=output_pydantic,
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
)
):
Expand Down Expand Up @@ -1678,11 +1732,33 @@ def _build_completion_params(self, **override_params) -> Dict[str, Any]:
# Override with any provided parameters
params.update(override_params)

# Handle structured output parameters
output_json = override_params.get('output_json')
output_pydantic = override_params.get('output_pydantic')

if output_json or output_pydantic:
# Always remove these from params as they're not native litellm parameters
params.pop('output_json', None)
params.pop('output_pydantic', None)

# Check if this is a Gemini model that supports native structured outputs
if self._is_gemini_model():
from .model_capabilities import supports_structured_outputs
schema_model = output_json or output_pydantic

if schema_model and hasattr(schema_model, 'model_json_schema') and supports_structured_outputs(self.model):
schema = schema_model.model_json_schema()

# Gemini uses response_mime_type and response_schema
params['response_mime_type'] = 'application/json'
params['response_schema'] = schema

logging.debug(f"Using Gemini native structured output with schema: {json.dumps(schema, indent=2)}")

# Add tool_choice="auto" when tools are provided (unless already specified)
if 'tools' in params and params['tools'] and 'tool_choice' not in params:
# For Gemini models, use tool_choice to encourage tool usage
# More comprehensive Gemini model detection
if any(prefix in self.model.lower() for prefix in ['gemini', 'gemini/', 'google/gemini']):
if self._is_gemini_model():
try:
import litellm
# Check if model supports function calling before setting tool_choice
Expand Down
23 changes: 20 additions & 3 deletions src/praisonai-agents/praisonaiagents/llm/model_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
"gpt-4.1-mini",
"o4-mini",
"o3",

# Gemini models that support structured outputs
"gemini-2.0-flash",
"gemini-2.0-flash-exp",
"gemini-1.5-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-8b",
"gemini-1.5-flash-8b-latest",
}

# Models that explicitly DON'T support structured outputs
Expand Down Expand Up @@ -57,16 +67,23 @@ def supports_structured_outputs(model_name: str) -> bool:
if not model_name:
return False

# Strip provider prefixes (e.g., 'google/', 'openai/', etc.)
model_without_provider = model_name
for prefix in ['google/', 'openai/', 'anthropic/', 'gemini/', 'mistral/', 'deepseek/', 'groq/']:
if model_name.startswith(prefix):
model_without_provider = model_name[len(prefix):]
break

# First check if it's explicitly in the NOT supporting list
if model_name in MODELS_NOT_SUPPORTING_STRUCTURED_OUTPUTS:
if model_without_provider in MODELS_NOT_SUPPORTING_STRUCTURED_OUTPUTS:
return False

# Then check if it's in the supporting list
if model_name in MODELS_SUPPORTING_STRUCTURED_OUTPUTS:
if model_without_provider in MODELS_SUPPORTING_STRUCTURED_OUTPUTS:
return True

# For models with version suffixes, check the base model name
base_model = model_name.split('-2024-')[0].split('-2025-')[0]
base_model = model_without_provider.split('-2024-')[0].split('-2025-')[0]
if base_model in MODELS_SUPPORTING_STRUCTURED_OUTPUTS:
return True

Expand Down
Loading