-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: implement LLM provider adapter pattern to eliminate scattered provider branching #1307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -32,21 +32,65 @@ def supports_structured_output(self) -> bool: | |||||
|
|
||||||
| def supports_streaming(self) -> bool: | ||||||
| return True # Most providers support streaming | ||||||
|
|
||||||
| def supports_streaming_with_tools(self) -> bool: | ||||||
| return True # Most providers support streaming with tools | ||||||
|
|
||||||
| def get_max_iteration_threshold(self) -> int: | ||||||
| return 10 # Conservative default | ||||||
|
|
||||||
| def format_tool_result_message(self, function_name: str, tool_result: Any) -> Dict[str, Any]: | ||||||
| # Standard OpenAI-style tool result message | ||||||
| return { | ||||||
| "role": "tool", | ||||||
| "content": str(tool_result), | ||||||
| "tool_call_id": getattr(tool_result, 'tool_call_id', f"call_{function_name}") | ||||||
| } | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
# Show the adapter/protocol contract
rg -n -C2 'def format_tool_result_message\(' src/praisonai-agents/praisonaiagents/llm
# Show whether llm.py has the original tool_call_id available at the call site
rg -n -C4 'format_tool_result_message\(|tool_call_id' src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 13586 🏁 Script executed: # Check if _create_tool_message exists and how it uses the hook
rg -n -A10 'def _create_tool_message' src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 49 🏁 Script executed: # Check if format_tool_result_message is actually called anywhere in llm.py
rg -n 'format_tool_result_message' src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 49 🏁 Script executed: # Show the full OllamaAdapter and GeminiAdapter format_tool_result_message implementations
rg -n -A8 'class OllamaAdapter' src/praisonai-agents/praisonaiagents/llm/adapters/__init__.py | head -30
rg -n -A8 'def format_tool_result_message' src/praisonai-agents/praisonaiagents/llm/adapters/__init__.pyRepository: MervinPraison/PraisonAI Length of output: 1351 Pass The hook signature lacks the actual Update the protocol in Suggested contract update- def format_tool_result_message(self, function_name: str, tool_result: Any) -> Dict[str, Any]:
+ def format_tool_result_message(
+ self,
+ function_name: str,
+ tool_result: Any,
+ tool_call_id: str | None = None,
+ ) -> Dict[str, Any]:
# Standard OpenAI-style tool result message
- return {
+ message = {
"role": "tool",
"content": str(tool_result),
- "tool_call_id": getattr(tool_result, 'tool_call_id', f"call_{function_name}")
}
+ if tool_call_id is not None:
+ message["tool_call_id"] = tool_call_id
+ return messageEnsure any call site passes the actual 🤖 Prompt for AI Agents |
||||||
|
|
||||||
| def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool: | ||||||
| return False # No special handling by default | ||||||
|
|
||||||
|
|
||||||
| class OllamaAdapter(DefaultAdapter): | ||||||
| """ | ||||||
| Ollama-specific provider adapter. | ||||||
|
|
||||||
| Demonstrates how to extract Ollama-specific logic from llm.py | ||||||
| scattered provider dispatch into a clean adapter. | ||||||
| Handles Ollama's specific quirks: | ||||||
| - Doesn't support streaming with tools reliably | ||||||
| - Needs tool summarization after iteration 1 | ||||||
| - Uses natural language tool result format | ||||||
| - Handles empty responses after tool execution | ||||||
| """ | ||||||
|
|
||||||
| def should_summarize_tools(self, iter_count: int) -> bool: | ||||||
| # Replaces: OLLAMA_SUMMARY_ITERATION_THRESHOLD logic | ||||||
| # Must match LLM.OLLAMA_SUMMARY_ITERATION_THRESHOLD = 1 | ||||||
| return iter_count >= 1 | ||||||
|
|
||||||
| def supports_streaming_with_tools(self) -> bool: | ||||||
| # Ollama doesn't reliably support streaming with tools | ||||||
| return False | ||||||
|
|
||||||
| def get_max_iteration_threshold(self) -> int: | ||||||
| return 1 # Ollama-specific threshold | ||||||
|
|
||||||
| def format_tool_result_message(self, function_name: str, tool_result: Any) -> Dict[str, Any]: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the method signature to include
Suggested change
|
||||||
| # Ollama uses natural language format for tool results | ||||||
| return { | ||||||
| "role": "user", | ||||||
| "content": f"Tool '{function_name}' returned: {tool_result}" | ||||||
| } | ||||||
|
|
||||||
| def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool: | ||||||
| # Handle Ollama's tendency to return empty responses after tool execution | ||||||
| iteration_count = state.get('iteration_count', 0) | ||||||
| has_tool_results = bool(state.get('accumulated_tool_results')) | ||||||
| response_text = state.get('response_text', '').strip() | ||||||
|
|
||||||
| if iteration_count >= 1 and has_tool_results and not response_text: | ||||||
| return True # Signal that special handling is needed | ||||||
| return False | ||||||
|
|
||||||
| def post_tool_iteration(self, state: Dict[str, Any]) -> None: | ||||||
| # Replaces: Ollama-specific post-tool summary branches | ||||||
| if (not state.get('response_text', '').strip() and | ||||||
|
|
@@ -67,7 +111,14 @@ def supports_structured_output(self) -> bool: | |||||
|
|
||||||
|
|
||||||
| class GeminiAdapter(DefaultAdapter): | ||||||
| """Google Gemini provider adapter.""" | ||||||
| """ | ||||||
| Google Gemini provider adapter. | ||||||
|
|
||||||
| Handles Gemini's specific quirks: | ||||||
| - Has internal tools that need special formatting | ||||||
| - Doesn't support streaming with tools reliably | ||||||
| - Supports structured output | ||||||
| """ | ||||||
|
|
||||||
| def format_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||||||
| # Replaces: gemini_internal_tools handling in llm.py | ||||||
|
|
@@ -84,6 +135,10 @@ def format_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |||||
| formatted.append(tool) | ||||||
| return formatted | ||||||
|
|
||||||
| def supports_streaming_with_tools(self) -> bool: | ||||||
| # Gemini has issues with streaming + tools | ||||||
| return False | ||||||
|
|
||||||
| def supports_structured_output(self) -> bool: | ||||||
| return True | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -415,6 +415,9 @@ def __init__( | |||||||||||||
| # Apply Ollama-specific defaults after model is set | ||||||||||||||
| self._apply_ollama_defaults() | ||||||||||||||
|
|
||||||||||||||
| # Initialize provider adapter for dispatch logic | ||||||||||||||
| self._provider_adapter = self._initialize_provider_adapter() | ||||||||||||||
|
|
||||||||||||||
| # Token tracking | ||||||||||||||
| self.last_token_metrics: Optional[TokenMetrics] = None | ||||||||||||||
| self.session_token_metrics: Optional[TokenMetrics] = None | ||||||||||||||
|
|
@@ -474,9 +477,63 @@ def console(self): | |||||||||||||
| self._console = _get_console()() | ||||||||||||||
| return self._console | ||||||||||||||
|
|
||||||||||||||
| def _detect_provider(self) -> str: | ||||||||||||||
| """ | ||||||||||||||
| Detect provider from model name. | ||||||||||||||
|
|
||||||||||||||
| Consolidates all provider detection logic into a single method | ||||||||||||||
| that replaces scattered _is_X_provider() calls. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| Provider name (e.g., "ollama", "anthropic", "gemini", "openai") | ||||||||||||||
| """ | ||||||||||||||
| if not self.model: | ||||||||||||||
| return "default" | ||||||||||||||
|
|
||||||||||||||
| model_lower = self.model.lower() | ||||||||||||||
|
|
||||||||||||||
| # Direct model prefixes | ||||||||||||||
| if self.model.startswith("ollama/"): | ||||||||||||||
| return "ollama" | ||||||||||||||
|
|
||||||||||||||
| # Check base_url for provider hints | ||||||||||||||
| if self.base_url and "ollama" in self.base_url.lower(): | ||||||||||||||
| return "ollama" | ||||||||||||||
|
|
||||||||||||||
| # Check model name patterns | ||||||||||||||
| if any(prefix in model_lower for prefix in ['claude', 'anthropic/']): | ||||||||||||||
| return "anthropic" | ||||||||||||||
|
|
||||||||||||||
| if any(prefix in model_lower for prefix in ['gemini', 'gemini/', 'google/gemini']): | ||||||||||||||
| return "gemini" | ||||||||||||||
|
|
||||||||||||||
| # Check for Ollama models without prefix | ||||||||||||||
| ollama_patterns = [ | ||||||||||||||
| 'llama', 'llama2', 'llama3', 'mistral', 'mixtral', 'phi', 'vicuna', | ||||||||||||||
| 'wizardlm', 'orca', 'falcon', 'alpaca', 'wizard-coder', 'starcoder', | ||||||||||||||
| 'codellama', 'phind-codellama', 'deepseek-coder', 'magicoder', | ||||||||||||||
| 'qwen', 'qwen2' | ||||||||||||||
| ] | ||||||||||||||
| if any(pattern in model_lower for pattern in ollama_patterns): | ||||||||||||||
| return "ollama" | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The # Default fallback
return "openai" |
||||||||||||||
|
|
||||||||||||||
|
qodo-code-review[bot] marked this conversation as resolved.
Outdated
|
||||||||||||||
| # Default fallback | ||||||||||||||
| return "openai" | ||||||||||||||
|
Comment on lines
+480
to
+523
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tighten These checks scan the whole model id for family names, so routed ids like 🎯 Safer provider detection def _detect_provider(self) -> str:
if not self.model:
return "default"
model_lower = self.model.lower()
+ provider_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else None
+
+ if provider_prefix == "ollama":
+ return "ollama"
+ if provider_prefix in {"anthropic", "claude"}:
+ return "anthropic"
+ if provider_prefix in {"gemini", "google"} and "gemini" in model_lower:
+ return "gemini"
- # Check base_url for provider hints
- if self.base_url and "ollama" in self.base_url.lower():
+ base_urls = [self.base_url, os.getenv("OPENAI_BASE_URL", ""), os.getenv("OPENAI_API_BASE", "")]
+ if any(url and ("ollama" in url.lower() or ":11434" in url) for url in base_urls):
return "ollama"
- # Check model name patterns
- if any(prefix in model_lower for prefix in ['claude', 'anthropic/']):
+ if model_lower.startswith("claude"):
return "anthropic"
- if any(prefix in model_lower for prefix in ['gemini', 'gemini/', 'google/gemini']):
+ if model_lower.startswith("gemini") or model_lower.startswith("google/gemini"):
return "gemini"
-
- # Check for Ollama models without prefix
- ...
return "openai"🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| def _initialize_provider_adapter(self): | ||||||||||||||
| """Initialize provider adapter based on detected provider.""" | ||||||||||||||
| provider = self._detect_provider() | ||||||||||||||
|
|
||||||||||||||
| # Import here to avoid circular imports | ||||||||||||||
| from .adapters import get_provider_adapter | ||||||||||||||
| adapter = get_provider_adapter(provider) | ||||||||||||||
|
|
||||||||||||||
| logging.debug(f"[ADAPTER] Initialized {provider} adapter: {adapter.__class__.__name__}") | ||||||||||||||
| return adapter | ||||||||||||||
|
|
||||||||||||||
| def _apply_ollama_defaults(self): | ||||||||||||||
| """Apply Ollama-specific defaults for tool calling reliability.""" | ||||||||||||||
| if self._is_ollama_provider(): | ||||||||||||||
| if self._detect_provider() == "ollama": | ||||||||||||||
| # Apply defaults only if not explicitly set by user | ||||||||||||||
| if not self._max_tool_repairs_explicit: | ||||||||||||||
| self.max_tool_repairs = 2 | ||||||||||||||
|
qodo-code-review[bot] marked this conversation as resolved.
Outdated
|
||||||||||||||
|
|
@@ -753,14 +810,17 @@ def _supports_web_fetch(self) -> bool: | |||||||||||||
|
|
||||||||||||||
| def _supports_prompt_caching(self) -> bool: | ||||||||||||||
| """ | ||||||||||||||
| Check if the current model supports prompt caching via LiteLLM. | ||||||||||||||
| Check if the current model supports prompt caching. | ||||||||||||||
|
|
||||||||||||||
| Prompt caching allows caching parts of prompts to reduce costs and latency. | ||||||||||||||
| Supported by OpenAI, Anthropic, Bedrock, and Deepseek. | ||||||||||||||
| Uses provider adapter to eliminate scattered provider logic. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| bool: True if the model supports prompt caching, False otherwise | ||||||||||||||
| """ | ||||||||||||||
| if hasattr(self, '_provider_adapter') and self._provider_adapter: | ||||||||||||||
| return self._provider_adapter.supports_prompt_caching() | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This refactoring introduces a regression for OpenAI and Deepseek models. The
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| # Fallback to model capabilities if adapter not initialized | ||||||||||||||
| from .model_capabilities import supports_prompt_caching | ||||||||||||||
| return supports_prompt_caching(self.model) | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -1270,12 +1330,19 @@ def _handle_ollama_sequential_logic(self, iteration_count: int, accumulated_tool | |||||||||||||
| - final_response_text: Text to use as final response (None if continuing) | ||||||||||||||
| - iteration_count: Updated iteration count | ||||||||||||||
| """ | ||||||||||||||
| if not (self._is_ollama_provider() and iteration_count >= self.OLLAMA_SUMMARY_ITERATION_THRESHOLD): | ||||||||||||||
| # Use provider adapter to determine if tool summarization should occur | ||||||||||||||
| if hasattr(self, '_provider_adapter') and self._provider_adapter: | ||||||||||||||
| should_summarize = self._provider_adapter.should_summarize_tools(iteration_count) | ||||||||||||||
| else: | ||||||||||||||
| # Fallback to original Ollama logic if adapter not initialized | ||||||||||||||
| should_summarize = self._detect_provider() == "ollama" and iteration_count >= self.OLLAMA_SUMMARY_ITERATION_THRESHOLD | ||||||||||||||
|
|
||||||||||||||
| if not should_summarize: | ||||||||||||||
| return False, None, iteration_count | ||||||||||||||
|
|
||||||||||||||
| # For Ollama: if we have meaningful tool results, generate summary immediately | ||||||||||||||
| # Don't wait for more iterations as Ollama tends to repeat tool calls | ||||||||||||||
| if accumulated_tool_results and iteration_count >= self.OLLAMA_SUMMARY_ITERATION_THRESHOLD: | ||||||||||||||
| # If we should summarize: if we have meaningful tool results, generate summary immediately | ||||||||||||||
| # Don't wait for more iterations as some providers tend to repeat tool calls | ||||||||||||||
| if accumulated_tool_results and should_summarize: | ||||||||||||||
| # Generate summary from tool results | ||||||||||||||
| tool_summary = self._generate_ollama_tool_summary(accumulated_tool_results, response_text) | ||||||||||||||
| if tool_summary: | ||||||||||||||
|
|
@@ -1301,42 +1368,19 @@ def _supports_streaming_tools(self) -> bool: | |||||||||||||
| """ | ||||||||||||||
| Check if the current provider supports streaming with tools. | ||||||||||||||
|
|
||||||||||||||
| Most providers that support tool calling also support streaming with tools, | ||||||||||||||
| but some providers (like Ollama and certain local models) require non-streaming | ||||||||||||||
| calls when tools are involved. | ||||||||||||||
| Uses provider adapter to eliminate scattered provider logic. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| bool: True if provider supports streaming with tools, False otherwise | ||||||||||||||
| """ | ||||||||||||||
| if not self.model: | ||||||||||||||
| return False | ||||||||||||||
|
|
||||||||||||||
| # Ollama doesn't reliably support streaming with tools | ||||||||||||||
| if self._is_ollama_provider(): | ||||||||||||||
| return False | ||||||||||||||
|
|
||||||||||||||
| # Import the capability check function | ||||||||||||||
| from .model_capabilities import supports_streaming_with_tools | ||||||||||||||
|
|
||||||||||||||
| # Check if this model supports streaming with tools | ||||||||||||||
| if supports_streaming_with_tools(self.model): | ||||||||||||||
| return True | ||||||||||||||
|
|
||||||||||||||
| # Anthropic Claude models support streaming with tools | ||||||||||||||
| if self.model.startswith("claude-"): | ||||||||||||||
| return True | ||||||||||||||
|
|
||||||||||||||
| # Google Gemini models support streaming with tools | ||||||||||||||
| if any(self.model.startswith(prefix) for prefix in ["gemini-", "gemini/"]): | ||||||||||||||
| return True | ||||||||||||||
|
|
||||||||||||||
| # Models with XML tool format support streaming with tools | ||||||||||||||
| if self._supports_xml_tool_format(): | ||||||||||||||
| return True | ||||||||||||||
| # Use provider adapter for streaming with tools support | ||||||||||||||
| if hasattr(self, '_provider_adapter') and self._provider_adapter: | ||||||||||||||
| return self._provider_adapter.supports_streaming_with_tools() | ||||||||||||||
|
|
||||||||||||||
| # For other providers, default to False to be safe | ||||||||||||||
| # This ensures we make a single non-streaming call rather than risk | ||||||||||||||
| # missing tool calls or making duplicate calls | ||||||||||||||
| # Fallback to conservative default if adapter not initialized | ||||||||||||||
| return False | ||||||||||||||
|
|
||||||||||||||
| def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_json=None, output_pydantic=None, tools=None): | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,9 +11,9 @@ | |||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @runtime_checkable | ||||||||||||||||||
| class LLMProviderProtocol(Protocol): | ||||||||||||||||||
| class LLMClientProtocol(Protocol): | ||||||||||||||||||
| """ | ||||||||||||||||||
| Protocol defining the interface that LLM providers must implement. | ||||||||||||||||||
| Protocol defining the interface that LLM clients must implement. | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preserve the existing Renaming the client-facing protocol to ♻️ Proposed naming split-@runtime_checkable
-class LLMClientProtocol(Protocol):
+@runtime_checkable
+class LLMProviderProtocol(Protocol):
...# outside this hunk, rename the adapter hook protocol as well:
`@runtime_checkable`
class LLMProviderAdapterProtocol(Protocol):
...📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||
|
|
||||||||||||||||||
| This enables switching between different LLM backends (litellm, openai, | ||||||||||||||||||
| anthropic, local models, etc.) without modifying core agent code. | ||||||||||||||||||
|
|
@@ -306,6 +306,22 @@ def supports_structured_output(self) -> bool: | |||||||||||||||||
| def supports_streaming(self) -> bool: | ||||||||||||||||||
| """Check if provider supports streaming responses.""" | ||||||||||||||||||
| ... | ||||||||||||||||||
|
|
||||||||||||||||||
| def supports_streaming_with_tools(self) -> bool: | ||||||||||||||||||
| """Check if provider supports streaming with tools enabled.""" | ||||||||||||||||||
| ... | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_max_iteration_threshold(self) -> int: | ||||||||||||||||||
| """Get provider-specific maximum iteration count.""" | ||||||||||||||||||
| ... | ||||||||||||||||||
|
|
||||||||||||||||||
| def format_tool_result_message(self, function_name: str, tool_result: Any) -> Dict[str, Any]: | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||
| """Format tool result message for this provider's requirements.""" | ||||||||||||||||||
| ... | ||||||||||||||||||
|
|
||||||||||||||||||
| def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool: | ||||||||||||||||||
| """Handle provider-specific empty response with tools logic.""" | ||||||||||||||||||
| ... | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @runtime_checkable | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of
format_tool_result_messageshould not rely ongetattr(tool_result, 'tool_call_id', ...)astool_resultis typically the raw output from the tool execution. Thetool_call_idshould be passed as an explicit argument to the method to ensure correct message formatting for providers that follow the OpenAI tool calling standard.