Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 77 additions & 8 deletions src/praisonai-agents/praisonaiagents/llm/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
This demonstrates the protocol-driven approach for Gap 2.
"""

from ..protocols import LLMProviderProtocol
from ..protocols import LLMProviderAdapterProtocol
from ..model_capabilities import GEMINI_INTERNAL_TOOLS
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional


class DefaultAdapter:
Expand All @@ -32,28 +32,86 @@ def supports_structured_output(self) -> bool:

def supports_streaming(self) -> bool:
return True # Most providers support streaming

def supports_streaming_with_tools(self) -> bool:
return True # Most providers support streaming with tools

def get_max_iteration_threshold(self) -> int:
return 10 # Conservative default

def format_tool_result_message(self, function_name: str, tool_result: Any, tool_call_id: Optional[str] = None) -> Dict[str, Any]:
# Standard OpenAI-style tool result message
message = {
"role": "tool",
"content": str(tool_result),
}
if tool_call_id is not None:
message["tool_call_id"] = tool_call_id
else:
# Fallback for backward compatibility
message["tool_call_id"] = f"call_{function_name}"
return message

def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool:
return False # No special handling by default

def get_default_settings(self) -> Dict[str, Any]:
return {} # No provider-specific defaults


class OllamaAdapter(DefaultAdapter):
"""
Ollama-specific provider adapter.

Demonstrates how to extract Ollama-specific logic from llm.py
scattered provider dispatch into a clean adapter.
Handles Ollama's specific quirks:
- Doesn't support streaming with tools reliably
- Needs tool summarization after iteration 1
- Uses natural language tool result format
- Handles empty responses after tool execution
"""

def should_summarize_tools(self, iter_count: int) -> bool:
# Replaces: OLLAMA_SUMMARY_ITERATION_THRESHOLD logic
# Must match LLM.OLLAMA_SUMMARY_ITERATION_THRESHOLD = 1
return iter_count >= 1

def supports_streaming_with_tools(self) -> bool:
# Ollama doesn't reliably support streaming with tools
return False

def get_max_iteration_threshold(self) -> int:
return 1 # Ollama-specific threshold

def format_tool_result_message(self, function_name: str, tool_result: Any, tool_call_id: Optional[str] = None) -> Dict[str, Any]:
# Ollama uses natural language format for tool results
return {
"role": "user",
"content": f"Tool '{function_name}' returned: {tool_result}"
}

def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool:
# Handle Ollama's tendency to return empty responses after tool execution
iteration_count = state.get('iteration_count', 0)
has_tool_results = bool(state.get('accumulated_tool_results'))
response_text = state.get('response_text', '').strip()

if iteration_count >= 1 and has_tool_results and not response_text:
return True # Signal that special handling is needed
return False

def post_tool_iteration(self, state: Dict[str, Any]) -> None:
# Replaces: Ollama-specific post-tool summary branches
if (not state.get('response_text', '').strip() and
state.get('formatted_tools') and
state.get('iteration_count') == 0):
# Add Ollama-specific summary logic here
state['needs_summary'] = True

def get_default_settings(self) -> Dict[str, Any]:
return {
'max_tool_repairs': 2,
'force_tool_usage': 'auto'
}


class AnthropicAdapter(DefaultAdapter):
Expand All @@ -67,7 +125,14 @@ def supports_structured_output(self) -> bool:


class GeminiAdapter(DefaultAdapter):
"""Google Gemini provider adapter."""
"""
Google Gemini provider adapter.

Handles Gemini's specific quirks:
- Has internal tools that need special formatting
- Doesn't support streaming with tools reliably
- Supports structured output
"""

def format_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Replaces: gemini_internal_tools handling in llm.py
Expand All @@ -84,12 +149,16 @@ def format_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
formatted.append(tool)
return formatted

def supports_streaming_with_tools(self) -> bool:
# Gemini has issues with streaming + tools
return False

def supports_structured_output(self) -> bool:
return True


# Provider adapter registry - public for extension
_provider_adapters: Dict[str, LLMProviderProtocol] = {}
_provider_adapters: Dict[str, LLMProviderAdapterProtocol] = {}

# Register core adapters at import time
_default_adapter = DefaultAdapter()
Expand All @@ -100,7 +169,7 @@ def supports_structured_output(self) -> bool:
_provider_adapters['gemini'] = GeminiAdapter()


def add_provider_adapter(name: str, adapter: LLMProviderProtocol) -> None:
def add_provider_adapter(name: str, adapter: LLMProviderAdapterProtocol) -> None:
"""
Register a provider adapter by name.

