Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@
"telegram": [
"python-telegram-bot>=20.3",
],
"litellm": [
"litellm>=1.35.0,<2.0",
],
"dev": [
"pytest",
"pytest-cov",
Expand Down
1 change: 1 addition & 0 deletions src/gaia/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"lemonade": "gaia.llm.providers.lemonade.LemonadeProvider",
"openai": "gaia.llm.providers.openai_provider.OpenAIProvider",
"claude": "gaia.llm.providers.claude.ClaudeProvider",
"litellm": "gaia.llm.providers.litellm.LiteLLMProvider",
}


Expand Down
3 changes: 2 additions & 1 deletion src/gaia/llm/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .claude import ClaudeProvider
from .lemonade import LemonadeProvider
from .litellm import LiteLLMProvider
from .openai_provider import OpenAIProvider

__all__ = ["ClaudeProvider", "LemonadeProvider", "OpenAIProvider"]
__all__ = ["ClaudeProvider", "LemonadeProvider", "LiteLLMProvider", "OpenAIProvider"]
94 changes: 94 additions & 0 deletions src/gaia/llm/providers/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""LiteLLM provider - unified gateway for 100+ LLM providers."""

from typing import Iterator, Optional, Union

from ..base_client import LLMClient


class LiteLLMProvider(LLMClient):
"""LiteLLM AI gateway provider."""

def __init__(
self,
api_key: Optional[str] = None,
model: str = "gpt-4o",
system_prompt: Optional[str] = None,
**kwargs,
):
import litellm

self._model = model
self._system_prompt = system_prompt
self._api_key = api_key
self._extra_kwargs = kwargs

litellm.drop_params = True

@property
def provider_name(self) -> str:
return "LiteLLM"

def generate(
self,
prompt: str,
model: str | None = None,
stream: bool = False,
**kwargs,
) -> Union[str, Iterator[str]]:
return self.chat(
[{"role": "user", "content": prompt}],
model=model,
stream=stream,
**kwargs,
)

def chat(
self,
messages: list[dict],
model: str | None = None,
stream: bool = False,
**kwargs,
) -> Union[str, Iterator[str]]:
import litellm

if self._system_prompt:
messages = [{"role": "system", "content": self._system_prompt}] + list(
messages
)

call_kwargs = {**self._extra_kwargs, **kwargs}
if self._api_key:
call_kwargs["api_key"] = self._api_key

response = litellm.completion(
model=model or self._model,
messages=messages,
stream=stream,
drop_params=True,
**call_kwargs,
)
if stream:
return self._handle_stream(response)
return response.choices[0].message.content

def embed(self, texts: list[str], **kwargs) -> list[list[float]]:
import litellm

call_kwargs = {**self._extra_kwargs, **kwargs}
if self._api_key:
call_kwargs["api_key"] = self._api_key

response = litellm.embedding(
model=kwargs.pop("model", self._model),
input=texts,
drop_params=True,
**call_kwargs,
)
return [item["embedding"] for item in response.data]

def _handle_stream(self, response) -> Iterator[str]:
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
155 changes: 155 additions & 0 deletions tests/unit/test_litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""Tests for LiteLLM provider."""

import sys
import types
from unittest.mock import MagicMock

import pytest


def _stub_litellm():
"""Install a stub litellm module so tests run without the real package."""
fake = types.ModuleType("litellm")
fake.completion = MagicMock(name="litellm.completion")
fake.embedding = MagicMock(name="litellm.embedding")
fake.drop_params = False
sys.modules["litellm"] = fake
return fake


class TestLiteLLMProviderName:
def test_provider_name(self):
fake = _stub_litellm()
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(api_key="test-key", model="gpt-4o")
assert provider.provider_name == "LiteLLM"
del sys.modules["litellm"]


class TestLiteLLMFactory:
def test_create_client_litellm(self):
_stub_litellm()
from gaia.llm import create_client

client = create_client("litellm", api_key="test-key")
assert client.provider_name == "LiteLLM"
del sys.modules["litellm"]

def test_create_client_litellm_case_insensitive(self):
_stub_litellm()
from gaia.llm import create_client

client = create_client("LITELLM", api_key="test-key")
assert client.provider_name == "LiteLLM"
del sys.modules["litellm"]


class TestLiteLLMChat:
def test_chat_calls_litellm_completion(self):
fake = _stub_litellm()
fake.completion.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="Hello!"))]
)
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(api_key="sk-test", model="gpt-4o")
result = provider.chat([{"role": "user", "content": "Hi"}])

assert result == "Hello!"
fake.completion.assert_called_once()
call_kwargs = fake.completion.call_args
assert call_kwargs.kwargs["model"] == "gpt-4o"
assert call_kwargs.kwargs["drop_params"] is True
assert call_kwargs.kwargs["api_key"] == "sk-test"
del sys.modules["litellm"]

def test_chat_prepends_system_prompt(self):
fake = _stub_litellm()
fake.completion.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="OK"))]
)
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(
api_key="sk-test", model="gpt-4o", system_prompt="You are helpful."
)
provider.chat([{"role": "user", "content": "Hi"}])

messages = fake.completion.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are helpful."
del sys.modules["litellm"]

def test_chat_omits_api_key_when_not_set(self):
fake = _stub_litellm()
fake.completion.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="OK"))]
)
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="gpt-4o")
provider.chat([{"role": "user", "content": "Hi"}])

assert "api_key" not in fake.completion.call_args.kwargs
del sys.modules["litellm"]

def test_chat_uses_override_model(self):
fake = _stub_litellm()
fake.completion.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="OK"))]
)
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="gpt-4o")
provider.chat(
[{"role": "user", "content": "Hi"}],
model="anthropic/claude-sonnet-4-6",
)

assert (
fake.completion.call_args.kwargs["model"] == "anthropic/claude-sonnet-4-6"
)
del sys.modules["litellm"]


class TestLiteLLMGenerate:
def test_generate_delegates_to_chat(self):
fake = _stub_litellm()
fake.completion.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="4"))]
)
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="gpt-4o")
result = provider.generate("What is 2+2?")

assert result == "4"
messages = fake.completion.call_args.kwargs["messages"]
assert messages[0] == {"role": "user", "content": "What is 2+2?"}
del sys.modules["litellm"]


class TestLiteLLMNotSupported:
def test_vision_raises_not_supported(self):
_stub_litellm()
from gaia.llm import NotSupportedError
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="gpt-4o")
with pytest.raises(NotSupportedError) as exc:
provider.vision([b"image"], "describe this")
assert "LiteLLM" in str(exc.value)
del sys.modules["litellm"]

def test_load_model_raises_not_supported(self):
_stub_litellm()
from gaia.llm import NotSupportedError
from gaia.llm.providers.litellm import LiteLLMProvider

provider = LiteLLMProvider(model="gpt-4o")
with pytest.raises(NotSupportedError):
provider.load_model("some-model")
del sys.modules["litellm"]
Loading