Skip to content

Commit 4caa4a6

Browse files
authored
feat: add LiteLLM as AI gateway backend (#1302)
1 parent b3d9dff commit 4caa4a6

5 files changed

Lines changed: 226 additions & 0 deletions

File tree

lmms_eval/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"llava_sglang": "LlavaSglang",
6161
"llava_vid": "LlavaVid",
6262
"llava": "Llava",
63+
"litellm": "LiteLLMCompatible",
6364
"longva": "LongVA",
6465
"mantis": "Mantis",
6566
"minicpm_o": "MiniCPM_O",
@@ -126,6 +127,7 @@
126127
"vllm_generate": "VLLMGenerate",
127128
"sglang": "Sglang",
128129
"huggingface": "Huggingface",
130+
"litellm": "LiteLLMCompatible",
129131
"async_openai": "AsyncOpenAIChat",
130132
"async_hf_model": "AsyncHFModel",
131133
"longvila": "LongVila",
@@ -137,6 +139,7 @@
137139
"openai": ("openai_compatible", "openai_compatible_chat"),
138140
"async_openai": ("async_openai_compatible_chat", "async_openai_compatible"),
139141
"async_hf_model": ("async_hf",),
142+
"litellm": ("litellm_chat", "litellm_compatible"),
140143
}
141144

142145

lmms_eval/models/chat/litellm.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Chat-template variant of the LiteLLM backend.
2+
3+
Inherits the richer ``generate_until`` from ``chat/openai.py`` (ThreadPoolExecutor +
4+
adaptive concurrency + prefix-aware queue). The client is swapped for a LiteLLM-backed
5+
shim in ``__init__`` so every call routes through ``litellm.completion``.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
from typing import Any, Optional
12+
13+
from lmms_eval.api.registry import register_model
14+
from lmms_eval.models.chat.openai import OpenAICompatible as OpenAICompatibleChatBase
15+
from lmms_eval.models.simple.litellm import _PLACEHOLDER_API_KEY, _LiteLLMClientShim
16+
17+
18+
@register_model("litellm")
19+
class LiteLLMCompatible(OpenAICompatibleChatBase):
20+
"""LiteLLM-backed chat backend."""
21+
22+
is_simple = False
23+
24+
def __init__(
25+
self,
26+
model_version: str = "openai/gpt-4o-mini",
27+
model: Optional[str] = None,
28+
base_url: Optional[str] = None,
29+
api_key: Optional[str] = None,
30+
**kwargs: Any,
31+
) -> None:
32+
resolved_api_key = api_key or os.getenv("OPENAI_API_KEY")
33+
resolved_base_url = base_url or os.getenv("OPENAI_API_BASE")
34+
35+
super().__init__(
36+
model_version=model_version,
37+
model=model,
38+
base_url=resolved_base_url,
39+
api_key=resolved_api_key or _PLACEHOLDER_API_KEY,
40+
azure_openai=False,
41+
**kwargs,
42+
)
43+
self.client = _LiteLLMClientShim(api_key=resolved_api_key, base_url=resolved_base_url)