Expand All @@ -113,7 +182,7 @@ def add_provider_adapter(name: str, adapter: LLMProviderProtocol) -> None:
_provider_adapters[name] = adapter


def get_provider_adapter(name: str) -> LLMProviderProtocol:
def get_provider_adapter(name: str) -> LLMProviderAdapterProtocol:
"""
Get provider adapter by name with fallback to default.

Expand Down
146 changes: 101 additions & 45 deletions src/praisonai-agents/praisonaiagents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,11 @@ def __init__(
self.max_tool_repairs = extra_settings.get('max_tool_repairs', 0) # Will be set to 2 for Ollama if not explicit
self.force_tool_usage = extra_settings.get('force_tool_usage', 'never') # Will be set to 'auto' for Ollama if not explicit

# Apply Ollama-specific defaults after model is set
self._apply_ollama_defaults()
# Initialize provider adapter for dispatch logic
self._provider_adapter = self._initialize_provider_adapter()

# Apply provider-specific defaults after adapter is initialized
self._apply_provider_defaults()

# Token tracking
self.last_token_metrics: Optional[TokenMetrics] = None
Expand Down Expand Up @@ -474,15 +477,75 @@ def console(self):
self._console = _get_console()()
return self._console

def _apply_ollama_defaults(self):
"""Apply Ollama-specific defaults for tool calling reliability."""
def _detect_provider(self) -> str:
"""
Detect provider from model name.

Consolidates all provider detection logic into a single method
that replaces scattered _is_X_provider() calls.

