From c2a62c3b22818be6d647a9c3b517aed39d70be31 Mon Sep 17 00:00:00 2001 From: Amit Gawande Date: Sun, 22 Feb 2026 20:12:42 +0530 Subject: [PATCH] feat: add Sarvam AI provider --- .env.sample | 3 + README.md | 2 +- aisuite/providers/sarvam_provider.py | 54 +++++++++++ guides/sarvam.md | 44 +++++++++ pyproject.toml | 1 + tests/providers/test_sarvam_provider.py | 120 ++++++++++++++++++++++++ 6 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 aisuite/providers/sarvam_provider.py create mode 100644 guides/sarvam.md create mode 100644 tests/providers/test_sarvam_provider.py diff --git a/.env.sample b/.env.sample index d4c39bb1..f035af05 100644 --- a/.env.sample +++ b/.env.sample @@ -45,3 +45,6 @@ SAMBANOVA_API_KEY= # Inception Labs INCEPTION_API_KEY= + +# Sarvam AI +SARVAM_API_KEY= diff --git a/README.md b/README.md index 966348c5..455107f5 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) `aisuite` is a lightweight Python library that provides a **unified API for working with multiple Generative AI providers**. -It offers a consistent interface for models from *OpenAI, Anthropic, Google, Hugging Face, AWS, Cohere, Mistral, Ollama*, and others—abstracting away SDK differences, authentication details, and parameter variations. +It offers a consistent interface for models from *OpenAI, Anthropic, Google, Hugging Face, AWS, Cohere, Mistral, Ollama, Sarvam*, and others—abstracting away SDK differences, authentication details, and parameter variations. Its design is modeled after OpenAI’s API style, making it instantly familiar and easy to adopt. `aisuite` lets developers build and **run LLM-based or agentic applications across providers** with minimal setup. diff --git a/aisuite/providers/sarvam_provider.py b/aisuite/providers/sarvam_provider.py new file mode 100644 index 00000000..a70811d8 --- /dev/null +++ b/aisuite/providers/sarvam_provider.py @@ -0,0 +1,54 @@ +import os +from aisuite.provider import Provider, LLMError +from openai import OpenAI +from aisuite.providers.message_converter import OpenAICompliantMessageConverter + + +class SarvamMessageConverter(OpenAICompliantMessageConverter): + """ + Sarvam-specific message converter. + """ + + pass + + +class SarvamProvider(Provider): + """ + Sarvam AI Provider using OpenAI-compatible client. + Sarvam uses a custom auth header (api-subscription-key) instead of Bearer token, + so we inject it via default_headers and use a placeholder api_key for the OpenAI SDK. + """ + + def __init__(self, **config): + """ + Initialize the Sarvam provider with the given configuration. + """ + # Ensure API key is provided either in config or via environment variable + self.api_key = config.get("api_key", os.getenv("SARVAM_API_KEY")) + if not self.api_key: + raise ValueError( + "Sarvam API key is missing. Please provide it in the config or set the SARVAM_API_KEY environment variable." + ) + + self.client = OpenAI( + api_key="placeholder", # Required by OpenAI SDK, not used by Sarvam + base_url="https://api.sarvam.ai/v1", + default_headers={"api-subscription-key": self.api_key}, + ) + self.transformer = SarvamMessageConverter() + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the Sarvam chat completions endpoint using the OpenAI client. + """ + try: + transformed_messages = self.transformer.convert_request(messages) + + response = self.client.chat.completions.create( + model=model, + messages=transformed_messages, + **kwargs, + ) + return self.transformer.convert_response(response.model_dump()) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/guides/sarvam.md b/guides/sarvam.md new file mode 100644 index 00000000..58d8f713 --- /dev/null +++ b/guides/sarvam.md @@ -0,0 +1,44 @@ +# Sarvam AI + +To use Sarvam AI with `aisuite`, you'll need a [Sarvam Platform](https://dashboard.sarvam.ai) account. After signing up, generate an API key from the platform dashboard. Once you have your key, add it to your environment as follows: + +```shell +export SARVAM_API_KEY="your-sarvam-api-key" +``` + +## Create a Chat Completion + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "sarvam" +model_id = "sarvam-m" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me about India."}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you'd like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/pyproject.toml b/pyproject.toml index b67e6761..b01618d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ huggingface = [] mistral = ["mistralai"] ollama = [] openai = ["openai"] +sarvam = ["openai"] watsonx = ["ibm-watsonx-ai"] mcp = ["mcp", "nest-asyncio"] all = ["anthropic", "boto3", "cerebras_cloud_sdk", "vertexai", "google-cloud-speech", "groq", "mistralai", "openai", "cohere", "ibm-watsonx-ai", "deepgram-sdk", "soundfile", "scipy", "numpy", "mcp", "nest-asyncio"] # To install all providers diff --git a/tests/providers/test_sarvam_provider.py b/tests/providers/test_sarvam_provider.py new file mode 100644 index 00000000..b9da445a --- /dev/null +++ b/tests/providers/test_sarvam_provider.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.framework import ChatCompletionResponse +from aisuite.providers.sarvam_provider import SarvamProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("SARVAM_API_KEY", "test-api-key") + + +def test_sarvam_provider(): + """Test that the provider is initialized and chat completions are requested.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "sarvam-m" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = SarvamProvider() + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + {"message": {"content": response_text_content, "role": "assistant"}} + ] + } + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_once_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.choices[0].message.content == response_text_content + assert response.usage is None + + +def test_sarvam_provider_with_usage(): + """Tests that usage data is correctly parsed when present in the response.""" + + message_history = [{"role": "user", "content": "Hello!"}] + selected_model = "sarvam-m" + response_text_content = "mocked-text-response-from-model" + + provider = SarvamProvider() + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + {"message": {"content": response_text_content, "role": "assistant"}} + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + }, + } + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + ) + + mock_create.assert_called_once_with( + messages=message_history, + model=selected_model, + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.usage is not None + assert response.usage.prompt_tokens == 10 + assert response.usage.completion_tokens == 20 + assert response.usage.total_tokens == 30 + + +def test_sarvam_provider_missing_api_key(monkeypatch): + """Tests that a missing API key raises a ValueError.""" + monkeypatch.delenv("SARVAM_API_KEY", raising=False) + + with pytest.raises(ValueError, match="Sarvam API key is missing"): + SarvamProvider() + + +def test_sarvam_provider_uses_correct_auth_header(): + """Tests that the provider sets the Sarvam-specific auth header on the client.""" + provider = SarvamProvider() + + assert provider.client.default_headers.get("api-subscription-key") == "test-api-key" + + +def test_sarvam_provider_api_key_from_config(monkeypatch): + """Tests that the API key can be passed via config instead of env var.""" + monkeypatch.delenv("SARVAM_API_KEY", raising=False) + + provider = SarvamProvider(api_key="config-api-key") + + assert provider.api_key == "config-api-key" + assert ( + provider.client.default_headers.get("api-subscription-key") == "config-api-key" + )