Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"llava_sglang": "LlavaSglang",
"llava_vid": "LlavaVid",
"llava": "Llava",
"litellm": "LiteLLMCompatible",
"longva": "LongVA",
"mantis": "Mantis",
"minicpm_o": "MiniCPM_O",
Expand Down Expand Up @@ -126,6 +127,7 @@
"vllm_generate": "VLLMGenerate",
"sglang": "Sglang",
"huggingface": "Huggingface",
"litellm": "LiteLLMCompatible",
"async_openai": "AsyncOpenAIChat",
"async_hf_model": "AsyncHFModel",
"longvila": "LongVila",
Expand All @@ -137,6 +139,7 @@
"openai": ("openai_compatible", "openai_compatible_chat"),
"async_openai": ("async_openai_compatible_chat", "async_openai_compatible"),
"async_hf_model": ("async_hf",),
"litellm": ("litellm_chat", "litellm_compatible"),
}


Expand Down
43 changes: 43 additions & 0 deletions lmms_eval/models/chat/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Chat-template variant of the LiteLLM backend.

Inherits the richer ``generate_until`` from ``chat/openai.py`` (ThreadPoolExecutor +
adaptive concurrency + prefix-aware queue). The client is swapped for a LiteLLM-backed
shim in ``__init__`` so every call routes through ``litellm.completion``.
"""

from __future__ import annotations

import os
from typing import Any, Optional

from lmms_eval.api.registry import register_model
from lmms_eval.models.chat.openai import OpenAICompatible as OpenAICompatibleChatBase
from lmms_eval.models.simple.litellm import _PLACEHOLDER_API_KEY, _LiteLLMClientShim


@register_model("litellm")
class LiteLLMCompatible(OpenAICompatibleChatBase):
"""LiteLLM-backed chat backend."""

is_simple = False

def __init__(
self,
model_version: str = "openai/gpt-4o-mini",
model: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
resolved_api_key = api_key or os.getenv("OPENAI_API_KEY")
resolved_base_url = base_url or os.getenv("OPENAI_API_BASE")

super().__init__(
model_version=model_version,
model=model,
base_url=resolved_base_url,
api_key=resolved_api_key or _PLACEHOLDER_API_KEY,
azure_openai=False,
**kwargs,
)
self.client = _LiteLLMClientShim(api_key=resolved_api_key, base_url=resolved_base_url)
96 changes: 96 additions & 0 deletions lmms_eval/models/simple/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""LiteLLM backend — routes requests through ``litellm.completion()`` to 100+ providers
(OpenAI, Anthropic, Bedrock, Vertex, Gemini, Ollama, OpenRouter, Groq, DeepSeek, etc.)
using provider-native API keys.

Reuses the ``OpenAICompatible`` simple-backend implementation end-to-end by swapping
``self.client`` for a thin duck-typed shim that dispatches ``chat.completions.create``
to ``litellm.completion``. All retry/concurrency/media-encoding logic is inherited
unchanged.

See https://docs.litellm.ai/docs/providers for the full provider list and model-name
prefix convention (e.g. ``anthropic/claude-3-5-sonnet-20241022``).
"""

from __future__ import annotations

import os
from typing import Any, Optional

from lmms_eval.api.registry import register_model
from lmms_eval.models.simple.openai import OpenAICompatible as OpenAICompatibleBase

# Placeholder key passed to ``openai.OpenAI(api_key=...)`` inside the super().__init__()
# call. openai-python raises at construction if both the argument and OPENAI_API_KEY env
# var are unset. We replace ``self.client`` with a LiteLLM-backed shim immediately after,
# so this placeholder is never used on the wire.
_PLACEHOLDER_API_KEY = "sk-litellm-placeholder"


class _LiteLLMChatCompletions:
"""Duck-typed ``openai.OpenAI().chat.completions`` surface backed by ``litellm.completion``."""

def __init__(self, api_key: Optional[str], base_url: Optional[str]) -> None:
self._api_key = api_key
self._base_url = base_url

def create(self, **kwargs: Any) -> Any:
import litellm # lazy import — ``litellm`` is an optional extra

if self._api_key is not None:
kwargs.setdefault("api_key", self._api_key)
if self._base_url is not None:
kwargs.setdefault("api_base", self._base_url)
return litellm.completion(**kwargs)


class _LiteLLMChat:
def __init__(self, completions: _LiteLLMChatCompletions) -> None:
self.completions = completions


class _LiteLLMClientShim:
"""Minimal OpenAI-client shape used by lmms-eval's ``generate_until`` call path.

lmms-eval only calls ``self.client.chat.completions.create(**payload)``; this shim
dispatches that single method to ``litellm.completion``. API key / base URL are
forwarded per-call; when they are None, LiteLLM resolves credentials from
provider-specific env vars (``ANTHROPIC_API_KEY``, ``GEMINI_API_KEY``, ``AWS_*``, ...).
"""

