diff --git a/.claude/settings.local.json b/.claude/settings.local.json index df8a1f11a..345e6e1b1 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -22,6 +22,8 @@ "Bash(gh release view:*)", "Bash(gh run list:*)", "Bash(gh run view:*)", + "Bash(git add:*)", + "Bash(git commit:*)", "Bash(git log:*)", "Bash(git ls-tree:*)", "Bash(git show:*)", @@ -63,6 +65,8 @@ "mcp__memory__read_graph", "mcp__memory__search_nodes", "mcp__perplexity-ask__perplexity_ask", + "mcp__plugin_context7_context7__query-docs", + "mcp__plugin_context7_context7__resolve-library-id", "mcp__sequential-thinking__sequentialthinking", "Skill(research-enforcer)", "WebFetch", diff --git a/.github/workflows/test_gaia_cli_linux.yml b/.github/workflows/test_gaia_cli_linux.yml index c8b25641d..f291881a6 100644 --- a/.github/workflows/test_gaia_cli_linux.yml +++ b/.github/workflows/test_gaia_cli_linux.yml @@ -269,7 +269,7 @@ jobs: from gaia.cli import main from gaia.version import version from gaia.logger import get_logger - from gaia.llm.llm_client import LLMClient + from gaia.llm import LLMClient print('✅ All core imports successful on Linux') except ImportError as e: print(f'❌ Import error: {e}') diff --git a/docs/spec/llm-client.mdx b/docs/spec/llm-client.mdx index 208d7224c..fb486a525 100644 --- a/docs/spec/llm-client.mdx +++ b/docs/spec/llm-client.mdx @@ -5,28 +5,50 @@ icon: "brain" --- - **Source Code:** [`src/gaia/llm/llm_client.py`](https://github.com/amd/gaia/blob/main/src/gaia/llm/llm_client.py) + **Source Code:** + - [`src/gaia/llm/__init__.py`](https://github.com/amd/gaia/blob/main/src/gaia/llm/__init__.py) - Package exports + - [`src/gaia/llm/base_client.py`](https://github.com/amd/gaia/blob/main/src/gaia/llm/base_client.py) - Abstract interface + - [`src/gaia/llm/factory.py`](https://github.com/amd/gaia/blob/main/src/gaia/llm/factory.py) - Client factory + - [`src/gaia/llm/providers/`](https://github.com/amd/gaia/blob/main/src/gaia/llm/providers/) - Provider implementations -**Component:** LLMClient -**Module:** `gaia.llm.llm_client` -**Import:** `from gaia.llm import LLMClient` +**Primary API:** `create_client()` factory function +**Module:** `gaia.llm` +**Imports:** +- `from gaia.llm import create_client` (preferred) +- `from gaia.llm import LLMClient, NotSupportedError` --- ## Overview -LLMClient provides a unified interface for generating text from multiple LLM backends (local Lemonade server, Claude API, OpenAI API). It handles connection management, retry logic, streaming responses, and performance monitoring with automatic endpoint selection and base URL normalization. +The LLM client package provides a unified interface for generating text from multiple LLM backends using a **provider pattern**. Each provider implements the abstract `LLMClient` interface, with optional methods raising `NotSupportedError` when unavailable. **Key Features:** -- Multi-backend support (local, Claude, OpenAI) -- Automatic retry with exponential backoff -- Streaming and non-streaming generation -- Performance statistics tracking -- Generation halting/interruption -- Context manager for resource cleanup +- **Factory-based client creation** with `create_client()` +- **Three providers**: Lemonade (local AMD-optimized), OpenAI, Claude +- **Abstract base class** for type safety and extensibility +- **Graceful handling** of unsupported features via `NotSupportedError` +- **Streaming and non-streaming** generation +- **Backward-compatible** `use_claude`/`use_openai` flags + +**Provider Capabilities:** + +| Method | Lemonade | OpenAI | Claude | +|--------|:--------:|:------:|:------:| +| `generate()` | ✓ | ✓ | ✓ | +| `chat()` | ✓ | ✓ | ✓ | +| `embed()` | ✓ | ✓ | ✗ | +| `vision()` | ✓ | ✗ | ✓ | +| `get_performance_stats()` | ✓ | ✗ | ✗ | +| `load_model()` | ✓ | ✗ | ✗ | +| `unload_model()` | ✓ | ✗ | ✗ | + + +Methods marked with ✗ raise `NotSupportedError` when called on that provider. + --- @@ -34,652 +56,707 @@ LLMClient provides a unified interface for generating text from multiple LLM bac ### Functional Requirements -1. **Multi-Backend Support** - - Local LLM via Lemonade server (default) - - Anthropic Claude API - - OpenAI ChatGPT API - - Automatic base URL normalization - -2. **Generation Interface** - - `generate()` - Generate text with prompt - - Streaming and non-streaming modes - - System prompt support - - Temperature and other parameters - - Messages array support for chat - -3. **Connection Management** - - Configurable timeouts (connect, read, write, pool) - - Connection pooling - - Retry logic with exponential backoff - - Connection error handling - -4. **Performance Monitoring** - - `get_performance_stats()` - Token counts, timing - - `is_generating()` - Check generation status - - `halt_generation()` - Stop current generation - -5. **Error Handling** - - Network error detection and retry - - Timeout handling - - API endpoint validation - - Clear error messages with fix suggestions +1. **Factory Pattern** + - `create_client()` factory function for client creation + - Explicit provider selection via `provider` parameter ("lemonade", "openai", "claude") + - Backward-compatible `use_claude`/`use_openai` flags + - Auto-detection of provider from flags when `provider` not specified + - Default to Lemonade provider when no flags set + +2. **Abstract Interface** + - `LLMClient` ABC defines unified interface + - `provider_name` property returns provider name + - **Required methods** (all providers must implement): + - `generate()` - Text completion + - `chat()` - Chat completion with message history + - **Optional methods** (raise `NotSupportedError` if not implemented): + - `embed()` - Generate embeddings + - `vision()` - Vision/image understanding + - `get_performance_stats()` - Performance statistics + - `load_model()` - Load a model + - `unload_model()` - Unload current model + +3. **Provider Implementations** + - **LemonadeProvider**: Full support for all methods, connects to local Lemonade server + - **OpenAIProvider**: `generate`, `chat`, `embed` only + - **ClaudeProvider**: `generate`, `chat`, `vision` only + - All providers support streaming and non-streaming modes + +4. **Error Handling** + - `NotSupportedError` raised for unsupported methods + - Clear error messages indicating provider and unsupported method + - Connection errors handled by underlying provider implementations ### Non-Functional Requirements 1. **Performance** - - Fast connection establishment (15s timeout) - - Streaming with 120s read timeout - - Efficient token counting - - Minimal overhead + - Lazy provider loading via `importlib` (load only when needed) + - Minimal overhead from abstraction layer + - Streaming support across all providers + - Default temperature of 0.1 for deterministic responses 2. **Reliability** - - Automatic retry on transient failures - - Exponential backoff (base: 1s, max: 60s) - - Configurable max retries (default: 3) - - Connection pool management + - Type safety through ABC pattern + - Graceful handling of unsupported features + - Clear error messages for provider capabilities + - Provider-specific connection management 3. **Usability** - - Simple initialization - - Sensible defaults - - Clear documentation - - Helpful error messages + - Simple factory function interface + - Backward compatibility with existing code + - Consistent API across all providers + - Clear documentation with examples --- ## API Specification -### File Location +### Package Structure ``` -src/gaia/llm/llm_client.py +src/gaia/llm/ +├── __init__.py # Package exports +├── base_client.py # Abstract LLMClient interface +├── factory.py # create_client() factory function +├── exceptions.py # NotSupportedError +├── lemonade_client.py # Low-level REST client for Lemonade +└── providers/ + ├── lemonade.py # LemonadeProvider + ├── openai_provider.py # OpenAIProvider + └── claude.py # ClaudeProvider ``` -### Public Interface +### Package Exports (`__init__.py`) ```python -from typing import Any, Dict, Iterator, List, Literal, Optional, Union -import httpx -from openai import OpenAI +from gaia.llm import create_client, LLMClient, NotSupportedError +``` + +--- -class LLMClient: +### Factory Function (`factory.py`) + +```python +def create_client( + provider: Optional[str] = None, + use_claude: bool = False, + use_openai: bool = False, + **kwargs, +) -> LLMClient: """ - Unified LLM client for local, Claude, and OpenAI backends. + Create an LLM client, auto-detecting provider from parameters. + + Args: + provider: Explicit provider name ("lemonade", "openai", or "claude"). + If not specified, auto-detected from use_claude/use_openai flags. + use_claude: If True, use Claude provider (ignored if provider is specified) + use_openai: If True, use OpenAI provider (ignored if provider is specified) + **kwargs: Provider-specific arguments (base_url, model, api_key, etc.) + + Returns: + LLMClient instance for the specified or detected provider - Usage: - # Local LLM (default) - client = LLMClient() - response = client.generate("Hello world") + Raises: + ValueError: If provider is unknown or both use_claude and use_openai are True - # Claude API - client = LLMClient(use_claude=True) - response = client.generate("Hello world") + Examples: + # Default Lemonade provider + client = create_client() - # OpenAI API - client = LLMClient(use_openai=True) - response = client.generate("Hello world") + # Explicit provider selection + client = create_client(provider="lemonade", model="Qwen2.5-0.5B-Instruct-CPU") + client = create_client(provider="openai", api_key="sk-...") + client = create_client(provider="claude", api_key="sk-ant-...") - # With custom base URL - client = LLMClient(base_url="http://remote-server:8000") + # Backward-compatible flags + client = create_client(use_claude=True, api_key="sk-ant-...") + client = create_client(use_openai=True, api_key="sk-...") - # With streaming - for chunk in client.generate("Hello", stream=True): - print(chunk, end="") + Note: + Provider defaults to "lemonade" when no flags are set. + The design maintains backward compatibility while allowing explicit provider selection. """ +``` - def __init__( - self, - use_claude: bool = False, - use_openai: bool = False, - system_prompt: Optional[str] = None, - base_url: Optional[str] = None, - claude_model: str = "claude-sonnet-4-20250514", - max_retries: int = 3, - retry_base_delay: float = 1.0, - ): - """ - Initialize the LLM client. +--- - Args: - use_claude: If True, uses Anthropic Claude API. - use_openai: If True, uses OpenAI ChatGPT API. - system_prompt: Default system prompt to use for all generation requests. - base_url: Base URL for local LLM server (defaults to LEMONADE_BASE_URL env var). - Automatically normalized to include /api/v1 suffix if needed. - claude_model: Claude model to use (e.g., "claude-sonnet-4-20250514"). - max_retries: Maximum number of retry attempts on connection errors. - retry_base_delay: Base delay in seconds for exponential backoff. +### Abstract Base Class (`base_client.py`) - Note: - - Uses local LLM server by default unless use_claude or use_openai is True. - - Context size is configured when starting the Lemonade server. - - Base URL normalization: "http://localhost:8000" -> "http://localhost:8000/api/v1" +```python +from abc import ABC, abstractmethod +from typing import Iterator, Union - Environment Variables: - LEMONADE_BASE_URL: Default base URL for local LLM server - OPENAI_API_KEY: Required when use_openai=True - """ - pass +class LLMClient(ABC): + """ + Unified LLM client interface. + Methods raise NotSupportedError if not available for this provider. + """ + + @property + @abstractmethod + def provider_name(self) -> str: + """Return the provider name for error messages.""" + ... + + @abstractmethod def generate( self, prompt: str, - model: Optional[str] = None, - endpoint: Optional[Literal["completions", "chat", "claude", "openai"]] = None, - system_prompt: Optional[str] = None, + model: str | None = None, stream: bool = False, - messages: Optional[List[Dict[str, str]]] = None, - **kwargs: Any, + **kwargs, ) -> Union[str, Iterator[str]]: """ - Generate a response from the LLM. + Generate text completion. Args: - prompt: The user prompt/query to send to the LLM. For chat endpoint, - if messages is not provided, this is treated as a pre-formatted - prompt string that already contains the full conversation. - model: The model to use (defaults to endpoint-appropriate model) - endpoint: Override the endpoint to use (completions, chat, claude, or openai) - system_prompt: System prompt to use for this specific request (overrides default) + prompt: The user prompt/query to send to the LLM + model: The model to use (defaults to provider's default model) stream: If True, returns a generator that yields chunks of the response - messages: Optional list of message dicts with 'role' and 'content' keys. - If provided, these are used directly for chat completions instead of prompt. - **kwargs: Additional parameters to pass to the API (temperature, max_tokens, etc.) + **kwargs: Additional parameters (temperature, max_tokens, etc.) Returns: If stream=False: The complete generated text as a string If stream=True: A generator yielding chunks of the response - Raises: - ConnectionError: Network or server connection issues - ValueError: Invalid parameters or configuration - Example: - # Non-streaming response = client.generate("Write a hello world program") - print(response) + """ + ... + + @abstractmethod + def chat( + self, + messages: list[dict], + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + """ + Chat completion with message history. - # Streaming - for chunk in client.generate("Write a story", stream=True): - print(chunk, end="", flush=True) + Args: + messages: List of message dicts with 'role' and 'content' keys + model: The model to use (defaults to provider's default model) + stream: If True, returns a generator that yields chunks of the response + **kwargs: Additional parameters (temperature, max_tokens, etc.) - # With messages array (proper chat history) + Returns: + If stream=False: The complete generated text as a string + If stream=True: A generator yielding chunks of the response + + Example: messages = [ - {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, - {"role": "user", "content": "Tell me a joke"} + {"role": "user", "content": "How are you?"} ] - response = client.generate("", messages=messages) + response = client.chat(messages) """ - pass + ... - def get_performance_stats(self) -> Dict[str, Any]: + # Optional methods - default raises NotSupportedError + def embed(self, texts: list[str], **kwargs) -> list[list[float]]: """ - Get performance statistics from the last LLM request. + Generate embeddings for texts. + + Args: + texts: List of text strings to embed + **kwargs: Additional parameters (e.g., model="text-embedding-3-small" for OpenAI) Returns: - Dictionary containing performance statistics: - - time_to_first_token: Time in seconds until first token is generated - - tokens_per_second: Rate of token generation - - input_tokens: Number of tokens in the input - - output_tokens: Number of tokens in the output + List of embedding vectors (list of floats) - Note: - Only available for local LLM server. Returns empty dict for API backends. + Raises: + NotSupportedError: If provider doesn't support embeddings - Example: - >>> response = client.generate("Hello") - >>> stats = client.get_performance_stats() - >>> print(f"Speed: {stats['tokens_per_second']:.1f} tokens/sec") - Speed: 45.3 tokens/sec + Note: + Supported by: Lemonade, OpenAI (default: "text-embedding-3-small") + Not supported by: Claude """ - pass + raise NotSupportedError(self.provider_name, "embed") - def is_generating(self) -> bool: + def vision(self, images: list[bytes], prompt: str, **kwargs) -> str: """ - Check if the local LLM is currently generating. + Vision/image understanding. + + Args: + images: List of image data as bytes + prompt: Text prompt describing what to analyze + **kwargs: Additional parameters Returns: - True if generating, False otherwise + Text response describing the image - Note: - Only available when using local LLM (use_local=True). - Returns False for OpenAI/Claude API usage. + Raises: + NotSupportedError: If provider doesn't support vision - Example: - >>> client.is_generating() - False - >>> # Start generation in background thread - >>> client.is_generating() - True + Note: + Supported by: Lemonade, Claude + Not supported by: OpenAI """ - pass + raise NotSupportedError(self.provider_name, "vision") - def halt_generation(self) -> bool: + def get_performance_stats(self) -> dict: """ - Halt current generation on the local LLM server. + Get performance statistics from the last LLM request. Returns: - True if halt was successful, False otherwise + Dictionary containing performance statistics - Note: - Only available when using local LLM (use_local=True). - Does nothing for OpenAI/Claude API usage. + Raises: + NotSupportedError: If provider doesn't support performance stats - Example: - >>> if client.is_generating(): - ... client.halt_generation() - ... print("Generation stopped") - Generation stopped + Note: + Only supported by: Lemonade """ - pass + raise NotSupportedError(self.provider_name, "get_performance_stats") - def _retry_with_exponential_backoff( - self, - func: Callable[..., T], - *args, - **kwargs, - ) -> T: + def load_model(self, model_name: str, **kwargs) -> None: """ - Execute a function with exponential backoff retry on connection errors. + Load a specific model. Args: - func: The function to execute - *args: Positional arguments for the function - **kwargs: Keyword arguments for the function - - Returns: - The result of the function call + model_name: Name of the model to load + **kwargs: Additional parameters Raises: - The last exception if all retries are exhausted + NotSupportedError: If provider doesn't support model loading Note: - - Base delay: 1.0 seconds (configurable) - - Exponential base: 2.0 - - Max delay: 60.0 seconds - - Retries on: ConnectionError, httpx errors, requests errors + Only supported by: Lemonade """ - pass + raise NotSupportedError(self.provider_name, "load_model") - def _clean_claude_response(self, response: str) -> str: + def unload_model(self) -> None: """ - Extract valid JSON from Claude responses that may contain extra content. - - Args: - response: The raw response from Claude API + Unload the current model. - Returns: - Cleaned response with only the JSON portion (if JSON detected) + Raises: + NotSupportedError: If provider doesn't support model unloading Note: - Claude sometimes returns valid JSON followed by additional text. - This method extracts just the JSON part by matching braces. + Only supported by: Lemonade """ - pass + raise NotSupportedError(self.provider_name, "unload_model") ``` --- -## Implementation Details +### NotSupportedError (`exceptions.py`) -### Connection Configuration - -**Local LLM (Lemonade Server):** ```python -self.client = OpenAI( - base_url=base_url, # Default: http://localhost:8000/api/v1 - api_key="None", # Not needed for local server - timeout=httpx.Timeout( - connect=15.0, # 15 seconds to establish connection - read=120.0, # 120 seconds between data chunks (matches Lemonade) - write=15.0, # 15 seconds to send request - pool=15.0, # 15 seconds to acquire connection from pool - ), - max_retries=0, # Disable built-in retries (use custom retry logic) -) -``` +class NotSupportedError(Exception): + """Raised when a provider doesn't support a method.""" -**Claude API:** -```python -from gaia.eval.claude import ClaudeClient -self.claude_client = ClaudeClient(model=claude_model) + def __init__(self, provider: str, method: str): + super().__init__(f"{provider} does not support {method}") ``` -**OpenAI API:** +--- + +### Provider Implementations + +#### LemonadeProvider (`providers/lemonade.py`) + +**Full feature support** - implements all methods. + ```python -self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) +class LemonadeProvider(LLMClient): + """Lemonade provider - local AMD-optimized inference.""" + + def __init__( + self, + model: Optional[str] = None, + base_url: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + system_prompt: Optional[str] = None, + **kwargs, + ): + """ + Initialize Lemonade provider. + + Args: + model: Model name (defaults to "Qwen2.5-0.5B-Instruct-CPU") + base_url: Base URL for Lemonade server (overrides LEMONADE_BASE_URL env var) + host: Server host (alternative to base_url) + port: Server port (alternative to base_url) + system_prompt: Default system prompt for chat + **kwargs: Additional arguments passed to LemonadeClient + + Environment: + LEMONADE_BASE_URL: Default base URL (http://localhost:8000/api/v1) + LEMONADE_MODEL: Default model name if not specified + + Note: + Default model is "Qwen2.5-0.5B-Instruct-CPU" for CPU-only inference. + All methods use temperature=0.1 by default for deterministic responses. + """ + + # Supports all methods: generate, chat, embed, vision, + # get_performance_stats, load_model, unload_model ``` -### Base URL Normalization +#### OpenAIProvider (`providers/openai_provider.py`) + +**Partial support** - `generate`, `chat`, `embed` only. ```python -# Normalize base_url to ensure it has the /api/v1 suffix -if base_url and not base_url.endswith("/api/v1"): - base_url = base_url.rstrip("/") - from urllib.parse import urlparse - parsed = urlparse(base_url) - # Only add /api/v1 if path is empty or just "/" - if not parsed.path or parsed.path == "/": - base_url = f"{base_url}/api/v1" +class OpenAIProvider(LLMClient): + """OpenAI (OpenAI API) provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-4o", + system_prompt: Optional[str] = None, + **_kwargs, + ): + """ + Initialize OpenAI provider. + + Args: + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + model: Model name (default: "gpt-4o") + system_prompt: Default system prompt for chat + + Environment: + OPENAI_API_KEY: API key for OpenAI + """ + + # Supports: generate, chat, embed + # Raises NotSupportedError: vision, get_performance_stats, load_model, unload_model ``` -### Retry Logic +#### ClaudeProvider (`providers/claude.py`) + +**Partial support** - `generate`, `chat`, `vision` only. ```python -def _retry_with_exponential_backoff(self, func, *args, **kwargs): - delay = self.retry_base_delay # 1.0 seconds - max_delay = 60.0 - exponential_base = 2.0 - - for attempt in range(self.max_retries + 1): - try: - return func(*args, **kwargs) - except (ConnectionError, httpx.ConnectError, httpx.TimeoutException, - httpx.NetworkError, requests.exceptions.ConnectionError, - requests.exceptions.Timeout) as e: - if attempt == self.max_retries: - raise - - wait_time = min(delay, max_delay) - logger.warning( - f"Connection error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. " - f"Retrying in {wait_time:.1f}s..." - ) - time.sleep(wait_time) - delay *= exponential_base +class ClaudeProvider(LLMClient): + """Claude (Anthropic) provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "claude-3-5-sonnet-20241022", + system_prompt: Optional[str] = None, + **_kwargs, + ): + """ + Initialize Claude provider. + + Args: + api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) + model: Model name (default: "claude-3-5-sonnet-20241022") + system_prompt: Default system prompt for chat + + Environment: + ANTHROPIC_API_KEY: API key for Anthropic Claude + + Raises: + ImportError: If anthropic package not installed + """ + + # Supports: generate, chat, vision + # Raises NotSupportedError: embed, get_performance_stats, load_model, unload_model ``` -### Endpoint Selection +--- + +## Implementation Details + +### Provider Selection Logic + +The factory function auto-detects the provider based on parameters: ```python -# Completions endpoint (pre-formatted prompts, ChatSDK compatibility) -if endpoint_to_use == "completions": - response = self.client.completions.create( - model=model, - prompt=prompt, # Full formatted conversation - temperature=0.1, - stream=stream, - **kwargs, - ) - -# Chat endpoint (proper message history) -elif endpoint_to_use == "chat": - chat_messages = messages or [{"role": "user", "content": prompt}] - if effective_system_prompt: - chat_messages.insert(0, {"role": "system", "content": effective_system_prompt}) - - response = self.client.chat.completions.create( - model=model, - messages=chat_messages, - temperature=0.1, - stream=stream, - **kwargs, - ) +# From factory.py +def create_client(provider=None, use_claude=False, use_openai=False, **kwargs): + # Auto-detect provider from flags if not explicitly specified + if provider is None: + if use_claude and use_openai: + raise ValueError("Cannot specify both use_claude and use_openai") + elif use_claude: + provider = "claude" + elif use_openai: + provider = "openai" + else: + provider = "lemonade" # Default + + # Validate provider + if provider.lower() not in _PROVIDERS: + available = ", ".join(_PROVIDERS.keys()) + raise ValueError(f"Unknown provider: {provider}. Available: {available}") + + # Load provider class dynamically... ``` -### Error Handling +### Lazy Provider Loading + +Providers are loaded dynamically using `importlib` to avoid importing unnecessary dependencies: ```python -try: - response = self._retry_with_exponential_backoff( - self.client.completions.create, - model=model, - prompt=prompt, - temperature=0.1, - stream=stream, - **kwargs, - ) -except httpx.ConnectError as e: - error_msg = f"LLM Server Connection Error: {str(e)}" - raise ConnectionError(error_msg) from e -except Exception as e: - error_str = str(e) - if "404" in error_str: - if "endpoint" in error_str.lower() or "not found" in error_str.lower(): - raise ConnectionError( - f"API endpoint error: {error_str}\n\n" - f"This may indicate:\n" - f" 1. Lemonade Server version mismatch (try updating to {LEMONADE_VERSION})\n" - f" 2. Model not properly loaded or corrupted\n\n" - f"To fix model issues, try:\n" - f" lemonade model remove \n" - f" lemonade model download \n" - ) from e - raise +_PROVIDERS = { + "lemonade": "gaia.llm.providers.lemonade.LemonadeProvider", + "openai": "gaia.llm.providers.openai_provider.OpenAIProvider", + "claude": "gaia.llm.providers.claude.ClaudeProvider", +} + +# Lazy import - only load when needed +module_path, class_name = _PROVIDERS[provider_lower].rsplit(".", 1) +module = importlib.import_module(module_path) +provider_class = getattr(module, class_name) + +return provider_class(**kwargs) ``` ---- +### NotSupportedError Pattern -## Testing Requirements +Optional methods raise `NotSupportedError` by default in the ABC: -### Unit Tests +```python +# In base_client.py +class LLMClient(ABC): + def embed(self, texts: list[str], **kwargs) -> list[list[float]]: + raise NotSupportedError(self.provider_name, "embed") + + def vision(self, images: list[bytes], prompt: str, **kwargs) -> str: + raise NotSupportedError(self.provider_name, "vision") + # etc... +``` -**File:** `tests/llm/test_llm_client.py` +Providers override only the methods they support: ```python -import pytest -from unittest.mock import Mock, patch -from gaia.llm import LLMClient - -def test_llm_client_can_be_imported(): - """Verify LLMClient can be imported.""" - from gaia.llm import LLMClient - assert LLMClient is not None - -def test_initialize_local_llm(): - """Test local LLM initialization.""" - client = LLMClient() - assert client.use_claude is False - assert client.use_openai is False - assert client.base_url.endswith("/api/v1") - assert client.endpoint == "completions" - -def test_initialize_with_custom_base_url(): - """Test base URL normalization.""" - # Without /api/v1 - client = LLMClient(base_url="http://localhost:8000") - assert client.base_url == "http://localhost:8000/api/v1" - - # With /api/v1 - client = LLMClient(base_url="http://localhost:8000/api/v1") - assert client.base_url == "http://localhost:8000/api/v1" - - # With trailing slash - client = LLMClient(base_url="http://localhost:8000/") - assert client.base_url == "http://localhost:8000/api/v1" - -def test_initialize_claude(): - """Test Claude API initialization.""" - with patch('gaia.llm.llm_client.CLAUDE_AVAILABLE', True): - with patch('gaia.llm.llm_client.AnthropicClaudeClient'): - client = LLMClient(use_claude=True) - assert client.use_claude is True - assert client.endpoint == "claude" - assert client.default_model.startswith("claude-") - -def test_initialize_openai(): - """Test OpenAI API initialization.""" - with patch.dict('os.environ', {'OPENAI_API_KEY': 'test-key'}): - client = LLMClient(use_openai=True) - assert client.use_openai is True - assert client.endpoint == "openai" - assert client.default_model == "gpt-4o" - -def test_generate_non_streaming(): - """Test non-streaming generation.""" - client = LLMClient() - - # Mock the OpenAI client - mock_response = Mock() - mock_response.choices = [Mock(text="Hello world")] - client.client.completions.create = Mock(return_value=mock_response) - - response = client.generate("Test prompt") - assert response == "Hello world" - assert client.client.completions.create.called - -def test_generate_streaming(): - """Test streaming generation.""" - client = LLMClient() - - # Mock streaming response - def mock_stream(): - for chunk in ["Hello", " ", "world"]: - mock_chunk = Mock() - mock_chunk.choices = [Mock(text=chunk)] - yield mock_chunk - - client.client.completions.create = Mock(return_value=mock_stream()) - - result = list(client.generate("Test prompt", stream=True)) - assert result == ["Hello", " ", "world"] - -def test_generate_with_messages(): - """Test generation with messages array.""" - client = LLMClient() - - mock_response = Mock() - mock_response.choices = [Mock(message=Mock(content="Response"))] - client.client.chat.completions.create = Mock(return_value=mock_response) - - messages = [ - {"role": "user", "content": "Hello"} - ] - response = client.generate("", endpoint="chat", messages=messages) - assert response == "Response" - -def test_retry_logic(): - """Test exponential backoff retry.""" - client = LLMClient(max_retries=2, retry_base_delay=0.1) - - # Mock function that fails twice then succeeds - mock_func = Mock(side_effect=[ - ConnectionError("Failed"), - ConnectionError("Failed"), - "Success" - ]) - - result = client._retry_with_exponential_backoff(mock_func) - assert result == "Success" - assert mock_func.call_count == 3 - -def test_retry_exhausted(): - """Test retry exhaustion.""" - client = LLMClient(max_retries=1, retry_base_delay=0.1) - - mock_func = Mock(side_effect=ConnectionError("Always fails")) - - with pytest.raises(ConnectionError): - client._retry_with_exponential_backoff(mock_func) - - assert mock_func.call_count == 2 # Initial + 1 retry - -def test_get_performance_stats(): - """Test performance stats retrieval.""" - client = LLMClient() - - with patch('requests.get') as mock_get: - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = { - "time_to_first_token": 0.5, - "tokens_per_second": 45.3, - "input_tokens": 10, - "output_tokens": 20 - } - - stats = client.get_performance_stats() - assert stats["time_to_first_token"] == 0.5 - assert stats["tokens_per_second"] == 45.3 +# OpenAIProvider overrides embed but not vision +class OpenAIProvider(LLMClient): + def embed(self, texts: list[str], **kwargs): + # Implementation for OpenAI embeddings + response = self._client.embeddings.create(...) + return [item.embedding for item in response.data] + + # vision() inherited - raises NotSupportedError +``` -def test_is_generating(): - """Test generation status check.""" - client = LLMClient() +### Temperature Defaults - with patch('requests.get') as mock_get: - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = {"is_generating": True} +All providers default to `temperature=0.1` for deterministic responses: - assert client.is_generating() is True +```python +# In LemonadeProvider +kwargs.setdefault("temperature", 0.1) +response = self._backend.completions(model=model, prompt=prompt, **kwargs) +``` -def test_halt_generation(): - """Test generation halting.""" - client = LLMClient() +### Provider-Specific Implementation - with patch('requests.get') as mock_get: - mock_get.return_value.status_code = 200 +**LemonadeProvider** wraps the low-level `LemonadeClient`: - assert client.halt_generation() is True +```python +class LemonadeProvider(LLMClient): + def __init__(self, model=None, base_url=None, **kwargs): + self._backend = LemonadeClient(model=model, base_url=base_url, **kwargs) -def test_clean_claude_response(): - """Test Claude response cleaning.""" - client = LLMClient() + def generate(self, prompt, model=None, stream=False, **kwargs): + return self._backend.completions(prompt=prompt, stream=stream, **kwargs) +``` - # Valid JSON with extra text - response = '{"result": "success"} Some extra text after' - cleaned = client._clean_claude_response(response) - assert cleaned == '{"result": "success"}' +**OpenAIProvider** uses the OpenAI SDK directly: - # Plain text (no JSON) - response = "Just plain text" - cleaned = client._clean_claude_response(response) - assert cleaned == "Just plain text" +```python +class OpenAIProvider(LLMClient): + def __init__(self, api_key=None, model="gpt-4o", **kwargs): + import openai + self._client = openai.OpenAI(api_key=api_key) + self._model = model +``` -def test_system_prompt(): - """Test system prompt handling.""" - system_prompt = "You are a helpful assistant." - client = LLMClient(system_prompt=system_prompt) - assert client.system_prompt == system_prompt +**ClaudeProvider** uses the Anthropic SDK: - # Override in generate() - mock_response = Mock() - mock_response.choices = [Mock(text="Response")] - client.client.completions.create = Mock(return_value=mock_response) +```python +class ClaudeProvider(LLMClient): + def __init__(self, api_key=None, model="claude-3-5-sonnet-20241022", **kwargs): + import anthropic + self._client = anthropic.Anthropic(api_key=api_key) + self._model = model +``` - client.generate("Test", system_prompt="Different prompt") - # Verify different prompt was used (would need more sophisticated mocking) +--- -def test_error_handling_404(): - """Test 404 error handling with helpful message.""" - client = LLMClient() +## Testing Requirements - client.client.completions.create = Mock( - side_effect=Exception("404 endpoint not found") - ) +### Unit Tests - with pytest.raises(ConnectionError) as exc_info: - client.generate("Test") +**File:** `tests/unit/test_llm_client_factory.py` - assert "Lemonade Server version mismatch" in str(exc_info.value) - assert "lemonade model remove" in str(exc_info.value) +```python +import pytest +from unittest.mock import patch, Mock +from gaia.llm import create_client, LLMClient, NotSupportedError + +class TestImports: + def test_can_import_create_client(self): + """Verify create_client can be imported.""" + from gaia.llm import create_client + assert callable(create_client) + + def test_can_import_llm_client_abc(self): + """Verify LLMClient ABC can be imported.""" + from abc import ABC + from gaia.llm import LLMClient + assert issubclass(LLMClient, ABC) + + def test_can_import_not_supported_error(self): + """Verify NotSupportedError can be imported.""" + from gaia.llm import NotSupportedError + assert issubclass(NotSupportedError, Exception) + +class TestFactory: + def test_default_creates_lemonade_provider(self): + """Test factory creates Lemonade by default.""" + with patch("gaia.llm.providers.lemonade.LemonadeClient"): + client = create_client() + assert client.provider_name == "Lemonade" + + def test_explicit_provider_selection(self): + """Test explicit provider parameter.""" + with patch("gaia.llm.providers.lemonade.LemonadeClient"): + client = create_client(provider="lemonade") + assert client.provider_name == "Lemonade" + + def test_use_claude_flag(self): + """Test backward-compatible use_claude flag.""" + with patch("gaia.llm.providers.claude.anthropic"): + client = create_client(use_claude=True, api_key="test") + assert client.provider_name == "Claude" + + def test_use_openai_flag(self): + """Test backward-compatible use_openai flag.""" + with patch("openai.OpenAI"): + client = create_client(use_openai=True, api_key="test") + assert client.provider_name == "OpenAI" + + def test_invalid_provider_raises_error(self): + """Test unknown provider raises ValueError.""" + with pytest.raises(ValueError, match="Unknown provider"): + create_client(provider="invalid") + + def test_both_flags_raises_error(self): + """Test both flags raise ValueError.""" + with pytest.raises(ValueError, match="Cannot specify both"): + create_client(use_claude=True, use_openai=True) + + def test_case_insensitive_provider(self): + """Test provider names are case-insensitive.""" + with patch("gaia.llm.providers.lemonade.LemonadeClient"): + client = create_client(provider="LEMONADE") + assert client.provider_name == "Lemonade" + +class TestNotSupportedError: + def test_error_message_format(self): + """Test NotSupportedError message format.""" + error = NotSupportedError("TestProvider", "test_method") + assert "TestProvider" in str(error) + assert "test_method" in str(error) + + def test_claude_embed_raises_not_supported(self): + """Test Claude provider raises NotSupportedError for embed.""" + with patch("gaia.llm.providers.claude.anthropic"): + client = create_client(provider="claude", api_key="test") + with pytest.raises(NotSupportedError) as exc: + client.embed(["text"]) + assert "Claude" in str(exc.value) + assert "embed" in str(exc.value) + + def test_openai_vision_raises_not_supported(self): + """Test OpenAI provider raises NotSupportedError for vision.""" + with patch("openai.OpenAI"): + client = create_client(provider="openai", api_key="test") + with pytest.raises(NotSupportedError) as exc: + client.vision([b"image"], "describe") + assert "OpenAI" in str(exc.value) + assert "vision" in str(exc.value) + +class TestProviderMethods: + def test_lemonade_generate(self): + """Test LemonadeProvider.generate().""" + with patch("gaia.llm.providers.lemonade.LemonadeClient") as MockClient: + mock_backend = Mock() + mock_backend.completions.return_value = { + "choices": [{"text": "Hello"}] + } + MockClient.return_value = mock_backend + + client = create_client() + response = client.generate("Test") + assert response == "Hello" + + def test_lemonade_chat(self): + """Test LemonadeProvider.chat().""" + with patch("gaia.llm.providers.lemonade.LemonadeClient") as MockClient: + mock_backend = Mock() + mock_backend.chat_completions.return_value = { + "choices": [{"message": {"content": "Hi"}}] + } + MockClient.return_value = mock_backend + + client = create_client() + response = client.chat([{"role": "user", "content": "Hello"}]) + assert response == "Hi" + + def test_streaming_returns_iterator(self): + """Test streaming returns an iterator.""" + with patch("gaia.llm.providers.lemonade.LemonadeClient") as MockClient: + mock_backend = Mock() + + def mock_stream(): + yield {"choices": [{"delta": {"content": "Hi"}}]} + yield {"choices": [{"delta": {"content": " there"}}]} + + mock_backend.chat_completions.return_value = mock_stream() + MockClient.return_value = mock_backend + + client = create_client() + result = client.chat([{"role": "user", "content": "Hello"}], stream=True) + chunks = list(result) + assert len(chunks) == 2 + assert "".join(chunks) == "Hi there" ``` ### Integration Tests -```python -def test_integration_local_llm(): - """Test integration with local Lemonade server.""" - client = LLMClient() +**File:** `tests/integration/test_llm_providers.py` +```python +def test_integration_lemonade_generate(): + """Test Lemonade provider with live server.""" try: + client = create_client() response = client.generate("Say hello") assert isinstance(response, str) assert len(response) > 0 except ConnectionError: pytest.skip("Lemonade server not running") -def test_integration_streaming(): - """Test streaming integration.""" - client = LLMClient() - +def test_integration_lemonade_streaming(): + """Test Lemonade streaming.""" try: - chunks = [] - for chunk in client.generate("Count to 3", stream=True): - chunks.append(chunk) - + client = create_client() + chunks = list(client.generate("Count to 3", stream=True)) assert len(chunks) > 0 - full_response = "".join(chunks) - assert len(full_response) > 0 + except ConnectionError: + pytest.skip("Lemonade server not running") + +def test_integration_lemonade_performance_stats(): + """Test performance stats (Lemonade only).""" + try: + client = create_client() + client.generate("Test") + stats = client.get_performance_stats() + assert isinstance(stats, dict) except ConnectionError: pytest.skip("Lemonade server not running") ``` @@ -706,151 +783,218 @@ claude = ["anthropic>=0.18.0"] # Claude API support ### Import Dependencies +**Factory (`factory.py`):** +```python +import importlib +from typing import Optional +from .base_client import LLMClient +``` + +**Base Client (`base_client.py`):** +```python +from abc import ABC, abstractmethod +from typing import Iterator, Union +from .exceptions import NotSupportedError +``` + +**LemonadeProvider (`providers/lemonade.py`):** ```python -import logging -import os -import time -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, TypeVar, Union +from typing import Iterator, Optional, Union +from ..base_client import LLMClient +from ..lemonade_client import LemonadeClient, DEFAULT_MODEL_NAME +# DEFAULT_MODEL_NAME = "Qwen2.5-0.5B-Instruct-CPU" +``` -import httpx -import requests -from dotenv import load_dotenv -from openai import OpenAI +**OpenAIProvider (`providers/openai_provider.py`):** +```python +from typing import Iterator, Optional, Union +import openai # Requires: pip install openai +from ..base_client import LLMClient +``` -# Conditional Claude import +**ClaudeProvider (`providers/claude.py`):** +```python +from typing import Iterator, Optional, Union try: - from gaia.eval.claude import ClaudeClient as AnthropicClaudeClient - CLAUDE_AVAILABLE = True + import anthropic # Requires: pip install anthropic except ImportError: - CLAUDE_AVAILABLE = False + anthropic = None +from ..base_client import LLMClient ``` --- ## Usage Examples -### Example 1: Basic Local LLM +### Example 1: Basic Usage with Factory ```python -from gaia.llm import LLMClient +from gaia.llm import create_client -# Initialize with local Lemonade server -client = LLMClient() +# Default Lemonade provider +client = create_client() # Non-streaming generation response = client.generate("Write a hello world program in Python") print(response) -# Get performance stats +# Get performance stats (Lemonade only) stats = client.get_performance_stats() -print(f"Speed: {stats['tokens_per_second']:.1f} tokens/sec") +print(f"Speed: {stats.get('tokens_per_second', 'N/A')} tokens/sec") ``` -### Example 2: Streaming Responses +### Example 2: Explicit Provider Selection ```python -from gaia.llm import LLMClient +from gaia.llm import create_client -client = LLMClient() +# Lemonade (local) +lemonade = create_client(provider="lemonade", model="Qwen2.5-0.5B-Instruct-CPU") -# Streaming generation -print("AI: ", end="", flush=True) -for chunk in client.generate("Tell me a short story", stream=True): - print(chunk, end="", flush=True) -print() +# OpenAI +openai_client = create_client(provider="openai", api_key="sk-...") + +# Claude +claude = create_client(provider="claude", api_key="sk-ant-...") + +# Backward-compatible flags +legacy_claude = create_client(use_claude=True, api_key="sk-ant-...") ``` -### Example 3: Using Claude API +### Example 3: Handling Unsupported Features ```python -from gaia.llm import LLMClient +from gaia.llm import create_client, NotSupportedError -# Initialize with Claude -client = LLMClient( - use_claude=True, - claude_model="claude-sonnet-4-20250514", - system_prompt="You are a helpful coding assistant." -) +# Create OpenAI client +client = create_client(provider="openai", api_key="sk-...") -# Generate code -response = client.generate("Write a binary search function") -print(response) +# This works (OpenAI supports embed) +embeddings = client.embed(["Hello world", "How are you?"]) + +# This raises NotSupportedError (OpenAI doesn't support vision) +try: + result = client.vision([image_bytes], "Describe this image") +except NotSupportedError as e: + print(f"Feature not available: {e}") + # Output: "OpenAI does not support vision" ``` ### Example 4: Chat with Message History ```python -from gaia.llm import LLMClient +from gaia.llm import create_client -client = LLMClient() +client = create_client() # Build conversation history messages = [ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's 2+2?"}, {"role": "assistant", "content": "2+2 equals 4."}, {"role": "user", "content": "What about 3+3?"} ] -# Generate with full context -response = client.generate("", endpoint="chat", messages=messages) +# Use chat() method +response = client.chat(messages) print(response) # "3+3 equals 6." ``` -### Example 5: Halting Generation +### Example 5: Streaming Responses ```python -from gaia.llm import LLMClient -import threading -import time +from gaia.llm import create_client + +client = create_client() + +# Streaming with generate() +print("AI: ", end="", flush=True) +for chunk in client.generate("Tell me a short story", stream=True): + print(chunk, end="", flush=True) +print() + +# Streaming with chat() +for chunk in client.chat([{"role": "user", "content": "Hello"}], stream=True): + print(chunk, end="", flush=True) +``` -client = LLMClient() +### Example 6: Embeddings (Lemonade and OpenAI only) -def generate_long_text(): - """Generate in background thread.""" - response = client.generate("Write a very long essay about AI") - print(response) +```python +from gaia.llm import create_client + +# With Lemonade +lemonade = create_client() +embeddings = lemonade.embed(["Hello world", "How are you?"]) +print(f"Embedding dimensions: {len(embeddings[0])}") -# Start generation in background -thread = threading.Thread(target=generate_long_text) -thread.start() +# With OpenAI +openai_client = create_client(provider="openai", api_key="sk-...") +embeddings = openai_client.embed(["Text to embed"]) +``` -# Wait a bit, then halt -time.sleep(2) -if client.is_generating(): - client.halt_generation() - print("Generation stopped!") +### Example 7: Vision (Lemonade and Claude only) -thread.join() +```python +from gaia.llm import create_client + +# With Claude +claude = create_client(provider="claude", api_key="sk-ant-...") +with open("image.jpg", "rb") as f: + image_data = f.read() +description = claude.vision([image_data], "Describe what you see") +print(description) ``` -### Example 6: Custom Retry Configuration +### Example 8: Model Management (Lemonade only) ```python -from gaia.llm import LLMClient +from gaia.llm import create_client -# Configure aggressive retry -client = LLMClient( - max_retries=5, - retry_base_delay=0.5, # Start with 0.5s delay -) +client = create_client() + +# Load a specific model +client.load_model("Qwen2.5-0.5B-Instruct-CPU") -# Will retry up to 5 times with exponential backoff +# Generate response = client.generate("Hello") + +# Get performance stats +stats = client.get_performance_stats() +print(f"Speed: {stats.get('tokens_per_second', 'N/A')} tokens/sec") + +# Unload model +client.unload_model() ``` -### Example 7: Remote Lemonade Server +### Example 9: Remote Lemonade Server ```python -from gaia.llm import LLMClient +from gaia.llm import create_client # Connect to remote server -client = LLMClient(base_url="http://192.168.1.100:8000") +client = create_client(base_url="http://192.168.1.100:8000") response = client.generate("Hello from remote server") print(response) ``` +### Example 10: System Prompts + +```python +from gaia.llm import create_client + +# Set default system prompt +client = create_client( + system_prompt="You are a helpful coding assistant." +) + +# System prompt automatically prepended to chat messages +response = client.chat([ + {"role": "user", "content": "Write a binary search function"} +]) +print(response) +``` + --- ## Third-Party LLM Integration @@ -995,16 +1139,16 @@ Your LLM service must implement at least one of these OpenAI-compatible endpoint ```python - from gaia.llm import LLMClient + from gaia.llm import create_client - client = LLMClient(base_url="http://your-llm-service:8080") + client = create_client(base_url="http://your-llm-service:8080") response = client.generate("Hello world") ``` -**URL Normalization:** GAIA automatically appends `/api/v1` if not present: +**URL Normalization:** LemonadeClient automatically appends `/api/v1` if not present: - `http://localhost:8080` → `http://localhost:8080/api/v1` - If your service uses `/v1` instead, provide the full path: `http://localhost:8080/v1` @@ -1015,10 +1159,10 @@ Your LLM service must implement at least one of these OpenAI-compatible endpoint ```python Basic Connection -from gaia.llm import LLMClient +from gaia.llm import create_client # Connect to your third-party LLM service -client = LLMClient(base_url="http://your-service:8080/v1") +client = create_client(base_url="http://your-service:8080/v1") # Test connection response = client.generate("Hello, are you working?") @@ -1026,9 +1170,9 @@ print(response) ``` ```python Streaming -from gaia.llm import LLMClient +from gaia.llm import create_client -client = LLMClient(base_url="http://your-service:8080/v1") +client = create_client(base_url="http://your-service:8080/v1") # Stream response chunks for chunk in client.generate("Tell me a story", stream=True): @@ -1036,7 +1180,7 @@ for chunk in client.generate("Tell me a story", stream=True): ``` ```python With Agent -from gaia.llm import LLMClient +from gaia.llm import create_client from gaia.agents import Agent class CustomAgent(Agent): @@ -1076,10 +1220,10 @@ result = agent.process_task("Analyze this code...") - The following features are specific to Lemonade Server and will not work with third-party services: - - `get_performance_stats()` - Returns empty dict `{}` - - `is_generating()` - Returns `False` - - `halt_generation()` - Returns `False` + The following features are specific to Lemonade provider and raise `NotSupportedError` with third-party services: + - `get_performance_stats()` - Performance statistics + - `load_model()` - Model loading + - `unload_model()` - Model unloading @@ -1098,9 +1242,9 @@ result = agent.process_task("Analyze this code...") ``` 2. Check firewall settings 3. Ensure correct base URL format - 4. Test with explicit endpoint: + 4. Test with explicit base URL: ```python - client = LLMClient(base_url="http://localhost:8080/v1") + client = create_client(base_url="http://localhost:8080/v1") ``` @@ -1111,9 +1255,9 @@ result = agent.process_task("Analyze this code...") 1. Check if service uses `/v1/completions` (OpenAI standard) 2. Verify API path structure: `/v1` vs `/api/v1` 3. Consult service documentation for correct endpoint paths - 4. Use explicit endpoint override: + 4. Use chat method explicitly if needed: ```python - client.generate("Test", endpoint="chat") # Force chat endpoint + client.chat([{"role": "user", "content": "Test"}]) ``` @@ -1159,25 +1303,26 @@ result = agent.process_task("Analyze this code...") Add to LLM Section: ```markdown -### LLMClient +### LLM Client -**Import:** `from gaia.llm import LLMClient` +**Import:** `from gaia.llm import create_client, LLMClient, NotSupportedError` -**Purpose:** Unified interface for LLM generation across local, Claude, and OpenAI backends. +**Purpose:** Provider-based LLM client with factory pattern for local and cloud backends. **Features:** -- Multi-backend support (local Lemonade, Claude, OpenAI) +- Factory-based client creation with `create_client()` +- Three providers: Lemonade (local), OpenAI, Claude +- Abstract base class for type safety +- `NotSupportedError` for unsupported features - Streaming and non-streaming generation -- Automatic retry with exponential backoff -- Performance monitoring -- Generation control (halt/resume) +- Backward-compatible flags **Quick Start:** ```python -from gaia.llm import LLMClient +from gaia.llm import create_client -# Local LLM -client = LLMClient() +# Local LLM (default) +client = create_client() response = client.generate("Hello world") # Streaming @@ -1185,8 +1330,11 @@ for chunk in client.generate("Tell me a story", stream=True): print(chunk, end="") # Claude API -client = LLMClient(use_claude=True) -response = client.generate("Explain Python decorators") +claude = create_client(provider="claude", api_key="sk-ant-...") +response = claude.generate("Explain Python decorators") + +# Backward-compatible +client = create_client(use_claude=True) ``` --- diff --git a/setup.py b/setup.py index a1a18e544..f958a240b 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ packages=[ "gaia", "gaia.llm", + "gaia.llm.providers", "gaia.audio", "gaia.chat", "gaia.database", diff --git a/src/gaia/agents/blender/agent_simple.py b/src/gaia/agents/blender/agent_simple.py index c377b541f..79941affd 100644 --- a/src/gaia/agents/blender/agent_simple.py +++ b/src/gaia/agents/blender/agent_simple.py @@ -3,7 +3,8 @@ from typing import Any, Dict, Optional, Tuple -from gaia.llm.llm_client import LLMClient +from gaia.llm import create_client +from gaia.llm.base_client import LLMClient from gaia.mcp.blender_mcp_client import MCPClient @@ -39,7 +40,6 @@ def __init__( self, llm: Optional[LLMClient] = None, mcp: Optional[MCPClient] = None, - use_local: bool = True, base_url: Optional[str] = "http://localhost:8000/api/v1", ): """ @@ -48,15 +48,13 @@ def __init__( Args: llm: An optional pre-configured LLM client, otherwise a new one will be created mcp: An optional pre-configured MCP client, otherwise a new one will be created - use_local: Whether to use a local LLM (True) or a remote API (False) - base_url: Base URL for the local LLM API if using local LLM. If None and use_local=True, - defaults to "http://localhost:8000/api/v1" + base_url: Base URL for the Lemonade LLM server """ self.llm = ( llm if llm - else LLMClient( - use_local=use_local, system_prompt=self.SYSTEM_PROMPT, base_url=base_url + else create_client( + "lemonade", base_url=base_url, system_prompt=self.SYSTEM_PROMPT ) ) self.mcp = mcp if mcp else MCPClient() diff --git a/src/gaia/agents/blender/tests/test_agent.py b/src/gaia/agents/blender/tests/test_agent.py index 5bb4a150b..f5d996f7c 100644 --- a/src/gaia/agents/blender/tests/test_agent.py +++ b/src/gaia/agents/blender/tests/test_agent.py @@ -10,7 +10,7 @@ from gaia.agents.base.console import AgentConsole from gaia.agents.blender.agent import BlenderAgent -from gaia.llm.llm_client import LLMClient +from gaia.llm.base_client import LLMClient from gaia.mcp.blender_mcp_client import MCPClient # Set up logging diff --git a/src/gaia/agents/blender/tests/test_agent_simple.py b/src/gaia/agents/blender/tests/test_agent_simple.py index 75211f937..cb8295290 100644 --- a/src/gaia/agents/blender/tests/test_agent_simple.py +++ b/src/gaia/agents/blender/tests/test_agent_simple.py @@ -7,8 +7,9 @@ import pytest from gaia.agents.blender.agent_simple import BlenderAgentSimple -from gaia.agents.blender.mcp.mcp_client import MCPClient -from gaia.llm.llm_client import LLMClient +from gaia.llm import create_client +from gaia.llm.base_client import LLMClient +from gaia.mcp.blender_mcp_client import MCPClient # Set up logging logging.basicConfig(level=logging.DEBUG) @@ -34,7 +35,7 @@ def llm_client(): - For any other request, respond with: CYLINDER,0,2,0,0.5,0.5,3 """ # Using local LLM for faster testing - return LLMClient(system_prompt=system_prompt) + return create_client("lemonade", system_prompt=system_prompt) @pytest.fixture @@ -166,8 +167,8 @@ def test_agent_initialization(): assert isinstance(agent.mcp, MCPClient) # Verify default system prompt is set - logger.debug(f"System prompt: {agent.llm.system_prompt}") - assert agent.llm.system_prompt == agent.SYSTEM_PROMPT + logger.debug(f"System prompt: {agent.llm._system_prompt}") + assert agent.llm._system_prompt == agent.SYSTEM_PROMPT if __name__ == "__main__": diff --git a/src/gaia/agents/routing/agent.py b/src/gaia/agents/routing/agent.py index 10323b683..7b8a1c1e7 100644 --- a/src/gaia/agents/routing/agent.py +++ b/src/gaia/agents/routing/agent.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from gaia.agents.base.agent import Agent -from gaia.llm.llm_client import LLMClient +from gaia.llm import create_client from gaia.logger import get_logger from .system_prompt import ROUTING_ANALYSIS_PROMPT @@ -58,13 +58,10 @@ def __init__( # Read from environment if not provided base_url = os.getenv("LEMONADE_BASE_URL", "http://localhost:8000/api/v1") - llm_kwargs = { - "use_claude": use_claude, - "use_openai": use_chatgpt, - "base_url": base_url, - } - - self.llm_client = LLMClient(**llm_kwargs) + # Initialize LLM client - factory auto-detects provider from flags + self.llm_client = create_client( + use_claude=use_claude, use_openai=use_chatgpt, base_url=base_url + ) self.agent_kwargs = agent_kwargs # Store for passing to created agents # Model to use for routing analysis (configurable via env var) diff --git a/src/gaia/apps/llm/app.py b/src/gaia/apps/llm/app.py index d57d1b56f..b088344ea 100644 --- a/src/gaia/apps/llm/app.py +++ b/src/gaia/apps/llm/app.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT """ @@ -11,7 +11,7 @@ import sys from typing import Iterator, Optional, Union -from gaia.llm.llm_client import LLMClient +from gaia.llm import create_client from gaia.logger import get_logger @@ -28,7 +28,11 @@ def __init__( base_url: Base URL for local LLM server (defaults to LEMONADE_BASE_URL env var) """ self.log = get_logger(__name__) - self.client = LLMClient(system_prompt=system_prompt, base_url=base_url) + self.client = create_client( + "lemonade", + base_url=base_url, + system_prompt=system_prompt, + ) self.log.debug("LLM app initialized") def query( diff --git a/src/gaia/audio/audio_client.py b/src/gaia/audio/audio_client.py index a19c8f284..47a0a7b2f 100644 --- a/src/gaia/audio/audio_client.py +++ b/src/gaia/audio/audio_client.py @@ -6,7 +6,7 @@ import threading import time -from gaia.llm.llm_client import LLMClient +from gaia.llm import create_client from gaia.logger import get_logger @@ -40,10 +40,10 @@ def __init__( self.transcription_queue = queue.Queue() self.tts = None - # Initialize LLM client (base_url handled automatically) - self.llm_client = LLMClient( + # Initialize LLM client - factory auto-detects provider from flags + self.llm_client = create_client( use_claude=use_claude, - use_openai=use_chatgpt, # LLMClient uses use_openai, not use_chatgpt + use_openai=use_chatgpt, system_prompt=system_prompt, ) diff --git a/src/gaia/chat/sdk.py b/src/gaia/chat/sdk.py index 4493361d5..1e99c8a72 100644 --- a/src/gaia/chat/sdk.py +++ b/src/gaia/chat/sdk.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT """ @@ -13,8 +13,8 @@ from typing import Any, Dict, List, Optional from gaia.chat.prompts import Prompts +from gaia.llm import create_client from gaia.llm.lemonade_client import DEFAULT_MODEL_NAME -from gaia.llm.llm_client import LLMClient from gaia.logger import get_logger @@ -87,19 +87,17 @@ def __init__(self, config: Optional[ChatConfig] = None): self.log = get_logger(__name__) self.log.setLevel(getattr(logging, self.config.logging_level)) - # Validate that both providers aren't specified - if self.config.use_claude and self.config.use_chatgpt: - raise ValueError( - "Cannot specify both use_claude and use_chatgpt. Please choose one." - ) - - # Initialize LLM client - it will compute use_local automatically - self.llm_client = LLMClient( + # Initialize LLM client - factory auto-detects provider and validates + self.llm_client = create_client( use_claude=self.config.use_claude, use_openai=self.config.use_chatgpt, - claude_model=self.config.claude_model, + model=( + self.config.claude_model + if self.config.use_claude + else self.config.model + ), base_url=self.config.base_url, - system_prompt=None, # We handle system prompts through Prompts class + system_prompt=self.config.system_prompt, ) # Store conversation history diff --git a/src/gaia/cli.py b/src/gaia/cli.py index ac1e44d33..32b9020d3 100644 --- a/src/gaia/cli.py +++ b/src/gaia/cli.py @@ -8,10 +8,10 @@ import sys import time from pathlib import Path -from typing import Optional from dotenv import load_dotenv +from gaia.llm import create_client from gaia.llm.lemonade_client import ( DEFAULT_HOST, DEFAULT_LEMONADE_URL, @@ -21,7 +21,6 @@ LemonadeClientError, _get_lemonade_config, ) -from gaia.llm.llm_client import LLMClient from gaia.logger import get_logger from gaia.perf_analysis import run_perf_visualization from gaia.version import version @@ -113,8 +112,8 @@ def initialize_lemonade_for_agent( skip_if_external: bool = False, use_claude: bool = False, use_chatgpt: bool = False, - host: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + port: int | None = None, ): """ Initialize Lemonade Server for a specific GAIA agent. @@ -394,7 +393,7 @@ def __init__( self.show_stats = show_stats # Initialize LLM client for local inference - self.llm_client = LLMClient() + self.llm_client = create_client("lemonade", model=model) self.log.debug("Gaia CLI client initialized.") self.log.debug(f"model: {self.model}\n max_tokens: {self.max_tokens}") @@ -3180,16 +3179,6 @@ def download_progress_callback(event_type: str, data: dict) -> None: LemonadeManager.print_server_error() else: print(f"❌ Error: {str(e)}") - # Check for model loading related errors - if ( - "404" in error_msg - or "not found" in error_msg - or "not loaded" in error_msg - ): - print( - "\nMake sure that the model is loaded. You can load it using:" - ) - print(f" gaia pull {DEFAULT_MODEL_NAME}") return # Handle groundtruth generation diff --git a/src/gaia/llm/__init__.py b/src/gaia/llm/__init__.py index 53bd49073..d87b9e6a3 100644 --- a/src/gaia/llm/__init__.py +++ b/src/gaia/llm/__init__.py @@ -1,2 +1,9 @@ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +"""LLM client package.""" + +from .base_client import LLMClient +from .exceptions import NotSupportedError +from .factory import create_client + +__all__ = ["create_client", "LLMClient", "NotSupportedError"] diff --git a/src/gaia/llm/base_client.py b/src/gaia/llm/base_client.py new file mode 100644 index 000000000..7a2e99412 --- /dev/null +++ b/src/gaia/llm/base_client.py @@ -0,0 +1,60 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Base LLM client interface.""" + +from abc import ABC, abstractmethod +from typing import Iterator, Union + +from .exceptions import NotSupportedError + + +class LLMClient(ABC): + """ + Unified LLM client interface. + + Methods raise NotSupportedError if not available for this provider. + """ + + @property + @abstractmethod + def provider_name(self) -> str: + """Return the provider name for error messages.""" + ... + + @abstractmethod + def generate( + self, + prompt: str, + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + """Generate text completion.""" + ... + + @abstractmethod + def chat( + self, + messages: list[dict], + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + """Chat completion.""" + ... + + # Optional - default raises NotSupportedError + def embed(self, texts: list[str], **kwargs) -> list[list[float]]: + raise NotSupportedError(self.provider_name, "embed") + + def vision(self, images: list[bytes], prompt: str, **kwargs) -> str: + raise NotSupportedError(self.provider_name, "vision") + + def get_performance_stats(self) -> dict: + raise NotSupportedError(self.provider_name, "get_performance_stats") + + def load_model(self, model_name: str, **kwargs) -> None: + raise NotSupportedError(self.provider_name, "load_model") + + def unload_model(self) -> None: + raise NotSupportedError(self.provider_name, "unload_model") diff --git a/src/gaia/llm/exceptions.py b/src/gaia/llm/exceptions.py new file mode 100644 index 000000000..20002075a --- /dev/null +++ b/src/gaia/llm/exceptions.py @@ -0,0 +1,12 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""LLM client exceptions.""" + + +class NotSupportedError(Exception): + """Raised when a provider doesn't support a method.""" + + def __init__(self, provider: str, method: str): + self.provider = provider + self.method = method + super().__init__(f"{provider} does not support {method}") diff --git a/src/gaia/llm/factory.py b/src/gaia/llm/factory.py new file mode 100644 index 000000000..ede92b023 --- /dev/null +++ b/src/gaia/llm/factory.py @@ -0,0 +1,70 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""LLM client factory.""" + +from typing import Optional + +from .base_client import LLMClient + +_PROVIDERS: dict[str, str] = { + "lemonade": "gaia.llm.providers.lemonade.LemonadeProvider", + "openai": "gaia.llm.providers.openai_provider.OpenAIProvider", + "claude": "gaia.llm.providers.claude.ClaudeProvider", +} + + +def create_client( + provider: Optional[str] = None, + use_claude: bool = False, + use_openai: bool = False, + **kwargs, +) -> LLMClient: + """ + Create an LLM client, auto-detecting provider from parameters. + + Args: + provider: Explicit provider name ("lemonade", "openai", or "claude"). + If not specified, auto-detected from use_claude/use_openai flags. + use_claude: If True, use Claude provider (ignored if provider is specified) + use_openai: If True, use OpenAI provider (ignored if provider is specified) + **kwargs: Provider-specific arguments (base_url, model, api_key, etc.) + + Note: + The design using these flags maintains backward compatibility + while allowing explicit provider selection. If both use_claude and + use_openai are False and provider is not specified, the default + provider "lemonade" is used. This was deemed better than updating all + existing callers with conditional logic and multiple `create_client` calls. + + Returns: + LLMClient instance for the specified or detected provider + + Raises: + ValueError: If provider is not recognized or both use_claude and use_openai are True + """ + # Auto-detect provider from flags if not explicitly specified + if provider is None: + if use_claude and use_openai: + raise ValueError( + "Cannot specify both use_claude and use_openai. Please choose one." + ) + elif use_claude: + provider = "claude" + elif use_openai: + provider = "openai" + else: + provider = "lemonade" + + provider_lower = provider.lower() + + if provider_lower not in _PROVIDERS: + available = ", ".join(_PROVIDERS.keys()) + raise ValueError(f"Unknown provider: {provider}. Available: {available}") + + import importlib + + module_path, class_name = _PROVIDERS[provider_lower].rsplit(".", 1) + module = importlib.import_module(module_path) + provider_class = getattr(module, class_name) + + return provider_class(**kwargs) diff --git a/src/gaia/llm/lemonade_client.py b/src/gaia/llm/lemonade_client.py index 12d1425f0..a956c8d88 100644 --- a/src/gaia/llm/lemonade_client.py +++ b/src/gaia/llm/lemonade_client.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT """ Lemonade Server Client for GAIA. @@ -505,6 +505,7 @@ def __init__( model: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None, + base_url: Optional[str] = None, verbose: bool = True, keep_alive: bool = False, ): @@ -515,18 +516,35 @@ def __init__( model: Name of the model to load (optional) host: Host address of the Lemonade server (defaults to LEMONADE_BASE_URL env var) port: Port number of the Lemonade server (defaults to LEMONADE_BASE_URL env var) + base_url: Base URL for the Lemonade server (defaults to LEMONADE_BASE_URL env var) verbose: If False, reduce logging verbosity during initialization keep_alive: If True, don't terminate server in __del__ """ + from urllib.parse import urlparse + # Use provided host/port, or get from env var, or use defaults env_host, env_port, env_base_url = _get_lemonade_config() - self.host = host if host is not None else env_host - self.port = port if port is not None else env_port - # If host/port explicitly provided, construct URL; otherwise use env URL directly + + # Determine base_url with priority: explicit params > base_url param > env if host is not None or port is not None: + # Explicit host/port provided - construct URL from them + self.host = host if host is not None else env_host + self.port = port if port is not None else env_port self.base_url = f"http://{self.host}:{self.port}/api/{LEMONADE_API_VERSION}" + elif base_url is not None: + # base_url parameter provided - normalize and use it + if not base_url.rstrip("/").endswith(f"/api/{LEMONADE_API_VERSION}"): + base_url = f"{base_url.rstrip('/')}/api/{LEMONADE_API_VERSION}" + self.base_url = base_url + # Parse for backwards compatibility with code accessing self.host/self.port + parsed = urlparse(base_url) + self.host = parsed.hostname or DEFAULT_HOST + self.port = parsed.port or DEFAULT_PORT else: + # Use environment config self.base_url = env_base_url + self.host = env_host + self.port = env_port self.model = model self.server_process = None self.log = get_logger(__name__) @@ -1196,6 +1214,9 @@ def _stream_chat_completions_with_openai( }] } """ + # Proactively ensure model is loaded before making request + self._ensure_model_loaded(model, auto_download) + # Create a client just for this request client = OpenAI( base_url=self.base_url, @@ -1267,67 +1288,6 @@ def _stream_chat_completions_with_openai( except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: error_type = e.__class__.__name__ error_msg = str(e) - - # Check if this is a model loading error and auto_download is enabled - if auto_download and self._is_model_error(e): - self.log.info( - f"{_emoji('📥', '[AUTO-DOWNLOAD]')} Model '{model}' not loaded, " - f"attempting auto-download and load..." - ) - try: - # Load model with auto-download (may take hours for large models) - self.load_model(model, timeout=60, auto_download=True) - - # Retry streaming - self.log.info( - f"{_emoji('🔄', '[RETRY]')} Retrying streaming chat completion " - f"with model: {model}" - ) - stream = client.chat.completions.create(**request_params) - - tokens_generated = 0 - for chunk in stream: - tokens_generated += 1 - yield { - "id": chunk.id, - "object": "chat.completion.chunk", - "created": chunk.created, - "model": chunk.model, - "choices": [ - { - "index": choice.index, - "delta": { - "role": ( - choice.delta.role - if hasattr(choice.delta, "role") - and choice.delta.role - else None - ), - "content": ( - choice.delta.content - if hasattr(choice.delta, "content") - and choice.delta.content - else None - ), - }, - "finish_reason": choice.finish_reason, - } - for choice in chunk.choices - ], - } - - self.log.debug( - f"Completed streaming chat completion. Generated {tokens_generated} tokens." - ) - return - - except Exception as load_error: - self.log.error(f"Auto-download/load failed: {load_error}") - raise LemonadeClientError( - f"Model '{model}' not loaded and auto-load failed: {load_error}" - ) - - # Re-raise original error self.log.error(f"OpenAI {error_type}: {error_msg}") raise LemonadeClientError(f"OpenAI {error_type}: {error_msg}") except Exception as e: @@ -1458,6 +1418,7 @@ def _stream_completions_with_openai( echo: bool = False, timeout: int = DEFAULT_REQUEST_TIMEOUT, logprobs: Optional[bool] = None, + auto_download: bool = True, **kwargs, ) -> Generator[Dict[str, Any], None, None]: """ @@ -1476,6 +1437,9 @@ def _stream_completions_with_openai( }] } """ + # Proactively ensure model is loaded before making request + self._ensure_model_loaded(model, auto_download) + client = OpenAI( base_url=self.base_url, api_key="lemonade", # required, but unused @@ -2079,6 +2043,43 @@ def _wait_for_model_download( ) return False + def _ensure_model_loaded(self, model: str, auto_download: bool = True) -> None: + """Ensure a model is loaded on the server before making requests. + + This method proactively checks if the model is loaded and loads it if not, + preventing 404 errors when making completions requests. Downloads are + automatic without user prompts when auto_download is enabled. + + Args: + model: Model name to ensure is loaded + auto_download: If True, download the model if not present (without prompting) + + Note: + This method is called at the start of streaming methods to ensure + the model is ready before making API requests. When a model is explicitly + requested via CLI flags, it downloads automatically without user confirmation. + """ + if not auto_download: + return # Skip if auto_download disabled + + try: + # Check current server state + status = self.get_status() + loaded_models = [m.get("id", "") for m in status.loaded_models] + + # If model already loaded, nothing to do + if model in loaded_models: + self.log.debug(f"Model '{model}' already loaded") + return + + # Model not loaded - load it (will download if needed without prompting) + self.log.info(f"Model '{model}' not loaded, loading...") + self.load_model(model, auto_download=True, prompt=False) + + except Exception as e: + # Log but don't fail - let the actual request fail with proper error + self.log.debug(f"Could not pre-check model status: {e}") + def load_model( self, model_name: str, @@ -2086,12 +2087,13 @@ def load_model( auto_download: bool = False, download_timeout: int = 7200, llamacpp_args: Optional[str] = None, + prompt: bool = True, ) -> Dict[str, Any]: """ Load a model on the server. If auto_download is enabled and the model is not available: - 1. Prompts user for confirmation (with size and ETA) + 1. Prompts user for confirmation (with size and ETA) - unless prompt=False 2. Validates disk space 3. Downloads model with cancellation support 4. Retries loading @@ -2104,6 +2106,8 @@ def load_model( Large models can be 100GB+ and take hours to download llamacpp_args: Optional llama.cpp arguments (e.g., "--ubatch-size 2048"). Used to configure model loading parameters like batch sizes. + prompt: If True, prompt user before downloading (default: True). + Set to False to download automatically without user confirmation. Returns: Dict containing the status of the load operation @@ -2181,10 +2185,21 @@ def load_model( size_gb = model_info["size_gb"] estimated_minutes = self._estimate_download_time(size_gb) - # Prompt user for confirmation - if not _prompt_user_for_download(model_name, size_gb, estimated_minutes): - raise ModelDownloadCancelledError( - f"User declined download of {model_name}" + # Prompt user for confirmation (if prompt=True) + if prompt: + if not _prompt_user_for_download( + model_name, size_gb, estimated_minutes + ): + raise ModelDownloadCancelledError( + f"User declined download of {model_name}" + ) + else: + # Log the download info without prompting + self.log.info( + f" {_emoji('📦', '[SIZE]')} Model size: {size_gb:.1f} GB" + ) + self.log.info( + f" {_emoji('⏱️', '[ETA]')} Estimated time: ~{estimated_minutes} minutes" ) # Validate disk space diff --git a/src/gaia/llm/llm_client.py b/src/gaia/llm/llm_client.py deleted file mode 100644 index 951e2b08e..000000000 --- a/src/gaia/llm/llm_client.py +++ /dev/null @@ -1,723 +0,0 @@ -# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: MIT - -# Standard library imports -import logging -import os -import time -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Literal, - Optional, - TypeVar, - Union, -) - -import httpx - -# Third-party imports -import requests -from dotenv import load_dotenv -from openai import OpenAI - -from ..version import LEMONADE_VERSION - -# Local imports -from .lemonade_client import DEFAULT_MODEL_NAME - -# Default Lemonade server URL (can be overridden via LEMONADE_BASE_URL env var) -DEFAULT_LEMONADE_URL = "http://localhost:8000/api/v1" - -# Type variable for retry decorator -T = TypeVar("T") - -# Conditional import for Claude -try: - from ..eval.claude import ClaudeClient as AnthropicClaudeClient - - CLAUDE_AVAILABLE = True -except ImportError: - CLAUDE_AVAILABLE = False - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) # Explicitly set module logger level - -# Load environment variables from .env file -load_dotenv() - - -class LLMClient: - def __init__( - self, - use_claude: bool = False, - use_openai: bool = False, - system_prompt: Optional[str] = None, - base_url: Optional[str] = None, - claude_model: str = "claude-sonnet-4-20250514", - max_retries: int = 3, - retry_base_delay: float = 1.0, - ): - """ - Initialize the LLM client. - - Args: - use_claude: If True, uses Anthropic Claude API. - use_openai: If True, uses OpenAI ChatGPT API. - system_prompt: Default system prompt to use for all generation requests. - base_url: Base URL for local LLM server (defaults to LEMONADE_BASE_URL env var). - claude_model: Claude model to use (e.g., "claude-sonnet-4-20250514"). - max_retries: Maximum number of retry attempts on connection errors. - retry_base_delay: Base delay in seconds for exponential backoff. - - Note: Uses local LLM server by default unless use_claude or use_openai is True. - Context size is configured when starting the Lemonade server with --ctx-size parameter. - """ - # Use provided base_url, fall back to env var, then default - if base_url is None: - base_url = os.getenv("LEMONADE_BASE_URL", DEFAULT_LEMONADE_URL) - - # Normalize base_url to ensure it has the /api/v1 suffix for Lemonade server - # This allows users to specify just "http://localhost:8000" for convenience - if base_url and not base_url.endswith("/api/v1"): - # Remove trailing slash if present - base_url = base_url.rstrip("/") - # Add /api/v1 if the URL looks like a Lemonade server (localhost or IP with port) - # but doesn't already have a path beyond the port - from urllib.parse import urlparse - - parsed = urlparse(base_url) - # Only add /api/v1 if path is empty or just "/" - if not parsed.path or parsed.path == "/": - base_url = f"{base_url}/api/v1" - logger.debug(f"Normalized base_url to: {base_url}") - - # Compute use_local: True if neither claude nor openai is selected - use_local = not (use_claude or use_openai) - - logger.debug( - f"Initializing LLMClient with use_local={use_local}, use_claude={use_claude}, use_openai={use_openai}, base_url={base_url}" - ) - - self.use_claude = use_claude - self.use_openai = use_openai - self.base_url = base_url - self.system_prompt = system_prompt - self.max_retries = max_retries - self.retry_base_delay = retry_base_delay - - if use_local: - # Configure timeout for local LLM server - # For streaming: timeout between chunks (read timeout) - # For non-streaming: total timeout for the entire response - self.client = OpenAI( - base_url=base_url, - api_key="None", - timeout=httpx.Timeout( - connect=15.0, # 15 seconds to establish connection - read=120.0, # 120 seconds between data chunks (matches Lemonade DEFAULT_REQUEST_TIMEOUT) - write=15.0, # 15 seconds to send request - pool=15.0, # 15 seconds to acquire connection from pool - ), - max_retries=0, # Disable retries to fail fast on connection issues - ) - # Use completions endpoint for pre-formatted prompts (ChatSDK compatibility) - # Use chat endpoint when messages array is explicitly provided - self.endpoint = "completions" - logger.debug("Using Lemonade completions endpoint") - self.default_model = DEFAULT_MODEL_NAME - self.claude_client = None - logger.debug(f"Using local LLM with model={self.default_model}") - elif use_claude and CLAUDE_AVAILABLE: - # Use Claude API - self.claude_client = AnthropicClaudeClient(model=claude_model) - self.client = None - self.endpoint = "claude" - self.default_model = claude_model - logger.debug(f"Using Claude API with model={self.default_model}") - elif use_claude and not CLAUDE_AVAILABLE: - raise ValueError( - "Claude support requested but anthropic library not available. Install with: uv pip install anthropic" - ) - elif use_openai: - # Use OpenAI API - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError( - "OPENAI_API_KEY not found in environment variables. Please add it to your .env file." - ) - self.client = OpenAI(api_key=api_key) - self.claude_client = None - self.endpoint = "openai" - self.default_model = "gpt-4o" # Updated to latest model - logger.debug(f"Using OpenAI API with model={self.default_model}") - else: - # This should not happen with the new logic, but keep as fallback - raise ValueError("Invalid LLM provider configuration") - if system_prompt: - logger.debug(f"System prompt set: {system_prompt[:100]}...") - - def _retry_with_exponential_backoff( - self, - func: Callable[..., T], - *args, - **kwargs, - ) -> T: - """ - Execute a function with exponential backoff retry on connection errors. - - Args: - func: The function to execute - *args: Positional arguments for the function - **kwargs: Keyword arguments for the function - - Returns: - The result of the function call - - Raises: - The last exception if all retries are exhausted - """ - delay = self.retry_base_delay - max_delay = 60.0 - exponential_base = 2.0 - - for attempt in range(self.max_retries + 1): - try: - return func(*args, **kwargs) - except ( - ConnectionError, - httpx.ConnectError, - httpx.TimeoutException, - httpx.NetworkError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - ) as e: - if attempt == self.max_retries: - logger.error( - f"Max retries ({self.max_retries}) reached for {func.__name__}. " - f"Last error: {str(e)}" - ) - raise - - # Calculate next delay with exponential backoff - wait_time = min(delay, max_delay) - logger.warning( - f"Connection error in {func.__name__} (attempt {attempt + 1}/{self.max_retries + 1}): {str(e)}. " - f"Retrying in {wait_time:.1f}s..." - ) - - time.sleep(wait_time) - delay *= exponential_base - - def generate( - self, - prompt: str, - model: Optional[str] = None, - endpoint: Optional[Literal["completions", "chat", "claude", "openai"]] = None, - system_prompt: Optional[str] = None, - stream: bool = False, - messages: Optional[List[Dict[str, str]]] = None, - **kwargs: Any, - ) -> Union[str, Iterator[str]]: - """ - Generate a response from the LLM. - - Args: - prompt: The user prompt/query to send to the LLM. For chat endpoint, - if messages is not provided, this is treated as a pre-formatted - prompt string that already contains the full conversation. - model: The model to use (defaults to endpoint-appropriate model) - endpoint: Override the endpoint to use (completions, chat, claude, or openai) - system_prompt: System prompt to use for this specific request (overrides default) - stream: If True, returns a generator that yields chunks of the response as they become available - messages: Optional list of message dicts with 'role' and 'content' keys. - If provided, these are used directly for chat completions instead of prompt. - **kwargs: Additional parameters to pass to the API - - Returns: - If stream=False: The complete generated text as a string - If stream=True: A generator yielding chunks of the response as they become available - """ - model = model or self.default_model - endpoint_to_use = endpoint or self.endpoint - logger.debug( - f"LLMClient.generate() called with model={model}, endpoint={endpoint_to_use}, stream={stream}" - ) - - # Use provided system_prompt, fall back to instance default if not provided - effective_system_prompt = ( - system_prompt if system_prompt is not None else self.system_prompt - ) - logger.debug( - f"Using system prompt: {effective_system_prompt[:100] if effective_system_prompt else 'None'}..." - ) - - if endpoint_to_use == "claude": - # For Claude API, construct the prompt appropriately - if effective_system_prompt: - # Claude handles system prompts differently in messages format - full_prompt = f"System: {effective_system_prompt}\n\nHuman: {prompt}" - else: - full_prompt = prompt - - logger.debug(f"Using Claude API with prompt: {full_prompt[:200]}...") - - try: - if stream: - logger.warning( - "Streaming not yet implemented for Claude API, falling back to non-streaming" - ) - - # Use Claude client with retry logic - logger.debug("Making request to Claude API") - - # Use retry logic for the API call - result = self._retry_with_exponential_backoff( - self.claude_client.get_completion, full_prompt - ) - - # Claude returns a list of content blocks, extract text - if isinstance(result, list) and len(result) > 0: - # Each content block has a 'text' attribute - text_parts = [] - for content_block in result: - if hasattr(content_block, "text"): - text_parts.append(content_block.text) - else: - text_parts.append(str(content_block)) - result = "".join(text_parts) - elif isinstance(result, str): - pass # result is already a string - else: - result = str(result) - - # Check for empty responses - if not result or not result.strip(): - logger.warning("Empty response from Claude API") - - # Debug: log the response structure for troubleshooting - logger.debug(f"Claude response length: {len(result)}") - logger.debug(f"Claude response preview: {result[:300]}...") - - # Claude sometimes returns valid JSON followed by additional text - # Try to extract just the JSON part if it exists - result = self._clean_claude_response(result) - - return result - except Exception as e: - logger.error(f"Error generating response from Claude API: {str(e)}") - raise - elif endpoint_to_use == "completions": - # For local LLM with pre-formatted prompts (ChatSDK uses this) - # The prompt already contains the full formatted conversation - logger.debug( - f"Using completions endpoint: prompt_length={len(prompt)} chars" - ) - - try: - # Use retry logic for the API call - response = self._retry_with_exponential_backoff( - self.client.completions.create, - model=model, - prompt=prompt, - temperature=0.1, - stream=stream, - **kwargs, - ) - - if stream: - # Return a generator that yields chunks - def stream_generator(): - for chunk in response: - if ( - hasattr(chunk.choices[0], "text") - and chunk.choices[0].text - ): - yield chunk.choices[0].text - - return stream_generator() - else: - # Return the complete response - result = response.choices[0].text - if not result or not result.strip(): - logger.warning("Empty response from local LLM") - return result - except ( - httpx.ConnectError, - httpx.TimeoutException, - httpx.NetworkError, - ) as e: - logger.error(f"Network error connecting to local LLM server: {str(e)}") - error_msg = f"LLM Server Connection Error: {str(e)}" - raise ConnectionError(error_msg) from e - except Exception as e: - error_str = str(e) - logger.error(f"Error generating response from local LLM: {error_str}") - - if "404" in error_str: - if ( - "endpoint" in error_str.lower() - or "not found" in error_str.lower() - ): - raise ConnectionError( - f"API endpoint error: {error_str}\n\n" - f"This may indicate:\n" - f" 1. Lemonade Server version mismatch (try updating to {LEMONADE_VERSION})\n" - f" 2. Model not properly loaded or corrupted\n" - ) from e - - if "network" in error_str.lower() or "connection" in error_str.lower(): - raise ConnectionError(f"LLM Server Error: {error_str}") from e - raise - elif endpoint_to_use == "chat": - # For local LLM using chat completions format (Lemonade v9+) - if messages: - # Use provided messages directly (proper chat history support) - chat_messages = list(messages) - # Prepend system prompt if provided and not already in messages - if effective_system_prompt and ( - not chat_messages or chat_messages[0].get("role") != "system" - ): - chat_messages.insert( - 0, {"role": "system", "content": effective_system_prompt} - ) - else: - # Treat prompt as pre-formatted string (legacy ChatSDK support) - # Pass as single user message - the prompt already contains formatted history - chat_messages = [] - if effective_system_prompt: - chat_messages.append( - {"role": "system", "content": effective_system_prompt} - ) - chat_messages.append({"role": "user", "content": prompt}) - logger.debug( - f"Using chat completions for local LLM: {len(chat_messages)} messages" - ) - - try: - # Use retry logic for the API call - response = self._retry_with_exponential_backoff( - self.client.chat.completions.create, - model=model, - messages=chat_messages, - temperature=0.1, - stream=stream, - **kwargs, - ) - - if stream: - # Return a generator that yields chunks - def stream_generator(): - for chunk in response: - if ( - hasattr(chunk.choices[0].delta, "content") - and chunk.choices[0].delta.content - ): - yield chunk.choices[0].delta.content - - return stream_generator() - else: - # Return the complete response - result = response.choices[0].message.content - if not result or not result.strip(): - logger.warning("Empty response from local LLM") - return result - except ( - httpx.ConnectError, - httpx.TimeoutException, - httpx.NetworkError, - ) as e: - logger.error(f"Network error connecting to local LLM server: {str(e)}") - error_msg = f"LLM Server Connection Error: {str(e)}" - raise ConnectionError(error_msg) from e - except Exception as e: - error_str = str(e) - logger.error(f"Error generating response from local LLM: {error_str}") - - # Check for 404 errors which might indicate endpoint or model issues - if "404" in error_str: - if ( - "endpoint" in error_str.lower() - or "not found" in error_str.lower() - ): - raise ConnectionError( - f"API endpoint error: {error_str}\n\n" - f"This may indicate:\n" - f" 1. Lemonade Server version mismatch (try updating to {LEMONADE_VERSION})\n" - f" 2. Model not properly loaded or corrupted\n" - ) from e - - if "network" in error_str.lower() or "connection" in error_str.lower(): - raise ConnectionError(f"LLM Server Error: {error_str}") from e - raise - elif endpoint_to_use == "openai": - # For OpenAI API, use the messages format - messages = [] - if effective_system_prompt: - messages.append({"role": "system", "content": effective_system_prompt}) - messages.append({"role": "user", "content": prompt}) - logger.debug(f"OpenAI API messages: {messages}") - - try: - # Use retry logic for the API call - response = self._retry_with_exponential_backoff( - self.client.chat.completions.create, - model=model, - messages=messages, - stream=stream, - **kwargs, - ) - - if stream: - # Return a generator that yields chunks - def stream_generator(): - for chunk in response: - if ( - hasattr(chunk.choices[0].delta, "content") - and chunk.choices[0].delta.content - ): - yield chunk.choices[0].delta.content - - return stream_generator() - else: - # Return the complete response as before - result = response.choices[0].message.content - logger.debug(f"OpenAI API response: {result[:200]}...") - return result - except Exception as e: - logger.error(f"Error generating response from OpenAI API: {str(e)}") - raise - else: - raise ValueError( - f"Unsupported endpoint: {endpoint_to_use}. Supported endpoints: 'completions', 'chat', 'claude', 'openai'." - ) - - def get_performance_stats(self) -> Dict[str, Any]: - """ - Get performance statistics from the last LLM request. - - Returns: - Dictionary containing performance statistics like: - - time_to_first_token: Time in seconds until first token is generated - - tokens_per_second: Rate of token generation - - input_tokens: Number of tokens in the input - - output_tokens: Number of tokens in the output - """ - if not self.base_url: - # Return empty stats if not using local LLM - return { - "time_to_first_token": None, - "tokens_per_second": None, - "input_tokens": None, - "output_tokens": None, - } - - try: - # Use the Lemonade API v1 stats endpoint - # This returns both timing stats and token counts - stats_url = f"{self.base_url}/stats" - response = requests.get(stats_url) - - if response.status_code == 200: - stats = response.json() - # Remove decode_token_times as it's too verbose - if "decode_token_times" in stats: - del stats["decode_token_times"] - return stats - else: - logger.warning( - f"Failed to get stats: {response.status_code} - {response.text}" - ) - return {} - except Exception as e: - logger.warning(f"Error fetching performance stats: {str(e)}") - return {} - - def is_generating(self) -> bool: - """ - Check if the local LLM is currently generating. - - Returns: - bool: True if generating, False otherwise - - Note: - Only available when using local LLM (use_local=True). - Returns False for OpenAI API usage. - """ - if not self.base_url: - logger.debug("is_generating(): Not using local LLM, returning False") - return False - - try: - # Check the generating endpoint - # Remove /api/v1 suffix to access root-level endpoints - base = self.base_url.replace("/api/v1", "") - generating_url = f"{base}/generating" - response = requests.get(generating_url) - if response.status_code == 200: - response_data = response.json() - is_gen = response_data.get("is_generating", False) - logger.debug(f"Generation status check: {is_gen}") - return is_gen - else: - logger.warning( - f"Failed to check generation status: {response.status_code} - {response.text}" - ) - return False - except Exception as e: - logger.warning(f"Error checking generation status: {str(e)}") - return False - - def halt_generation(self) -> bool: - """ - Halt current generation on the local LLM server. - - Returns: - bool: True if halt was successful, False otherwise - - Note: - Only available when using local LLM (use_local=True). - Does nothing for OpenAI API usage. - """ - if not self.base_url: - logger.debug("halt_generation(): Not using local LLM, nothing to halt") - return False - - try: - # Send halt request - # Remove /api/v1 suffix to access root-level endpoints - base = self.base_url.replace("/api/v1", "") - halt_url = f"{base}/halt" - response = requests.get(halt_url) - if response.status_code == 200: - logger.debug("Successfully halted current generation") - return True - else: - logger.warning( - f"Failed to halt generation: {response.status_code} - {response.text}" - ) - return False - except Exception as e: - logger.warning(f"Error halting generation: {str(e)}") - return False - - def _clean_claude_response(self, response: str) -> str: - """ - Extract valid JSON from Claude responses that may contain extra content after the JSON. - - Args: - response: The raw response from Claude API - - Returns: - Cleaned response with only the JSON portion - """ - import json - - if not response or not response.strip(): - return response - - # Try to parse as-is first - try: - json.loads(response.strip()) - return response.strip() - except json.JSONDecodeError: - pass - - # Look for JSON object patterns - # Find the first { and try to extract a complete JSON object - start_idx = response.find("{") - if start_idx == -1: - # No JSON object found, return as-is - return response - - # Find the matching closing brace by counting braces - brace_count = 0 - end_idx = -1 - - for i in range(start_idx, len(response)): - char = response[i] - if char == "{": - brace_count += 1 - elif char == "}": - brace_count -= 1 - if brace_count == 0: - end_idx = i - break - - if end_idx == -1: - # No complete JSON object found - return response - - # Extract the JSON portion - json_portion = response[start_idx : end_idx + 1] - - # Validate that it's valid JSON - try: - json.loads(json_portion) - logger.debug( - f"Extracted JSON from Claude response: {len(json_portion)} chars vs original {len(response)} chars" - ) - return json_portion - except json.JSONDecodeError: - # If extracted portion is not valid JSON, return original - logger.debug( - "Could not extract valid JSON from Claude response, returning original" - ) - return response - - -def main(): - # Example usage with local LLM - system_prompt = "You are a creative assistant who specializes in short stories." - - local_llm = LLMClient(system_prompt=system_prompt) - - # Non-streaming example - result = local_llm.generate("Write a one-sentence bedtime story about a unicorn.") - print(f"Local LLM response:\n{result}") - print(f"Local LLM stats:\n{local_llm.get_performance_stats()}") - - # Halt functionality demo (only for local LLM) - print(f"\nHalt functionality available: {local_llm.is_generating()}") - - # Streaming example - print("\nLocal LLM streaming response:") - for chunk in local_llm.generate( - "Write a one-sentence bedtime story about a dragon.", stream=True - ): - print(chunk, end="", flush=True) - print("\n") - - # Example usage with Claude API - if CLAUDE_AVAILABLE: - claude_llm = LLMClient(use_claude=True, system_prompt=system_prompt) - - # Non-streaming example - result = claude_llm.generate( - "Write a one-sentence bedtime story about a unicorn." - ) - print(f"\nClaude API response:\n{result}") - - # Example usage with OpenAI API - openai_llm = LLMClient(use_openai=True, system_prompt=system_prompt) - - # Non-streaming example - result = openai_llm.generate("Write a one-sentence bedtime story about a unicorn.") - print(f"\nOpenAI API response:\n{result}") - - # Streaming example - print("\nOpenAI API streaming response:") - for chunk in openai_llm.generate( - "Write a one-sentence bedtime story about a dragon.", stream=True - ): - print(chunk, end="", flush=True) - print("\n") - - -if __name__ == "__main__": - main() diff --git a/src/gaia/llm/providers/__init__.py b/src/gaia/llm/providers/__init__.py new file mode 100644 index 000000000..12a5ee3eb --- /dev/null +++ b/src/gaia/llm/providers/__init__.py @@ -0,0 +1,9 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""LLM provider implementations.""" + +from .claude import ClaudeProvider +from .lemonade import LemonadeProvider +from .openai_provider import OpenAIProvider + +__all__ = ["ClaudeProvider", "LemonadeProvider", "OpenAIProvider"] diff --git a/src/gaia/llm/providers/claude.py b/src/gaia/llm/providers/claude.py new file mode 100644 index 000000000..874116218 --- /dev/null +++ b/src/gaia/llm/providers/claude.py @@ -0,0 +1,108 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Claude provider - no embeddings support.""" + +from typing import Iterator, Optional, Union + +try: + import anthropic +except ImportError: + anthropic = None # type: ignore + +from ..base_client import LLMClient + + +class ClaudeProvider(LLMClient): + """Claude (Anthropic) provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "claude-3-5-sonnet-20241022", + system_prompt: Optional[str] = None, + **_kwargs, + ): + if anthropic is None: + raise ImportError( + "anthropic package is required for ClaudeProvider. " + "Install it with: pip install anthropic" + ) + + self._client = anthropic.Anthropic(api_key=api_key) + self._model = model + self._system_prompt = system_prompt + + @property + def provider_name(self) -> str: + return "Claude" + + def generate( + self, + prompt: str, + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + return self.chat( + [{"role": "user", "content": prompt}], + model=model, + stream=stream, + **kwargs, + ) + + def chat( + self, + messages: list[dict], + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + # Build parameters for Anthropic messages.create + params = { + "model": model or self._model, + "messages": messages, + "stream": stream, + **kwargs, + } + # Claude API requires system prompt as separate parameter, not in messages + if self._system_prompt: + params["system"] = self._system_prompt + + response = self._client.messages.create(**params) + if stream: + return self._handle_stream(response) + return response.content[0].text + + # embed() inherited from ABC - raises NotSupportedError + + def vision(self, images: list[bytes], prompt: str, **kwargs) -> str: + import base64 + + # Claude supports vision via messages + image_b64 = base64.b64encode(images[0]).decode() + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_b64, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ] + return self.chat(messages, **kwargs) + + # get_performance_stats() inherited from ABC - raises NotSupportedError + # load_model() inherited from ABC - raises NotSupportedError + # unload_model() inherited from ABC - raises NotSupportedError + + def _handle_stream(self, response) -> Iterator[str]: + for chunk in response: + if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): + yield chunk.delta.text diff --git a/src/gaia/llm/providers/lemonade.py b/src/gaia/llm/providers/lemonade.py new file mode 100644 index 000000000..4bd2d6af0 --- /dev/null +++ b/src/gaia/llm/providers/lemonade.py @@ -0,0 +1,120 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Lemonade provider - supports ALL methods.""" + +from typing import Iterator, Optional, Union + +from ..base_client import LLMClient +from ..lemonade_client import DEFAULT_MODEL_NAME, LemonadeClient + + +class LemonadeProvider(LLMClient): + """Lemonade provider - local AMD-optimized inference.""" + + def __init__( + self, + model: Optional[str] = None, + base_url: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + system_prompt: Optional[str] = None, + **kwargs, + ): + # Build kwargs for LemonadeClient, only including non-None values + backend_kwargs = {} + if model is not None: + backend_kwargs["model"] = model + if base_url is not None: + backend_kwargs["base_url"] = base_url + if host is not None: + backend_kwargs["host"] = host + if port is not None: + backend_kwargs["port"] = port + backend_kwargs.update(kwargs) + + self._backend = LemonadeClient(**backend_kwargs) + self._model = model + self._system_prompt = system_prompt + + @property + def provider_name(self) -> str: + return "Lemonade" + + def generate( + self, + prompt: str, + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + # Use provided model, instance model, or default CPU model + effective_model = model or self._model or DEFAULT_MODEL_NAME + + # Default to low temperature for deterministic responses (matches old LLMClient behavior) + kwargs.setdefault("temperature", 0.1) + + response = self._backend.completions( + model=effective_model, prompt=prompt, stream=stream, **kwargs + ) + if stream: + return self._handle_stream(response) + return self._extract_text(response) + + def chat( + self, + messages: list[dict], + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + # Use provided model, instance model, or default CPU model + effective_model = model or self._model or DEFAULT_MODEL_NAME + + # Prepend system prompt if set + if self._system_prompt: + messages = [{"role": "system", "content": self._system_prompt}] + list( + messages + ) + + # Default to low temperature for deterministic responses (matches old LLMClient behavior) + kwargs.setdefault("temperature", 0.1) + + response = self._backend.chat_completions( + model=effective_model, messages=messages, stream=stream, **kwargs + ) + if stream: + return self._handle_stream(response) + return response["choices"][0]["message"]["content"] + + def embed(self, texts: list[str], **kwargs) -> list[list[float]]: + response = self._backend.embeddings(texts, **kwargs) + return [item["embedding"] for item in response["data"]] + + def vision(self, images: list[bytes], prompt: str, **kwargs) -> str: + # Delegate to VLMClient + from ..vlm_client import VLMClient + + vlm = VLMClient(base_url=self._backend.base_url) + return vlm.extract_from_image(images[0], prompt=prompt) + + def get_performance_stats(self) -> dict: + return self._backend.get_stats() or {} + + def load_model(self, model_name: str, **kwargs) -> None: + self._backend.load_model(model_name, **kwargs) + self._model = model_name + + def unload_model(self) -> None: + self._backend.unload_model() + + def _extract_text(self, response: dict) -> str: + return response["choices"][0]["text"] + + def _handle_stream(self, response) -> Iterator[str]: + for chunk in response: + if "choices" in chunk and chunk["choices"]: + delta = chunk["choices"][0].get("delta", {}) + if "content" in delta: + yield delta["content"] + elif "text" in chunk["choices"][0]: + yield chunk["choices"][0]["text"] diff --git a/src/gaia/llm/providers/openai_provider.py b/src/gaia/llm/providers/openai_provider.py new file mode 100644 index 000000000..ab204153a --- /dev/null +++ b/src/gaia/llm/providers/openai_provider.py @@ -0,0 +1,79 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""OpenAI provider - no vision support.""" + +from typing import Iterator, Optional, Union + +from ..base_client import LLMClient + + +class OpenAIProvider(LLMClient): + """OpenAI (OpenAI API) provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-4o", + system_prompt: Optional[str] = None, + **_kwargs, + ): + import openai + + self._client = openai.OpenAI(api_key=api_key) + self._model = model + self._system_prompt = system_prompt + + @property + def provider_name(self) -> str: + return "OpenAI" + + def generate( + self, + prompt: str, + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + # OpenAI doesn't have a separate completions endpoint for chat models + return self.chat( + [{"role": "user", "content": prompt}], + model=model, + stream=stream, + **kwargs, + ) + + def chat( + self, + messages: list[dict], + model: str | None = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: + # Prepend system prompt if set + if self._system_prompt: + messages = [{"role": "system", "content": self._system_prompt}] + list( + messages + ) + + response = self._client.chat.completions.create( + model=model or self._model, messages=messages, stream=stream, **kwargs + ) + if stream: + return self._handle_stream(response) + return response.choices[0].message.content + + def embed( + self, texts: list[str], model: str = "text-embedding-3-small", **kwargs + ) -> list[list[float]]: + response = self._client.embeddings.create(model=model, input=texts, **kwargs) + return [item.embedding for item in response.data] + + # vision() inherited from ABC - raises NotSupportedError + # get_performance_stats() inherited from ABC - raises NotSupportedError + # load_model() inherited from ABC - raises NotSupportedError + # unload_model() inherited from ABC - raises NotSupportedError + + def _handle_stream(self, response) -> Iterator[str]: + for chunk in response: + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content diff --git a/src/gaia/llm/vlm_client.py b/src/gaia/llm/vlm_client.py index 064e728ec..9f1f5a127 100644 --- a/src/gaia/llm/vlm_client.py +++ b/src/gaia/llm/vlm_client.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT """ diff --git a/src/gaia/mcp/mcp_bridge.py b/src/gaia/mcp/mcp_bridge.py index 05d98c808..492cf838d 100644 --- a/src/gaia/mcp/mcp_bridge.py +++ b/src/gaia/mcp/mcp_bridge.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT """ @@ -21,7 +21,7 @@ ) from gaia.agents.blender.agent import BlenderAgent -from gaia.llm.llm_client import LLMClient +from gaia.llm import create_client from gaia.logger import get_logger logger = get_logger(__name__) @@ -221,7 +221,7 @@ def _execute_jira(self, args: Dict[str, Any]) -> Dict[str, Any]: def _execute_query(self, args: Dict[str, Any]) -> Dict[str, Any]: """Execute LLM query.""" if not self.llm_client: - self.llm_client = LLMClient(base_url=self.base_url) + self.llm_client = create_client("lemonade", base_url=self.base_url) response = self.llm_client.generate( prompt=args.get("query", ""), diff --git a/tests/test_rag.py b/tests/test_rag.py index 587432a7c..f6e5c18bd 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -399,7 +399,7 @@ def mock_chat_dependencies(self): with ( patch("gaia.llm.vlm_client.VLMClient") as mock_vlm_class, patch("gaia.llm.lemonade_client.LemonadeClient") as mock_lemonade, - patch("gaia.chat.sdk.LLMClient") as mock_llm, + patch("gaia.chat.sdk.create_client") as mock_create_client, patch("gaia.rag.sdk.RAGSDK") as mock_rag_class, ): @@ -416,9 +416,9 @@ def mock_chat_dependencies(self): } mock_lemonade.return_value = mock_lemonade_instance - # Mock LLM client + # Mock LLM client factory - create_client() returns mock instance mock_llm_instance = Mock() - mock_llm.return_value = mock_llm_instance + mock_create_client.return_value = mock_llm_instance # Mock RAG SDK mock_rag = Mock() diff --git a/tests/test_sdk.py b/tests/test_sdk.py index bf6ab936a..1200e87d3 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -410,10 +410,10 @@ def test_quick_rag_exists(self): class TestLLMClient: """Test LLMClient interface.""" - @patch("gaia.llm.llm_client.LLMClient.__init__") + @patch("gaia.llm.LLMClient.__init__") def test_llm_client_can_be_imported(self, mock_init): """Verify LLMClient can be imported.""" - from gaia.llm.llm_client import LLMClient + from gaia.llm import LLMClient mock_init.return_value = None client = LLMClient.__new__(LLMClient) @@ -421,13 +421,18 @@ def test_llm_client_can_be_imported(self, mock_init): def test_llm_client_interface_methods(self): """Verify LLMClient has required methods.""" - from gaia.llm.llm_client import LLMClient + from gaia.llm import LLMClient - # Check methods exist + # Check abstract methods exist assert hasattr(LLMClient, "generate") - assert hasattr(LLMClient, "chat_completions") - assert hasattr(LLMClient, "get_available_models") - assert hasattr(LLMClient, "estimate_tokens") + assert hasattr(LLMClient, "chat") + assert hasattr(LLMClient, "provider_name") + # Check optional methods exist + assert hasattr(LLMClient, "embed") + assert hasattr(LLMClient, "vision") + assert hasattr(LLMClient, "get_performance_stats") + assert hasattr(LLMClient, "load_model") + assert hasattr(LLMClient, "unload_model") def test_lemonade_constants_exist(self): """Verify Lemonade client constants.""" @@ -853,7 +858,7 @@ def test_all_imports_in_sdk_are_valid(self): # LLM try: - from gaia.llm.llm_client import LLMClient # noqa: F401 + from gaia.llm import LLMClient # noqa: F401 from gaia.llm.vlm_client import VLMClient # noqa: F401 except ImportError as e: pytest.fail(f"LLM import failed: {e}") @@ -1502,7 +1507,7 @@ def test_llm_app_exists(self): assert LlmApp is not None except ImportError: # LLM app may be in different location - from gaia.llm.llm_client import LLMClient + from gaia.llm import LLMClient assert LLMClient is not None diff --git a/tests/unit/test_lemonade_model_loading.py b/tests/unit/test_lemonade_model_loading.py new file mode 100644 index 000000000..0d166ecec --- /dev/null +++ b/tests/unit/test_lemonade_model_loading.py @@ -0,0 +1,278 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for LemonadeClient model loading functionality.""" + +from unittest.mock import MagicMock, Mock, patch + +from gaia.llm.lemonade_client import LemonadeClient, LemonadeStatus + + +class TestEnsureModelLoaded: + """Test _ensure_model_loaded helper method.""" + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + def test_calls_load_when_model_not_loaded(self, mock_load, mock_status): + """Verify load_model is called when model not in loaded_models list.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + mock_status.return_value = LemonadeStatus( + url="http://localhost:8000", + running=True, + loaded_models=[{"id": "model-a"}], + ) + + # Execute + client._ensure_model_loaded("model-b", auto_download=True) + + # Verify - should call with prompt=False to skip user confirmation + mock_load.assert_called_once_with("model-b", auto_download=True, prompt=False) + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + def test_skips_load_when_model_already_loaded(self, mock_load, mock_status): + """Verify no load_model call when model already in loaded_models list.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + mock_status.return_value = LemonadeStatus( + url="http://localhost:8000", + running=True, + loaded_models=[{"id": "model-a"}], + ) + + # Execute + client._ensure_model_loaded("model-a", auto_download=True) + + # Verify - should NOT call load_model + mock_load.assert_not_called() + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + def test_skips_check_when_auto_download_disabled(self, mock_load, mock_status): + """Verify method returns early when auto_download=False.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + + # Execute + client._ensure_model_loaded("model-a", auto_download=False) + + # Verify - should NOT call get_status or load_model + mock_status.assert_not_called() + mock_load.assert_not_called() + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + def test_handles_status_check_error_gracefully(self, mock_load, mock_status): + """Verify errors during status check are logged but don't fail.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + mock_status.side_effect = Exception("Connection failed") + + # Execute - should not raise + client._ensure_model_loaded("model-a", auto_download=True) + + # Verify - load_model should not be called due to error + mock_load.assert_not_called() + + +class TestStreamCompletionsModelLoading: + """Test that _stream_completions_with_openai calls _ensure_model_loaded.""" + + @patch.object(LemonadeClient, "_ensure_model_loaded") + @patch("gaia.llm.lemonade_client.OpenAI") + def test_calls_ensure_model_loaded_before_request( + self, mock_openai_class, mock_ensure + ): + """Verify _ensure_model_loaded is called before making the API request.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + mock_openai_instance = MagicMock() + mock_openai_class.return_value = mock_openai_instance + + # Mock the streaming response + mock_chunk = Mock() + mock_chunk.model_dump.return_value = { + "id": "test", + "object": "text_completion", + "created": 12345, + "model": "test-model", + "choices": [{"index": 0, "text": "Hello", "finish_reason": None}], + } + mock_openai_instance.completions.create.return_value = iter([mock_chunk]) + + # Execute - consume the generator + list( + client._stream_completions_with_openai( + model="test-model", + prompt="test prompt", + auto_download=True, + ) + ) + + # Verify _ensure_model_loaded was called with correct arguments + mock_ensure.assert_called_once_with("test-model", True) + + # Verify it was called BEFORE the API request + assert mock_ensure.call_count == 1 + assert mock_openai_instance.completions.create.called + + +class TestStreamChatCompletionsModelLoading: + """Test that _stream_chat_completions_with_openai calls _ensure_model_loaded.""" + + @patch.object(LemonadeClient, "_ensure_model_loaded") + @patch("gaia.llm.lemonade_client.OpenAI") + def test_calls_ensure_model_loaded_before_request( + self, mock_openai_class, mock_ensure + ): + """Verify _ensure_model_loaded is called before making the API request.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + mock_openai_instance = MagicMock() + mock_openai_class.return_value = mock_openai_instance + + # Mock the streaming response + mock_chunk = Mock() + mock_chunk.id = "test-id" + mock_chunk.object = "chat.completion.chunk" + mock_chunk.created = 12345 + mock_chunk.model = "test-model" + + mock_choice = Mock() + mock_choice.index = 0 + mock_choice.delta = Mock() + mock_choice.delta.role = "assistant" + mock_choice.delta.content = "Hello" + mock_choice.finish_reason = None + mock_chunk.choices = [mock_choice] + + mock_openai_instance.chat.completions.create.return_value = iter([mock_chunk]) + + # Execute - consume the generator + list( + client._stream_chat_completions_with_openai( + model="test-model", + messages=[{"role": "user", "content": "test"}], + auto_download=True, + ) + ) + + # Verify _ensure_model_loaded was called with correct arguments + mock_ensure.assert_called_once_with("test-model", True) + + # Verify it was called BEFORE the API request + assert mock_ensure.call_count == 1 + assert mock_openai_instance.chat.completions.create.called + + +class TestNoPromptBehavior: + """Test that model downloads happen without prompting.""" + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + def test_ensure_model_loaded_passes_prompt_false(self, mock_load, mock_status): + """Verify _ensure_model_loaded passes prompt=False to avoid user prompts.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + mock_status.return_value = LemonadeStatus( + url="http://localhost:8000", + running=True, + loaded_models=[], # No models loaded + ) + + # Execute + client._ensure_model_loaded("new-model", auto_download=True) + + # Verify prompt=False is passed to skip user confirmation + assert mock_load.called + call_kwargs = mock_load.call_args.kwargs + assert "prompt" in call_kwargs + assert call_kwargs["prompt"] is False + + +class TestModelLoadingIntegration: + """Integration-style tests for model loading behavior.""" + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + @patch("gaia.llm.lemonade_client.OpenAI") + def test_model_loaded_when_not_present( + self, mock_openai_class, mock_load, mock_status + ): + """Integration test: model is loaded when not in loaded_models list.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + + # Mock status to show model NOT loaded + mock_status.return_value = LemonadeStatus( + url="http://localhost:8000", + running=True, + loaded_models=[{"id": "different-model"}], + ) + + # Mock OpenAI client + mock_openai_instance = MagicMock() + mock_openai_class.return_value = mock_openai_instance + mock_chunk = Mock() + mock_chunk.model_dump.return_value = { + "id": "test", + "object": "text_completion", + "created": 12345, + "model": "new-model", + "choices": [{"index": 0, "text": "Response", "finish_reason": None}], + } + mock_openai_instance.completions.create.return_value = iter([mock_chunk]) + + # Execute - consume the generator + list( + client._stream_completions_with_openai( + model="new-model", + prompt="test", + auto_download=True, + ) + ) + + # Verify load_model was called to download/load the model WITHOUT prompting + mock_load.assert_called_once_with("new-model", auto_download=True, prompt=False) + + @patch.object(LemonadeClient, "get_status") + @patch.object(LemonadeClient, "load_model") + @patch("gaia.llm.lemonade_client.OpenAI") + def test_model_not_loaded_when_already_present( + self, mock_openai_class, mock_load, mock_status + ): + """Integration test: no load when model already in loaded_models list.""" + # Setup + client = LemonadeClient(host="localhost", port=8000) + + # Mock status to show model IS loaded + mock_status.return_value = LemonadeStatus( + url="http://localhost:8000", + running=True, + loaded_models=[{"id": "existing-model"}], + ) + + # Mock OpenAI client + mock_openai_instance = MagicMock() + mock_openai_class.return_value = mock_openai_instance + mock_chunk = Mock() + mock_chunk.model_dump.return_value = { + "id": "test", + "object": "text_completion", + "created": 12345, + "model": "existing-model", + "choices": [{"index": 0, "text": "Response", "finish_reason": None}], + } + mock_openai_instance.completions.create.return_value = iter([mock_chunk]) + + # Execute - consume the generator + list( + client._stream_completions_with_openai( + model="existing-model", + prompt="test", + auto_download=True, + ) + ) + + # Verify load_model was NOT called (model already loaded) + mock_load.assert_not_called() diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index cc294357b..6b3429a88 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -14,60 +14,6 @@ # from gaia.logger import get_logger -class TestBaseUrlNormalization(unittest.TestCase): - """Test base_url normalization in LLMClient (fast unit tests, no server needed).""" - - def _run_normalization_test(self, input_url, expected_url): - """Run a base_url normalization test in a subprocess to avoid bytecode cache issues.""" - result = subprocess.run( - [ - sys.executable, - "-c", - f""" -import sys -sys.path.insert(0, "src") -from unittest.mock import patch, MagicMock -from gaia.llm.llm_client import LLMClient - -with patch("gaia.llm.llm_client.OpenAI", MagicMock()): - client = LLMClient(base_url="{input_url}") - print(client.base_url) -""", - ], - capture_output=True, - text=True, - timeout=30, - ) - actual = result.stdout.strip() - self.assertEqual( - actual, - expected_url, - f"Expected {expected_url}, got {actual}. stderr: {result.stderr}", - ) - - def test_base_url_normalization_adds_api_v1(self): - """Test that base_url without /api/v1 gets it appended.""" - # Test: URL without path should get /api/v1 appended - self._run_normalization_test( - "http://localhost:8000", "http://localhost:8000/api/v1" - ) - - # Test: URL with trailing slash should get /api/v1 appended - self._run_normalization_test( - "http://localhost:8000/", "http://localhost:8000/api/v1" - ) - - # Test: URL already with /api/v1 should remain unchanged - self._run_normalization_test( - "http://localhost:8000/api/v1", "http://localhost:8000/api/v1" - ) - - # Test: Custom port should work - self._run_normalization_test( - "http://192.168.1.100:9000", "http://192.168.1.100:9000/api/v1" - ) - - class TestLlmCli(unittest.TestCase): def setUp(self): # self.log = get_logger(__name__) diff --git a/tests/unit/test_llm_client_factory.py b/tests/unit/test_llm_client_factory.py new file mode 100644 index 000000000..0a7ac676b --- /dev/null +++ b/tests/unit/test_llm_client_factory.py @@ -0,0 +1,178 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""TDD tests for LLM client factory - write BEFORE implementation.""" + +from unittest.mock import patch + +import pytest + +# ============================================================================= +# Import Tests (will fail until modules exist) +# ============================================================================= + + +class TestImports: + def test_can_import_create_client(self): + from gaia.llm import create_client + + assert callable(create_client) + + def test_can_import_llm_client_abc(self): + from abc import ABC + + from gaia.llm import LLMClient + + assert issubclass(LLMClient, ABC) + + def test_can_import_not_supported_error(self): + from gaia.llm import NotSupportedError + + assert issubclass(NotSupportedError, Exception) + + +# ============================================================================= +# Factory Tests +# ============================================================================= + + +class TestCreateClientFactory: + def test_create_client_returns_lemonade_provider(self): + with patch("gaia.llm.providers.lemonade.LemonadeClient"): + from gaia.llm import create_client + + client = create_client("lemonade") + assert client.provider_name == "Lemonade" + + def test_create_client_invalid_provider_raises_valueerror(self): + from gaia.llm import create_client + + with pytest.raises(ValueError, match="Unknown provider"): + create_client("invalid_provider") + + def test_create_client_case_insensitive(self): + with patch("gaia.llm.providers.lemonade.LemonadeClient"): + from gaia.llm import create_client + + client = create_client("LEMONADE") + assert client.provider_name == "Lemonade" + + def test_create_client_passes_kwargs(self): + with patch("gaia.llm.providers.lemonade.LemonadeClient") as mock: + from gaia.llm import create_client + + create_client("lemonade", base_url="http://custom:9000", model="test") + mock.assert_called_with(base_url="http://custom:9000", model="test") + + +# ============================================================================= +# NotSupportedError Tests +# ============================================================================= + + +class TestNotSupportedError: + def test_error_includes_provider_name(self): + from gaia.llm import NotSupportedError + + error = NotSupportedError("Claude", "embed") + assert "Claude" in str(error) + + def test_error_includes_method_name(self): + from gaia.llm import NotSupportedError + + error = NotSupportedError("Claude", "embed") + assert "embed" in str(error) + + +class TestClaudeNotSupported: + def test_claude_embed_raises_not_supported(self): + with patch("gaia.llm.providers.claude.anthropic"): + from gaia.llm import NotSupportedError, create_client + + client = create_client("claude", api_key="test") + + with pytest.raises(NotSupportedError) as exc: + client.embed(["text"]) + assert "Claude" in str(exc.value) + assert "embed" in str(exc.value) + + def test_claude_load_model_raises_not_supported(self): + with patch("gaia.llm.providers.claude.anthropic"): + from gaia.llm import NotSupportedError, create_client + + client = create_client("claude", api_key="test") + + with pytest.raises(NotSupportedError): + client.load_model("some-model") + + +class TestOpenAINotSupported: + def test_openai_vision_raises_not_supported(self): + with patch("openai.OpenAI"): + from gaia.llm import NotSupportedError, create_client + + client = create_client("openai", api_key="test") + + with pytest.raises(NotSupportedError) as exc: + client.vision([b"image"], "describe this") + assert "OpenAI" in str(exc.value) + + +# ============================================================================= +# Provider Name Tests +# ============================================================================= + + +class TestProviderNames: + def test_lemonade_provider_name(self): + with patch("gaia.llm.providers.lemonade.LemonadeClient"): + from gaia.llm import create_client + + client = create_client("lemonade") + assert client.provider_name == "Lemonade" + + def test_openai_provider_name(self): + with patch("openai.OpenAI"): + from gaia.llm import create_client + + client = create_client("openai", api_key="test") + assert client.provider_name == "OpenAI" + + def test_claude_provider_name(self): + with patch("gaia.llm.providers.claude.anthropic"): + from gaia.llm import create_client + + client = create_client("claude", api_key="test") + assert client.provider_name == "Claude" + + +# ============================================================================= +# ABC Interface Tests +# ============================================================================= + + +class TestLLMClientABC: + def test_cannot_instantiate_abc(self): + from gaia.llm import LLMClient + + with pytest.raises(TypeError): + LLMClient() + + def test_abc_has_generate_method(self): + from gaia.llm import LLMClient + + assert hasattr(LLMClient, "generate") + + def test_abc_has_chat_method(self): + from gaia.llm import LLMClient + + assert hasattr(LLMClient, "chat") + + def test_abc_has_embed_method(self): + from gaia.llm import LLMClient + + assert hasattr(LLMClient, "embed") + + def test_abc_has_vision_method(self): + from gaia.llm import LLMClient + + assert hasattr(LLMClient, "vision") diff --git a/util/lint.py b/util/lint.py index 9dd38cb11..3cf07cdb9 100644 --- a/util/lint.py +++ b/util/lint.py @@ -277,7 +277,7 @@ def check_imports() -> CheckResult: imports = [ ("gaia.cli", "CLI module"), ("gaia.chat.sdk", "Chat SDK"), - ("gaia.llm.llm_client", "LLM client"), + ("gaia.llm", "LLM client"), ("gaia.agents.base.agent", "Base agent"), ]