diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index 5ee75e923..91d68005a 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -60,6 +60,7 @@ "llava_sglang": "LlavaSglang", "llava_vid": "LlavaVid", "llava": "Llava", + "litellm": "LiteLLMCompatible", "longva": "LongVA", "mantis": "Mantis", "minicpm_o": "MiniCPM_O", @@ -126,6 +127,7 @@ "vllm_generate": "VLLMGenerate", "sglang": "Sglang", "huggingface": "Huggingface", + "litellm": "LiteLLMCompatible", "async_openai": "AsyncOpenAIChat", "async_hf_model": "AsyncHFModel", "longvila": "LongVila", @@ -137,6 +139,7 @@ "openai": ("openai_compatible", "openai_compatible_chat"), "async_openai": ("async_openai_compatible_chat", "async_openai_compatible"), "async_hf_model": ("async_hf",), + "litellm": ("litellm_chat", "litellm_compatible"), } diff --git a/lmms_eval/models/chat/litellm.py b/lmms_eval/models/chat/litellm.py new file mode 100644 index 000000000..1d34ff1d2 --- /dev/null +++ b/lmms_eval/models/chat/litellm.py @@ -0,0 +1,43 @@ +"""Chat-template variant of the LiteLLM backend. + +Inherits the richer ``generate_until`` from ``chat/openai.py`` (ThreadPoolExecutor + +adaptive concurrency + prefix-aware queue). The client is swapped for a LiteLLM-backed +shim in ``__init__`` so every call routes through ``litellm.completion``. +""" + +from __future__ import annotations + +import os +from typing import Any, Optional + +from lmms_eval.api.registry import register_model +from lmms_eval.models.chat.openai import OpenAICompatible as OpenAICompatibleChatBase +from lmms_eval.models.simple.litellm import _PLACEHOLDER_API_KEY, _LiteLLMClientShim + + +@register_model("litellm") +class LiteLLMCompatible(OpenAICompatibleChatBase): + """LiteLLM-backed chat backend.""" + + is_simple = False + + def __init__( + self, + model_version: str = "openai/gpt-4o-mini", + model: Optional[str] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + resolved_api_key = api_key or os.getenv("OPENAI_API_KEY") + resolved_base_url = base_url or os.getenv("OPENAI_API_BASE") + + super().__init__( + model_version=model_version, + model=model, + base_url=resolved_base_url, + api_key=resolved_api_key or _PLACEHOLDER_API_KEY, + azure_openai=False, + **kwargs, + ) + self.client = _LiteLLMClientShim(api_key=resolved_api_key, base_url=resolved_base_url) diff --git a/lmms_eval/models/simple/litellm.py b/lmms_eval/models/simple/litellm.py new file mode 100644 index 000000000..3554995fb --- /dev/null +++ b/lmms_eval/models/simple/litellm.py @@ -0,0 +1,96 @@ +"""LiteLLM backend — routes requests through ``litellm.completion()`` to 100+ providers +(OpenAI, Anthropic, Bedrock, Vertex, Gemini, Ollama, OpenRouter, Groq, DeepSeek, etc.) +using provider-native API keys. + +Reuses the ``OpenAICompatible`` simple-backend implementation end-to-end by swapping +``self.client`` for a thin duck-typed shim that dispatches ``chat.completions.create`` +to ``litellm.completion``. All retry/concurrency/media-encoding logic is inherited +unchanged. + +See https://docs.litellm.ai/docs/providers for the full provider list and model-name +prefix convention (e.g. ``anthropic/claude-3-5-sonnet-20241022``). +""" + +from __future__ import annotations + +import os +from typing import Any, Optional + +from lmms_eval.api.registry import register_model +from lmms_eval.models.simple.openai import OpenAICompatible as OpenAICompatibleBase + +# Placeholder key passed to ``openai.OpenAI(api_key=...)`` inside the super().__init__() +# call. openai-python raises at construction if both the argument and OPENAI_API_KEY env +# var are unset. We replace ``self.client`` with a LiteLLM-backed shim immediately after, +# so this placeholder is never used on the wire. +_PLACEHOLDER_API_KEY = "sk-litellm-placeholder" + + +class _LiteLLMChatCompletions: + """Duck-typed ``openai.OpenAI().chat.completions`` surface backed by ``litellm.completion``.""" + + def __init__(self, api_key: Optional[str], base_url: Optional[str]) -> None: + self._api_key = api_key + self._base_url = base_url + + def create(self, **kwargs: Any) -> Any: + import litellm # lazy import — ``litellm`` is an optional extra + + if self._api_key is not None: + kwargs.setdefault("api_key", self._api_key) + if self._base_url is not None: + kwargs.setdefault("api_base", self._base_url) + return litellm.completion(**kwargs) + + +class _LiteLLMChat: + def __init__(self, completions: _LiteLLMChatCompletions) -> None: + self.completions = completions + + +class _LiteLLMClientShim: + """Minimal OpenAI-client shape used by lmms-eval's ``generate_until`` call path. + + lmms-eval only calls ``self.client.chat.completions.create(**payload)``; this shim + dispatches that single method to ``litellm.completion``. API key / base URL are + forwarded per-call; when they are None, LiteLLM resolves credentials from + provider-specific env vars (``ANTHROPIC_API_KEY``, ``GEMINI_API_KEY``, ``AWS_*``, ...). + """ + + def __init__(self, api_key: Optional[str], base_url: Optional[str]) -> None: + self.chat = _LiteLLMChat(_LiteLLMChatCompletions(api_key=api_key, base_url=base_url)) + + +@register_model("litellm") +class LiteLLMCompatible(OpenAICompatibleBase): + """LiteLLM-backed backend that inherits OpenAI-compatible batching/retry logic. + + Users select this backend via ``--model litellm --model_args model=``. + Provider API keys come from the user's environment (``ANTHROPIC_API_KEY``, ...) or + can be passed explicitly via ``--model_args api_key=...``. + """ + + def __init__( + self, + model_version: str = "openai/gpt-4o-mini", + model: Optional[str] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + resolved_api_key = api_key or os.getenv("OPENAI_API_KEY") + resolved_base_url = base_url or os.getenv("OPENAI_API_BASE") + + # The parent __init__ builds an openai.OpenAI(...) client and raises if no API + # key is resolvable. Hand it a placeholder so construction succeeds; we swap + # self.client immediately afterward with a LiteLLM-backed shim that never uses + # the OpenAI client. + super().__init__( + model_version=model_version, + model=model, + base_url=resolved_base_url, + api_key=resolved_api_key or _PLACEHOLDER_API_KEY, + azure_openai=False, + **kwargs, + ) + self.client = _LiteLLMClientShim(api_key=resolved_api_key, base_url=resolved_base_url) diff --git a/pyproject.toml b/pyproject.toml index c2edce25c..e6d7aef3c 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,9 @@ metrics = [ gemini = [ "google-generativeai", ] +litellm = [ + "litellm", +] reka = [ "httpx>=0.23.3", "reka-api", diff --git a/test/models/test_litellm.py b/test/models/test_litellm.py new file mode 100644 index 000000000..87cff3e07 --- /dev/null +++ b/test/models/test_litellm.py @@ -0,0 +1,81 @@ +"""Unit tests for the LiteLLM backend (simple + chat variants).""" + +from __future__ import annotations + +import sys +import types +import unittest +from types import SimpleNamespace +from unittest import mock + + +def _install_litellm_stub() -> mock.MagicMock: + """Register a fake ``litellm`` module so ``import litellm`` resolves in tests.""" + fake = types.ModuleType("litellm") + fake.completion = mock.MagicMock(name="litellm.completion") + sys.modules["litellm"] = fake + return fake.completion + + +def _fake_chat_completion(content: str = "hi", prompt_tokens: int = 3, completion_tokens: int = 5) -> SimpleNamespace: + usage = SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, completion_tokens_details=None) + message = SimpleNamespace(content=content) + choice = SimpleNamespace(message=message, finish_reason="stop", index=0) + return SimpleNamespace(choices=[choice], usage=usage, id="cmpl-test", model="test") + + +class TestLiteLLMShim(unittest.TestCase): + def test_shim_forwards_to_litellm_completion(self): + completion = _install_litellm_stub() + completion.return_value = _fake_chat_completion("pong") + + from lmms_eval.models.simple.litellm import _LiteLLMClientShim + + shim = _LiteLLMClientShim(api_key="sk-user", base_url="https://proxy.example/v1") + resp = shim.chat.completions.create( + model="anthropic/claude-3-5-sonnet-20241022", + messages=[{"role": "user", "content": "ping"}], + ) + + self.assertEqual(resp.choices[0].message.content, "pong") + completion.assert_called_once() + kwargs = completion.call_args.kwargs + self.assertEqual(kwargs["model"], "anthropic/claude-3-5-sonnet-20241022") + self.assertEqual(kwargs["api_key"], "sk-user") + self.assertEqual(kwargs["api_base"], "https://proxy.example/v1") + + def test_shim_without_explicit_credentials(self): + """When no api_key/base_url is set, the shim must not inject them; LiteLLM + then falls back to provider-specific env vars (ANTHROPIC_API_KEY, ...).""" + completion = _install_litellm_stub() + completion.return_value = _fake_chat_completion() + + from lmms_eval.models.simple.litellm import _LiteLLMClientShim + + shim = _LiteLLMClientShim(api_key=None, base_url=None) + shim.chat.completions.create(model="openai/gpt-4o-mini", messages=[]) + + kwargs = completion.call_args.kwargs + self.assertNotIn("api_key", kwargs) + self.assertNotIn("api_base", kwargs) + + def test_litellm_registered_as_simple_and_chat(self): + """Confirm the model manifest is discoverable under both backends.""" + from lmms_eval.models import MODEL_REGISTRY_V2 + + manifest = MODEL_REGISTRY_V2.get_manifest("litellm") + self.assertEqual(manifest.model_id, "litellm") + self.assertEqual( + manifest.simple_class_path, + "lmms_eval.models.simple.litellm.LiteLLMCompatible", + ) + self.assertEqual( + manifest.chat_class_path, + "lmms_eval.models.chat.litellm.LiteLLMCompatible", + ) + self.assertIn("litellm_chat", manifest.aliases) + self.assertIn("litellm_compatible", manifest.aliases) + + +if __name__ == "__main__": + unittest.main()