Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/test_gaia_cli_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ jobs:
from gaia.cli import main
from gaia.version import version
from gaia.logger import get_logger
from gaia.llm.llm_client import LLMClient
from gaia.llm import LLMClient
print('✅ All core imports successful on Linux')
except ImportError as e:
print(f'❌ Import error: {e}')
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
packages=[
"gaia",
"gaia.llm",
"gaia.llm.providers",
"gaia.audio",
"gaia.chat",
"gaia.database",
Expand Down
12 changes: 5 additions & 7 deletions src/gaia/agents/blender/agent_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from typing import Any, Dict, Optional, Tuple

from gaia.llm.llm_client import LLMClient
from gaia.llm import create_client
from gaia.llm.base_client import LLMClient
from gaia.mcp.blender_mcp_client import MCPClient


Expand Down Expand Up @@ -39,7 +40,6 @@ def __init__(
self,
llm: Optional[LLMClient] = None,
mcp: Optional[MCPClient] = None,
use_local: bool = True,
base_url: Optional[str] = "http://localhost:8000/api/v1",
):
"""
Expand All @@ -48,15 +48,13 @@ def __init__(
Args:
llm: An optional pre-configured LLM client, otherwise a new one will be created
mcp: An optional pre-configured MCP client, otherwise a new one will be created
use_local: Whether to use a local LLM (True) or a remote API (False)
base_url: Base URL for the local LLM API if using local LLM. If None and use_local=True,
defaults to "http://localhost:8000/api/v1"
base_url: Base URL for the Lemonade LLM server
"""
self.llm = (
llm
if llm
else LLMClient(
use_local=use_local, system_prompt=self.SYSTEM_PROMPT, base_url=base_url
else create_client(
"lemonade", base_url=base_url, system_prompt=self.SYSTEM_PROMPT
)
Comment thread
itomek marked this conversation as resolved.
)
self.mcp = mcp if mcp else MCPClient()
Expand Down
2 changes: 1 addition & 1 deletion src/gaia/agents/blender/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from gaia.agents.base.console import AgentConsole
from gaia.agents.blender.agent import BlenderAgent
from gaia.llm.llm_client import LLMClient
from gaia.llm.base_client import LLMClient
from gaia.mcp.blender_mcp_client import MCPClient

# Set up logging
Expand Down
11 changes: 6 additions & 5 deletions src/gaia/agents/blender/tests/test_agent_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import pytest

from gaia.agents.blender.agent_simple import BlenderAgentSimple
from gaia.agents.blender.mcp.mcp_client import MCPClient
from gaia.llm.llm_client import LLMClient
from gaia.llm import create_client
from gaia.llm.base_client import LLMClient
from gaia.mcp.blender_mcp_client import MCPClient

# Set up logging
logging.basicConfig(level=logging.DEBUG)
Expand All @@ -34,7 +35,7 @@ def llm_client():
- For any other request, respond with: CYLINDER,0,2,0,0.5,0.5,3
"""
# Using local LLM for faster testing
return LLMClient(system_prompt=system_prompt)
return create_client("lemonade", system_prompt=system_prompt)


@pytest.fixture
Expand Down Expand Up @@ -166,8 +167,8 @@ def test_agent_initialization():
assert isinstance(agent.mcp, MCPClient)

# Verify default system prompt is set
logger.debug(f"System prompt: {agent.llm.system_prompt}")
assert agent.llm.system_prompt == agent.SYSTEM_PROMPT
logger.debug(f"System prompt: {agent.llm._system_prompt}")
assert agent.llm._system_prompt == agent.SYSTEM_PROMPT


if __name__ == "__main__":
Expand Down
13 changes: 5 additions & 8 deletions src/gaia/agents/routing/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional

from gaia.agents.base.agent import Agent
from gaia.llm.llm_client import LLMClient
from gaia.llm import create_client
from gaia.logger import get_logger

from .system_prompt import ROUTING_ANALYSIS_PROMPT
Expand Down Expand Up @@ -58,13 +58,10 @@ def __init__(
# Read from environment if not provided
base_url = os.getenv("LEMONADE_BASE_URL", "http://localhost:8000/api/v1")

llm_kwargs = {
"use_claude": use_claude,
"use_openai": use_chatgpt,
"base_url": base_url,
}

self.llm_client = LLMClient(**llm_kwargs)
# Initialize LLM client - factory auto-detects provider from flags
self.llm_client = create_client(
use_claude=use_claude, use_openai=use_chatgpt, base_url=base_url
)
self.agent_kwargs = agent_kwargs # Store for passing to created agents

# Model to use for routing analysis (configurable via env var)
Expand Down
10 changes: 7 additions & 3 deletions src/gaia/apps/llm/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved.
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

"""
Expand All @@ -11,7 +11,7 @@
import sys
from typing import Iterator, Optional, Union

from gaia.llm.llm_client import LLMClient
from gaia.llm import create_client
from gaia.logger import get_logger


Expand All @@ -28,7 +28,11 @@ def __init__(
base_url: Base URL for local LLM server (defaults to LEMONADE_BASE_URL env var)
"""
self.log = get_logger(__name__)
self.client = LLMClient(system_prompt=system_prompt, base_url=base_url)
self.client = create_client(
"lemonade",
base_url=base_url,
system_prompt=system_prompt,
)
self.log.debug("LLM app initialized")

def query(
Expand Down
8 changes: 4 additions & 4 deletions src/gaia/audio/audio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
import time

from gaia.llm.llm_client import LLMClient
from gaia.llm import create_client
from gaia.logger import get_logger


Expand Down Expand Up @@ -40,10 +40,10 @@ def __init__(
self.transcription_queue = queue.Queue()
self.tts = None

# Initialize LLM client (base_url handled automatically)
self.llm_client = LLMClient(
# Initialize LLM client - factory auto-detects provider from flags
self.llm_client = create_client(
use_claude=use_claude,
use_openai=use_chatgpt, # LLMClient uses use_openai, not use_chatgpt
use_openai=use_chatgpt,
system_prompt=system_prompt,
)

Expand Down
22 changes: 10 additions & 12 deletions src/gaia/chat/sdk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved.
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

"""
Expand All @@ -13,8 +13,8 @@
from typing import Any, Dict, List, Optional

from gaia.chat.prompts import Prompts
from gaia.llm import create_client
from gaia.llm.lemonade_client import DEFAULT_MODEL_NAME
from gaia.llm.llm_client import LLMClient
from gaia.logger import get_logger


Expand Down Expand Up @@ -87,19 +87,17 @@ def __init__(self, config: Optional[ChatConfig] = None):
self.log = get_logger(__name__)
self.log.setLevel(getattr(logging, self.config.logging_level))

# Validate that both providers aren't specified
if self.config.use_claude and self.config.use_chatgpt:
raise ValueError(
"Cannot specify both use_claude and use_chatgpt. Please choose one."
)

# Initialize LLM client - it will compute use_local automatically
self.llm_client = LLMClient(
# Initialize LLM client - factory auto-detects provider and validates
self.llm_client = create_client(
use_claude=self.config.use_claude,
use_openai=self.config.use_chatgpt,
claude_model=self.config.claude_model,
model=(
self.config.claude_model
if self.config.use_claude
else self.config.model
),
base_url=self.config.base_url,
Comment thread
itomek marked this conversation as resolved.
system_prompt=None, # We handle system prompts through Prompts class
system_prompt=self.config.system_prompt,
)

# Store conversation history
Expand Down
19 changes: 4 additions & 15 deletions src/gaia/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import sys
import time
from pathlib import Path
from typing import Optional

from dotenv import load_dotenv

from gaia.llm import create_client
from gaia.llm.lemonade_client import (
DEFAULT_HOST,
DEFAULT_LEMONADE_URL,
Expand All @@ -21,7 +21,6 @@
LemonadeClientError,
_get_lemonade_config,
)
from gaia.llm.llm_client import LLMClient
from gaia.logger import get_logger
from gaia.version import version

Expand Down Expand Up @@ -120,8 +119,8 @@ def initialize_lemonade_for_agent(
skip_if_external: bool = False,
use_claude: bool = False,
use_chatgpt: bool = False,
host: Optional[str] = None,
port: Optional[int] = None,
host: str | None = None,
port: int | None = None,
):
"""
Initialize Lemonade Server for a specific GAIA agent.
Expand Down Expand Up @@ -401,7 +400,7 @@ def __init__(
self.show_stats = show_stats

# Initialize LLM client for local inference
self.llm_client = LLMClient()
self.llm_client = create_client("lemonade", model=model)

self.log.debug("Gaia CLI client initialized.")
self.log.debug(f"model: {self.model}\n max_tokens: {self.max_tokens}")
Expand Down Expand Up @@ -3138,16 +3137,6 @@ def download_progress_callback(event_type: str, data: dict) -> None:
LemonadeManager.print_server_error()
else:
print(f"❌ Error: {str(e)}")
# Check for model loading related errors
if (
"404" in error_msg
or "not found" in error_msg
or "not loaded" in error_msg
):
print(
"\nMake sure that the model is loaded. You can load it using:"
)
print(f" gaia pull {DEFAULT_MODEL_NAME}")
Comment thread
itomek marked this conversation as resolved.
return

# Handle groundtruth generation
Expand Down
7 changes: 7 additions & 0 deletions src/gaia/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""LLM client package."""

from .base_client import LLMClient
from .exceptions import NotSupportedError
from .factory import create_client

__all__ = ["create_client", "LLMClient", "NotSupportedError"]
60 changes: 60 additions & 0 deletions src/gaia/llm/base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""Base LLM client interface."""

from abc import ABC, abstractmethod
from typing import Iterator, Union

from .exceptions import NotSupportedError


class LLMClient(ABC):
"""
Unified LLM client interface.

Methods raise NotSupportedError if not available for this provider.
"""

@property
@abstractmethod
def provider_name(self) -> str:
"""Return the provider name for error messages."""
...

@abstractmethod
def generate(
self,
prompt: str,
model: str | None = None,
stream: bool = False,
**kwargs,
) -> Union[str, Iterator[str]]:
"""Generate text completion."""
...

@abstractmethod
def chat(
self,
messages: list[dict],
model: str | None = None,
stream: bool = False,
**kwargs,
) -> Union[str, Iterator[str]]:
"""Chat completion."""
...

# Optional - default raises NotSupportedError
def embed(self, texts: list[str], **kwargs) -> list[list[float]]:
raise NotSupportedError(self.provider_name, "embed")

def vision(self, images: list[bytes], prompt: str, **kwargs) -> str:
raise NotSupportedError(self.provider_name, "vision")

def get_performance_stats(self) -> dict:
raise NotSupportedError(self.provider_name, "get_performance_stats")

def load_model(self, model_name: str, **kwargs) -> None:
raise NotSupportedError(self.provider_name, "load_model")

def unload_model(self) -> None:
raise NotSupportedError(self.provider_name, "unload_model")
12 changes: 12 additions & 0 deletions src/gaia/llm/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""LLM client exceptions."""


class NotSupportedError(Exception):
"""Raised when a provider doesn't support a method."""

def __init__(self, provider: str, method: str):
self.provider = provider
self.method = method
super().__init__(f"{provider} does not support {method}")
Loading
Loading