def __init__(self, api_key: Optional[str], base_url: Optional[str]) -> None:
self.chat = _LiteLLMChat(_LiteLLMChatCompletions(api_key=api_key, base_url=base_url))


@register_model("litellm")
class LiteLLMCompatible(OpenAICompatibleBase):
"""LiteLLM-backed backend that inherits OpenAI-compatible batching/retry logic.

Users select this backend via ``--model litellm --model_args model=<prefixed_name>``.
Provider API keys come from the user's environment (``ANTHROPIC_API_KEY``, ...) or
can be passed explicitly via ``--model_args api_key=...``.
"""

def __init__(
self,
model_version: str = "openai/gpt-4o-mini",
model: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
resolved_api_key = api_key or os.getenv("OPENAI_API_KEY")
resolved_base_url = base_url or os.getenv("OPENAI_API_BASE")

# The parent __init__ builds an openai.OpenAI(...) client and raises if no API
# key is resolvable. Hand it a placeholder so construction succeeds; we swap
# self.client immediately afterward with a LiteLLM-backed shim that never uses
# the OpenAI client.
super().__init__(
model_version=model_version,
model=model,
base_url=resolved_base_url,
api_key=resolved_api_key or _PLACEHOLDER_API_KEY,
azure_openai=False,
**kwargs,
)
self.client = _LiteLLMClientShim(api_key=resolved_api_key, base_url=resolved_base_url)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ metrics = [
gemini = [
"google-generativeai",
]
litellm = [
"litellm",
]
reka = [
"httpx>=0.23.3",
"reka-api",
Expand Down
81 changes: 81 additions & 0 deletions test/models/test_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Unit tests for the LiteLLM backend (simple + chat variants)."""

from __future__ import annotations

import sys
import types
import unittest
from types import SimpleNamespace
from unittest import mock


def _install_litellm_stub() -> mock.MagicMock:
"""Register a fake ``litellm`` module so ``import litellm`` resolves in tests."""
fake = types.ModuleType("litellm")
fake.completion = mock.MagicMock(name="litellm.completion")
sys.modules["litellm"] = fake
return fake.completion


def _fake_chat_completion(content: str = "hi", prompt_tokens: int = 3, completion_tokens: int = 5) -> SimpleNamespace:
usage = SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, completion_tokens_details=None)
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message, finish_reason="stop", index=0)
return SimpleNamespace(choices=[choice], usage=usage, id="cmpl-test", model="test")


class TestLiteLLMShim(unittest.TestCase):
def test_shim_forwards_to_litellm_completion(self):
completion = _install_litellm_stub()
completion.return_value = _fake_chat_completion("pong")

from lmms_eval.models.simple.litellm import _LiteLLMClientShim

shim = _LiteLLMClientShim(api_key="sk-user", base_url="https://proxy.example/v1")
resp = shim.chat.completions.create(
model="anthropic/claude-3-5-sonnet-20241022",
messages=[{"role": "user", "content": "ping"}],
)

self.assertEqual(resp.choices[0].message.content, "pong")
completion.assert_called_once()
kwargs = completion.call_args.kwargs
self.assertEqual(kwargs["model"], "anthropic/claude-3-5-sonnet-20241022")
self.assertEqual(kwargs["api_key"], "sk-user")
self.assertEqual(kwargs["api_base"], "https://proxy.example/v1")

def test_shim_without_explicit_credentials(self):
"""When no api_key/base_url is set, the shim must not inject them; LiteLLM
then falls back to provider-specific env vars (ANTHROPIC_API_KEY, ...)."""
completion = _install_litellm_stub()
completion.return_value = _fake_chat_completion()

from lmms_eval.models.simple.litellm import _LiteLLMClientShim

shim = _LiteLLMClientShim(api_key=None, base_url=None)
shim.chat.completions.create(model="openai/gpt-4o-mini", messages=[])

kwargs = completion.call_args.kwargs
self.assertNotIn("api_key", kwargs)
self.assertNotIn("api_base", kwargs)

def test_litellm_registered_as_simple_and_chat(self):
"""Confirm the model manifest is discoverable under both backends."""
from lmms_eval.models import MODEL_REGISTRY_V2

manifest = MODEL_REGISTRY_V2.get_manifest("litellm")
self.assertEqual(manifest.model_id, "litellm")
self.assertEqual(
manifest.simple_class_path,
"lmms_eval.models.simple.litellm.LiteLLMCompatible",
)
self.assertEqual(
manifest.chat_class_path,
"lmms_eval.models.chat.litellm.LiteLLMCompatible",
)
self.assertIn("litellm_chat", manifest.aliases)
self.assertIn("litellm_compatible", manifest.aliases)


if __name__ == "__main__":
unittest.main()
Loading