Skip to content

Commit 881ebcf

Browse files
RheagalFireitomekTomasz Iniewicz
authored
feat(llm): add LiteLLM as AI gateway provider (#1593)
## Summary Adds LiteLLM as a fourth LLM provider alongside Lemonade, OpenAI, and Claude, giving GAIA users access to 100+ cloud providers (Bedrock, Vertex AI, Groq, DeepSeek, Azure OpenAI, etc.) through a single `create_client("litellm")` call. ## Why GAIA's `LLMClient` abstraction covers local inference (Lemonade) and two cloud providers (OpenAI, Claude). Adding providers individually doesn't scale; LiteLLM is one dependency that covers 100+ providers with `drop_params=True` for cross-provider kwarg compatibility. ## Changes - `src/gaia/llm/providers/litellm.py` -- new `LiteLLMProvider(LLMClient)` with `generate()`, `chat()`, `embed()`, streaming, and `drop_params=True` default - `src/gaia/llm/factory.py` -- registered `"litellm"` in `_PROVIDERS` - `src/gaia/llm/providers/__init__.py` -- export `LiteLLMProvider` - `setup.py` -- added `[litellm]` optional extra (`litellm>=1.35.0,<2.0`) - `tests/unit/test_litellm_provider.py` -- 10 unit tests ## Test plan - [x] `python -m pytest tests/unit/test_litellm_provider.py -v` -- 10/10 pass - [x] `python -m pytest tests/unit/test_llm_client_factory.py tests/unit/test_openai_provider.py -v` -- existing LLM tests still pass (76 passed) - [x] `python util/lint.py --all --fix` -- clean - [x] Live E2E: `create_client("litellm")` -> LiteLLM proxy -> Azure Foundry (Claude Sonnet 4.6): ``` Provider: LiteLLM Generate response: '4' Chat response: 'OK' Stream chunks: 2 chunks, text: 'Hello! ...' === E2E PASSED === ``` ## Checklist - [x] I have linked a GitHub issue above (`Closes #N` / `Fixes #N` / `Refs #N`). - [x] I have described **why** this change is being made, not just what changed. - [x] I have run linting and tests locally (`python util/lint.py --all`, `pytest tests/unit/`). - [ ] I have updated documentation if user-visible behavior changed (see [CONTRIBUTING.md](../CONTRIBUTING.md)). --------- Co-authored-by: Tomasz Iniewicz <itomek@users.noreply.github.com> Co-authored-by: Tomasz Iniewicz <heaters-nays0p@icloud.com>
1 parent 9ab7da1 commit 881ebcf

5 files changed

Lines changed: 255 additions & 1 deletion

File tree

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@
170170
"telegram": [
171171
"python-telegram-bot>=20.3",
172172
],
173+
"litellm": [
174+
"litellm>=1.35.0,<2.0",
175+
],
173176
"dev": [
174177
"pytest",
175178
"pytest-cov",

src/gaia/llm/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"lemonade": "gaia.llm.providers.lemonade.LemonadeProvider",
1111
"openai": "gaia.llm.providers.openai_provider.OpenAIProvider",
1212
"claude": "gaia.llm.providers.claude.ClaudeProvider",
13+
"litellm": "gaia.llm.providers.litellm.LiteLLMProvider",
1314
}
1415

1516

src/gaia/llm/providers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .claude import ClaudeProvider
66
from .lemonade import LemonadeProvider
7+
from .litellm import LiteLLMProvider
78
from .openai_provider import OpenAIProvider
89

9-
__all__ = ["ClaudeProvider", "LemonadeProvider", "OpenAIProvider"]
10+
__all__ = ["ClaudeProvider", "LemonadeProvider", "LiteLLMProvider", "OpenAIProvider"]

src/gaia/llm/providers/litellm.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
"""LiteLLM provider - unified gateway for 100+ LLM providers."""
4+
5+
from typing import Iterator, Optional, Union
6+
7+
from ..base_client import LLMClient
8+
9+
10+
class LiteLLMProvider(LLMClient):
11+
"""LiteLLM AI gateway provider."""
12+
13+
def __init__(
14+
self,
15+
api_key: Optional[str] = None,
16+
model: str = "gpt-4o",
17+
system_prompt: Optional[str] = None,
18+
**kwargs,
19+
):
20+
import litellm
21+
22+
self._model = model
23+
self._system_prompt = system_prompt
24+
self._api_key = api_key
25+
self._extra_kwargs = kwargs
26+
27+
litellm.drop_params = True
28+
29+
@property
30+
def provider_name(self) -> str:
31+
return "LiteLLM"
32+
33+
def generate(
34+
self,
35+
prompt: str,
36+
model: str | None = None,
37+
stream: bool = False,
38+
**kwargs,
39+
) -> Union[str, Iterator[str]]:
40+
return self.chat(
41+
[{"role": "user", "content": prompt}],
42+
model=model,
43+
stream=stream,
44+
**kwargs,
45+
)
46+
47+
def chat(
48+
self,
49+
messages: list[dict],
50+
model: str | None = None,
51+
stream: bool = False,
52+
**kwargs,
53+
) -> Union[str, Iterator[str]]:
54+
import litellm
55+
56+
if self._system_prompt:
57+
messages = [{"role": "system", "content": self._system_prompt}] + list(
58+
messages
59+
)
60+
61+
call_kwargs = {**self._extra_kwargs, **kwargs}
62+
if self._api_key:
63+
call_kwargs["api_key"] = self._api_key
64+
65+
response = litellm.completion(
66+
model=model or self._model,
67+
messages=messages,
68+
stream=stream,
69+
drop_params=True,
70+
**call_kwargs,
71+
)
72+
if stream:
73+
return self._handle_stream(response)
74+
return response.choices[0].message.content
75+
76+
def embed(self, texts: list[str], **kwargs) -> list[list[float]]:
77+
import litellm
78+
79+
call_kwargs = {**self._extra_kwargs, **kwargs}
80+
if self._api_key:
81+
call_kwargs["api_key"] = self._api_key
82+
83+
response = litellm.embedding(
84+
model=kwargs.pop("model", self._model),
85+
input=texts,
86+
drop_params=True,
87+
**call_kwargs,
88+
)
89+
return [item["embedding"] for item in response.data]
90+
91+
def _handle_stream(self, response) -> Iterator[str]:
92+
for chunk in response:
93+
if chunk.choices and chunk.choices[0].delta.content:
94+
yield chunk.choices[0].delta.content
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
"""Tests for LiteLLM provider."""
4+
5+
import sys
6+
import types
7+
from unittest.mock import MagicMock
8+
9+
import pytest
10+
11+
12+
def _stub_litellm():
13+
"""Install a stub litellm module so tests run without the real package."""
14+
fake = types.ModuleType("litellm")
15+
fake.completion = MagicMock(name="litellm.completion")
16+
fake.embedding = MagicMock(name="litellm.embedding")
17+
fake.drop_params = False
18+
sys.modules["litellm"] = fake
19+
return fake
20+
21+
22+
class TestLiteLLMProviderName:
23+
def test_provider_name(self):
24+
fake = _stub_litellm()
25+
from gaia.llm.providers.litellm import LiteLLMProvider
26+
27+
provider = LiteLLMProvider(api_key="test-key", model="gpt-4o")
28+
assert provider.provider_name == "LiteLLM"
29+
del sys.modules["litellm"]
30+
31+
32+
class TestLiteLLMFactory:
33+
def test_create_client_litellm(self):
34+
_stub_litellm()
35+
from gaia.llm import create_client
36+
37+
client = create_client("litellm", api_key="test-key")
38+
assert client.provider_name == "LiteLLM"
39+
del sys.modules["litellm"]
40+
41+
def test_create_client_litellm_case_insensitive(self):
42+
_stub_litellm()
43+
from gaia.llm import create_client
44+
45+
client = create_client("LITELLM", api_key="test-key")
46+
assert client.provider_name == "LiteLLM"
47+
del sys.modules["litellm"]
48+
49+
50+
class TestLiteLLMChat:
51+
def test_chat_calls_litellm_completion(self):
52+
fake = _stub_litellm()
53+
fake.completion.return_value = MagicMock(
54+
choices=[MagicMock(message=MagicMock(content="Hello!"))]
55+
)
56+
from gaia.llm.providers.litellm import LiteLLMProvider
57+
58+
provider = LiteLLMProvider(api_key="sk-test", model="gpt-4o")
59+
result = provider.chat([{"role": "user", "content": "Hi"}])
60+
61+
assert result == "Hello!"
62+
fake.completion.assert_called_once()
63+
call_kwargs = fake.completion.call_args
64+
assert call_kwargs.kwargs["model"] == "gpt-4o"
65+
assert call_kwargs.kwargs["drop_params"] is True
66+
assert call_kwargs.kwargs["api_key"] == "sk-test"
67+
del sys.modules["litellm"]
68+
69+
def test_chat_prepends_system_prompt(self):
70+
fake = _stub_litellm()
71+
fake.completion.return_value = MagicMock(
72+
choices=[MagicMock(message=MagicMock(content="OK"))]
73+
)
74+
from gaia.llm.providers.litellm import LiteLLMProvider
75+
76+
provider = LiteLLMProvider(
77+
api_key="sk-test", model="gpt-4o", system_prompt="You are helpful."
78+
)
79+
provider.chat([{"role": "user", "content": "Hi"}])
80+
81+
messages = fake.completion.call_args.kwargs["messages"]
82+
assert messages[0]["role"] == "system"
83+
assert messages[0]["content"] == "You are helpful."
84+
del sys.modules["litellm"]
85+
86+
def test_chat_omits_api_key_when_not_set(self):
87+
fake = _stub_litellm()
88+
fake.completion.return_value = MagicMock(
89+
choices=[MagicMock(message=MagicMock(content="OK"))]
90+
)
91+
from gaia.llm.providers.litellm import LiteLLMProvider
92+
93+
provider = LiteLLMProvider(model="gpt-4o")
94+
provider.chat([{"role": "user", "content": "Hi"}])
95+
96+
assert "api_key" not in fake.completion.call_args.kwargs
97+
del sys.modules["litellm"]
98+
99+
def test_chat_uses_override_model(self):
100+
fake = _stub_litellm()
101+
fake.completion.return_value = MagicMock(
102+
choices=[MagicMock(message=MagicMock(content="OK"))]
103+
)
104+
from gaia.llm.providers.litellm import LiteLLMProvider
105+
106+
provider = LiteLLMProvider(model="gpt-4o")
107+
provider.chat(
108+
[{"role": "user", "content": "Hi"}],
109+
model="anthropic/claude-sonnet-4-6",
110+
)
111+
112+
assert (
113+
fake.completion.call_args.kwargs["model"] == "anthropic/claude-sonnet-4-6"
114+
)
115+
del sys.modules["litellm"]
116+
117+
118+
class TestLiteLLMGenerate:
119+
def test_generate_delegates_to_chat(self):
120+
fake = _stub_litellm()
121+
fake.completion.return_value = MagicMock(
122+
choices=[MagicMock(message=MagicMock(content="4"))]
123+
)
124+
from gaia.llm.providers.litellm import LiteLLMProvider
125+
126+
provider = LiteLLMProvider(model="gpt-4o")
127+
result = provider.generate("What is 2+2?")
128+
129+
assert result == "4"
130+
messages = fake.completion.call_args.kwargs["messages"]
131+
assert messages[0] == {"role": "user", "content": "What is 2+2?"}
132+
del sys.modules["litellm"]
133+
134+
135+
class TestLiteLLMNotSupported:
136+
def test_vision_raises_not_supported(self):
137+
_stub_litellm()
138+
from gaia.llm import NotSupportedError
139+
from gaia.llm.providers.litellm import LiteLLMProvider
140+
141+
provider = LiteLLMProvider(model="gpt-4o")
142+
with pytest.raises(NotSupportedError) as exc:
143+
provider.vision([b"image"], "describe this")
144+
assert "LiteLLM" in str(exc.value)
145+
del sys.modules["litellm"]
146+
147+
def test_load_model_raises_not_supported(self):
148+
_stub_litellm()
149+
from gaia.llm import NotSupportedError
150+
from gaia.llm.providers.litellm import LiteLLMProvider
151+
152+
provider = LiteLLMProvider(model="gpt-4o")
153+
with pytest.raises(NotSupportedError):
154+
provider.load_model("some-model")
155+
del sys.modules["litellm"]

0 commit comments

Comments
 (0)