lmms_eval/models/simple/litellm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""LiteLLM backend — routes requests through ``litellm.completion()`` to 100+ providers
2+
(OpenAI, Anthropic, Bedrock, Vertex, Gemini, Ollama, OpenRouter, Groq, DeepSeek, etc.)
3+
using provider-native API keys.
4+
5+
Reuses the ``OpenAICompatible`` simple-backend implementation end-to-end by swapping
6+
``self.client`` for a thin duck-typed shim that dispatches ``chat.completions.create``
7+
to ``litellm.completion``. All retry/concurrency/media-encoding logic is inherited
8+
unchanged.
9+
10+
See https://docs.litellm.ai/docs/providers for the full provider list and model-name
11+
prefix convention (e.g. ``anthropic/claude-3-5-sonnet-20241022``).
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import os
17+
from typing import Any, Optional
18+
19+
from lmms_eval.api.registry import register_model
20+
from lmms_eval.models.simple.openai import OpenAICompatible as OpenAICompatibleBase
21+
22+
# Placeholder key passed to ``openai.OpenAI(api_key=...)`` inside the super().__init__()
23+
# call. openai-python raises at construction if both the argument and OPENAI_API_KEY env
24+
# var are unset. We replace ``self.client`` with a LiteLLM-backed shim immediately after,
25+
# so this placeholder is never used on the wire.
26+
_PLACEHOLDER_API_KEY = "sk-litellm-placeholder"
27+
28+
29+
class _LiteLLMChatCompletions:
30+
"""Duck-typed ``openai.OpenAI().chat.completions`` surface backed by ``litellm.completion``."""
31+
32+
def __init__(self, api_key: Optional[str], base_url: Optional[str]) -> None:
33+
self._api_key = api_key
34+
self._base_url = base_url
35+
36+
def create(self, **kwargs: Any) -> Any:
37+
import litellm # lazy import — ``litellm`` is an optional extra
38+
39+
if self._api_key is not None:
40+
kwargs.setdefault("api_key", self._api_key)
41+
if self._base_url is not None:
42+
kwargs.setdefault("api_base", self._base_url)
43+
return litellm.completion(**kwargs)
44+
45+
46+
class _LiteLLMChat:
47+
def __init__(self, completions: _LiteLLMChatCompletions) -> None:
48+
self.completions = completions
49+
50+
51+
class _LiteLLMClientShim:
52+
"""Minimal OpenAI-client shape used by lmms-eval's ``generate_until`` call path.
53+
54+
lmms-eval only calls ``self.client.chat.completions.create(**payload)``; this shim
55+
dispatches that single method to ``litellm.completion``. API key / base URL are
56+
forwarded per-call; when they are None, LiteLLM resolves credentials from
57+
provider-specific env vars (``ANTHROPIC_API_KEY``, ``GEMINI_API_KEY``, ``AWS_*``, ...).
58+
"""
59+
60+
def __init__(self, api_key: Optional[str], base_url: Optional[str]) -> None:
61+
self.chat = _LiteLLMChat(_LiteLLMChatCompletions(api_key=api_key, base_url=base_url))
62+
63+
64+
@register_model("litellm")
65+
class LiteLLMCompatible(OpenAICompatibleBase):
66+
"""LiteLLM-backed backend that inherits OpenAI-compatible batching/retry logic.
67+
68+
Users select this backend via ``--model litellm --model_args model=<prefixed_name>``.
69+
Provider API keys come from the user's environment (``ANTHROPIC_API_KEY``, ...) or
70+
can be passed explicitly via ``--model_args api_key=...``.
71+
"""
72+
73+
def __init__(
74+
self,
75+
model_version: str = "openai/gpt-4o-mini",
76+
model: Optional[str] = None,
77+
base_url: Optional[str] = None,
78+
api_key: Optional[str] = None,
79+
**kwargs: Any,
80+
) -> None:
81+
resolved_api_key = api_key or os.getenv("OPENAI_API_KEY")
82+
resolved_base_url = base_url or os.getenv("OPENAI_API_BASE")
83+
84+
# The parent __init__ builds an openai.OpenAI(...) client and raises if no API
85+
# key is resolvable. Hand it a placeholder so construction succeeds; we swap
86+
# self.client immediately afterward with a LiteLLM-backed shim that never uses
87+
# the OpenAI client.
88+
super().__init__(
89+
model_version=model_version,
90+
model=model,
91+
base_url=resolved_base_url,
92+
api_key=resolved_api_key or _PLACEHOLDER_API_KEY,
93+
azure_openai=False,
94+
**kwargs,
95+
)
96+
self.client = _LiteLLMClientShim(api_key=resolved_api_key, base_url=resolved_base_url)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ metrics = [
107107
gemini = [
108108
"google-generativeai",
109109
]
110+
litellm = [
111+
"litellm",
112+
]
110113
reka = [
111114
"httpx>=0.23.3",
112115
"reka-api",

test/models/test_litellm.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Unit tests for the LiteLLM backend (simple + chat variants)."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
import types
7+
import unittest
8+
from types import SimpleNamespace
9+
from unittest import mock
10+
11+
12+
def _install_litellm_stub() -> mock.MagicMock:
13+
"""Register a fake ``litellm`` module so ``import litellm`` resolves in tests."""
14+
fake = types.ModuleType("litellm")
15+
fake.completion = mock.MagicMock(name="litellm.completion")
16+
sys.modules["litellm"] = fake
17+
return fake.completion
18+
19+
20+
def _fake_chat_completion(content: str = "hi", prompt_tokens: int = 3, completion_tokens: int = 5) -> SimpleNamespace:
21+
usage = SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, completion_tokens_details=None)
22+
message = SimpleNamespace(content=content)
23+
choice = SimpleNamespace(message=message, finish_reason="stop", index=0)
24+
return SimpleNamespace(choices=[choice], usage=usage, id="cmpl-test", model="test")
25+
26+
27+
class TestLiteLLMShim(unittest.TestCase):
28+
def test_shim_forwards_to_litellm_completion(self):
29+
completion = _install_litellm_stub()
30+
completion.return_value = _fake_chat_completion("pong")
31+
32+
from lmms_eval.models.simple.litellm import _LiteLLMClientShim
33+
34+
shim = _LiteLLMClientShim(api_key="sk-user", base_url="https://proxy.example/v1")
35+
resp = shim.chat.completions.create(
36+
model="anthropic/claude-3-5-sonnet-20241022",
37+
messages=[{"role": "user", "content": "ping"}],
38+
)
39+
40+
self.assertEqual(resp.choices[0].message.content, "pong")
41+
completion.assert_called_once()
42+
kwargs = completion.call_args.kwargs
43+
self.assertEqual(kwargs["model"], "anthropic/claude-3-5-sonnet-20241022")
44+
self.assertEqual(kwargs["api_key"], "sk-user")
45+
self.assertEqual(kwargs["api_base"], "https://proxy.example/v1")
46+
47+
def test_shim_without_explicit_credentials(self):
48+
"""When no api_key/base_url is set, the shim must not inject them; LiteLLM
49+
then falls back to provider-specific env vars (ANTHROPIC_API_KEY, ...)."""
50+
completion = _install_litellm_stub()
51+
completion.return_value = _fake_chat_completion()
52+
53+
from lmms_eval.models.simple.litellm import _LiteLLMClientShim
54+
55+
shim = _LiteLLMClientShim(api_key=None, base_url=None)
56+
shim.chat.completions.create(model="openai/gpt-4o-mini", messages=[])
57+
58+
kwargs = completion.call_args.kwargs
59+
self.assertNotIn("api_key", kwargs)
60+
self.assertNotIn("api_base", kwargs)
61+
62+
def test_litellm_registered_as_simple_and_chat(self):
63+
"""Confirm the model manifest is discoverable under both backends."""
64+
from lmms_eval.models import MODEL_REGISTRY_V2
65+
66+
manifest = MODEL_REGISTRY_V2.get_manifest("litellm")
67+
self.assertEqual(manifest.model_id, "litellm")
68+
self.assertEqual(
69+
manifest.simple_class_path,
70+
"lmms_eval.models.simple.litellm.LiteLLMCompatible",
71+
)
72+
self.assertEqual(
73+
manifest.chat_class_path,
74+
"lmms_eval.models.chat.litellm.LiteLLMCompatible",
75+
)
76+
self.assertIn("litellm_chat", manifest.aliases)
77+
self.assertIn("litellm_compatible", manifest.aliases)
78+
79+
80+
if __name__ == "__main__":
81+
unittest.main()

0 commit comments

Comments
 (0)