Returns:
Provider name (e.g., "ollama", "anthropic", "gemini", "openai")
"""
if not self.model:
return "default"

model_lower = self.model.lower()

# Parse route prefix for explicit provider routing
provider_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else None

# Explicit provider prefixes take priority
if provider_prefix == "ollama":
return "ollama"
if provider_prefix in {"anthropic", "claude"}:
return "anthropic"
if provider_prefix in {"gemini", "google"} and "gemini" in model_lower:
return "gemini"

# Use existing robust Ollama detection logic first
if self._is_ollama_provider():
# Apply defaults only if not explicitly set by user
if not self._max_tool_repairs_explicit:
self.max_tool_repairs = 2
if not self._force_tool_usage_explicit:
self.force_tool_usage = 'auto'
logging.debug(f"[OLLAMA_RELIABILITY] Applied Ollama defaults: max_tool_repairs={self.max_tool_repairs}, force_tool_usage={self.force_tool_usage}")
return "ollama"

# Check for direct model name patterns (not substrings)
if model_lower.startswith("claude"):
return "anthropic"

if model_lower.startswith("gemini") or model_lower.startswith("google/gemini"):
return "gemini"

# Check base_url for provider hints
base_urls = [self.base_url, os.getenv("OPENAI_BASE_URL", ""), os.getenv("OPENAI_API_BASE", "")]
if any(url and ("ollama" in url.lower() or ":11434" in url) for url in base_urls):
return "ollama"

# Default fallback
return "openai"
Comment on lines +480 to +523
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Tighten _detect_provider() to the route prefix/endpoints, not model substrings.

These checks scan the whole model id for family names, so routed ids like groq/llama-3.3-70b-versatile become Ollama, and any .../google/gemini-* or .../anthropic/claude-* route becomes Gemini/Anthropic even when the transport is OpenAI-compatible. Once that happens, _apply_ollama_defaults(), _supports_streaming_tools(), and _supports_prompt_caching() all pick the wrong adapter, so callers can suddenly get Ollama tool coercion, Gemini no-streaming, or Anthropic cache_control payloads on the wrong backend. This also drops the OPENAI_BASE_URL / OPENAI_API_BASE / :11434 Ollama checks that _is_ollama_provider() still uses, so the same instance can be classified differently in different branches.

🎯 Safer provider detection
     def _detect_provider(self) -> str:
         if not self.model:
             return "default"

         model_lower = self.model.lower()
+        provider_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else None
+
+        if provider_prefix == "ollama":
+            return "ollama"
+        if provider_prefix in {"anthropic", "claude"}:
+            return "anthropic"
+        if provider_prefix in {"gemini", "google"} and "gemini" in model_lower:
+            return "gemini"

-        # Check base_url for provider hints
-        if self.base_url and "ollama" in self.base_url.lower():
+        base_urls = [self.base_url, os.getenv("OPENAI_BASE_URL", ""), os.getenv("OPENAI_API_BASE", "")]
+        if any(url and ("ollama" in url.lower() or ":11434" in url) for url in base_urls):
             return "ollama"

-        # Check model name patterns
-        if any(prefix in model_lower for prefix in ['claude', 'anthropic/']):
+        if model_lower.startswith("claude"):
             return "anthropic"

-        if any(prefix in model_lower for prefix in ['gemini', 'gemini/', 'google/gemini']):
+        if model_lower.startswith("gemini") or model_lower.startswith("google/gemini"):
             return "gemini"
-
-        # Check for Ollama models without prefix
-        ...
 
         return "openai"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/praisonai-agents/praisonaiagents/llm/llm.py` around lines 480 - 521,
_detect_provider currently classifies providers by scanning substrings across
the whole model id which misroutes models like "groq/llama-..." or
"openai/google/gemini-..." to the wrong backend; change it to detect by
route/prefix and transport first: call and honor the existing
_is_ollama_provider() early (so base_url / :11434 / OPENAI_* base checks remain
authoritative), then parse the model route prefix (e.g., take
model.split('/',1)[0] or model_lower.startswith("provider/") ) and only map
providers when the provider token is the route prefix (e.g., "ollama/",
"anthropic/", "google", "gemini/"); fall back to checking explicit base_url
hints (ollama in base_url) and otherwise default to "openai"; this prevents
substring matches from misclassifying and ensures _apply_ollama_defaults(),
_supports_streaming_tools(), and _supports_prompt_caching() get the correct
provider.


def _initialize_provider_adapter(self):
"""Initialize provider adapter based on detected provider."""
provider = self._detect_provider()

# Import here to avoid circular imports
from .adapters import get_provider_adapter
adapter = get_provider_adapter(provider)

logging.debug(f"[ADAPTER] Initialized {provider} adapter: {adapter.__class__.__name__}")
return adapter

def _apply_provider_defaults(self):
"""Apply provider-specific defaults via adapter pattern."""
if hasattr(self, '_provider_adapter') and self._provider_adapter:
defaults = self._provider_adapter.get_default_settings()
if defaults:
# Apply defaults only if not explicitly set by user
if not self._max_tool_repairs_explicit and 'max_tool_repairs' in defaults:
self.max_tool_repairs = defaults['max_tool_repairs']
if not self._force_tool_usage_explicit and 'force_tool_usage' in defaults:
self.force_tool_usage = defaults['force_tool_usage']

if defaults: # Only log if there were actual defaults
logging.debug(f"[PROVIDER_DEFAULTS] Applied {self._provider_adapter.__class__.__name__} defaults: {defaults}")

def _is_ollama_provider(self) -> bool:
"""Detect if this is an Ollama provider regardless of naming convention"""
Expand Down Expand Up @@ -753,14 +816,22 @@ def _supports_web_fetch(self) -> bool:

def _supports_prompt_caching(self) -> bool:
"""
Check if the current model supports prompt caching via LiteLLM.
Check if the current model supports prompt caching.

Prompt caching allows caching parts of prompts to reduce costs and latency.
Supported by OpenAI, Anthropic, Bedrock, and Deepseek.
Uses provider adapter to eliminate scattered provider logic.

Returns:
bool: True if the model supports prompt caching, False otherwise
"""
if hasattr(self, '_provider_adapter') and self._provider_adapter:
adapter_support = self._provider_adapter.supports_prompt_caching()
if adapter_support:
return True
# If adapter says False and is not the default, trust the adapter
if self._provider_adapter.__class__.__name__ != 'DefaultAdapter':
return False

# Fallback to model capabilities for DefaultAdapter or uninitialized
from .model_capabilities import supports_prompt_caching
return supports_prompt_caching(self.model)

Expand Down Expand Up @@ -1270,12 +1341,20 @@ def _handle_ollama_sequential_logic(self, iteration_count: int, accumulated_tool
- final_response_text: Text to use as final response (None if continuing)
- iteration_count: Updated iteration count
"""
if not (self._is_ollama_provider() and iteration_count >= self.OLLAMA_SUMMARY_ITERATION_THRESHOLD):
# Use provider adapter to determine if tool summarization should occur
# Adapter should always be initialized, but provide safe fallback
if hasattr(self, '_provider_adapter') and self._provider_adapter:
should_summarize = self._provider_adapter.should_summarize_tools(iteration_count)
else:
# Conservative fallback without provider detection
should_summarize = iteration_count >= 5 # Conservative default

if not should_summarize:
return False, None, iteration_count

# For Ollama: if we have meaningful tool results, generate summary immediately
# Don't wait for more iterations as Ollama tends to repeat tool calls
if accumulated_tool_results and iteration_count >= self.OLLAMA_SUMMARY_ITERATION_THRESHOLD:
# If we should summarize: if we have meaningful tool results, generate summary immediately
# Don't wait for more iterations as some providers tend to repeat tool calls
if accumulated_tool_results and should_summarize:
# Generate summary from tool results
tool_summary = self._generate_ollama_tool_summary(accumulated_tool_results, response_text)
if tool_summary:
Expand All @@ -1301,42 +1380,19 @@ def _supports_streaming_tools(self) -> bool:
"""
Check if the current provider supports streaming with tools.

Most providers that support tool calling also support streaming with tools,
but some providers (like Ollama and certain local models) require non-streaming
calls when tools are involved.
Uses provider adapter to eliminate scattered provider logic.

Returns:
bool: True if provider supports streaming with tools, False otherwise
"""
if not self.model:
return False

# Ollama doesn't reliably support streaming with tools
if self._is_ollama_provider():
return False

# Import the capability check function
from .model_capabilities import supports_streaming_with_tools

# Check if this model supports streaming with tools
if supports_streaming_with_tools(self.model):
return True

# Anthropic Claude models support streaming with tools
if self.model.startswith("claude-"):
return True

# Google Gemini models support streaming with tools
if any(self.model.startswith(prefix) for prefix in ["gemini-", "gemini/"]):
return True

# Models with XML tool format support streaming with tools
if self._supports_xml_tool_format():
return True
# Use provider adapter for streaming with tools support
if hasattr(self, '_provider_adapter') and self._provider_adapter:
return self._provider_adapter.supports_streaming_with_tools()

# For other providers, default to False to be safe
# This ensures we make a single non-streaming call rather than risk
# missing tool calls or making duplicate calls
# Fallback to conservative default if adapter not initialized
return False

def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_json=None, output_pydantic=None, tools=None):
Expand Down
24 changes: 22 additions & 2 deletions src/praisonai-agents/praisonaiagents/llm/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@runtime_checkable
class LLMProviderProtocol(Protocol):
"""
Protocol defining the interface that LLM providers must implement.
Protocol defining the interface that LLM clients must implement.

This enables switching between different LLM backends (litellm, openai,
anthropic, local models, etc.) without modifying core agent code.
Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(self, tokens: int, max_tokens: int, provider: Optional[str] = None,


@runtime_checkable
class LLMProviderProtocol(Protocol):
class LLMProviderAdapterProtocol(Protocol):
"""
Protocol for provider-specific LLM adaptations.

Expand Down Expand Up @@ -306,6 +306,26 @@ def supports_structured_output(self) -> bool:
def supports_streaming(self) -> bool:
"""Check if provider supports streaming responses."""
...

def supports_streaming_with_tools(self) -> bool:
"""Check if provider supports streaming with tools enabled."""
...

def get_max_iteration_threshold(self) -> int:
"""Get provider-specific maximum iteration count."""
...

def format_tool_result_message(self, function_name: str, tool_result: Any, tool_call_id: Optional[str] = None) -> Dict[str, Any]:
"""Format tool result message for this provider's requirements."""
...

def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool:
"""Handle provider-specific empty response with tools logic."""
...

def get_default_settings(self) -> Dict[str, Any]:
"""Get provider-specific default settings."""
...


@runtime_checkable
Expand Down
Loading