Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ SAMBANOVA_API_KEY=

# Inception Labs
INCEPTION_API_KEY=

# Sarvam AI
SARVAM_API_KEY=
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

`aisuite` is a lightweight Python library that provides a **unified API for working with multiple Generative AI providers**.
It offers a consistent interface for models from *OpenAI, Anthropic, Google, Hugging Face, AWS, Cohere, Mistral, Ollama*, and others—abstracting away SDK differences, authentication details, and parameter variations.
It offers a consistent interface for models from *OpenAI, Anthropic, Google, Hugging Face, AWS, Cohere, Mistral, Ollama, Sarvam*, and others—abstracting away SDK differences, authentication details, and parameter variations.
Its design is modeled after OpenAI’s API style, making it instantly familiar and easy to adopt.

`aisuite` lets developers build and **run LLM-based or agentic applications across providers** with minimal setup.
Expand Down
54 changes: 54 additions & 0 deletions aisuite/providers/sarvam_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
from aisuite.provider import Provider, LLMError
from openai import OpenAI
from aisuite.providers.message_converter import OpenAICompliantMessageConverter


class SarvamMessageConverter(OpenAICompliantMessageConverter):
"""
Sarvam-specific message converter.
"""

pass


class SarvamProvider(Provider):
"""
Sarvam AI Provider using OpenAI-compatible client.
Sarvam uses a custom auth header (api-subscription-key) instead of Bearer token,
so we inject it via default_headers and use a placeholder api_key for the OpenAI SDK.
"""

def __init__(self, **config):
"""
Initialize the Sarvam provider with the given configuration.
"""
# Ensure API key is provided either in config or via environment variable
self.api_key = config.get("api_key", os.getenv("SARVAM_API_KEY"))
if not self.api_key:
raise ValueError(
"Sarvam API key is missing. Please provide it in the config or set the SARVAM_API_KEY environment variable."
)

self.client = OpenAI(
api_key="placeholder", # Required by OpenAI SDK, not used by Sarvam
base_url="https://api.sarvam.ai/v1",
default_headers={"api-subscription-key": self.api_key},
)
self.transformer = SarvamMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Sarvam chat completions endpoint using the OpenAI client.
"""
try:
transformed_messages = self.transformer.convert_request(messages)

response = self.client.chat.completions.create(
model=model,
messages=transformed_messages,
**kwargs,
)
return self.transformer.convert_response(response.model_dump())
except Exception as e:
raise LLMError(f"An error occurred: {e}")
44 changes: 44 additions & 0 deletions guides/sarvam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Sarvam AI

To use Sarvam AI with `aisuite`, you'll need a [Sarvam Platform](https://dashboard.sarvam.ai) account. After signing up, generate an API key from the platform dashboard. Once you have your key, add it to your environment as follows:

```shell
export SARVAM_API_KEY="your-sarvam-api-key"
```

## Create a Chat Completion

Install the `openai` Python client:

Example with pip:
```shell
pip install openai
```

Example with poetry:
```shell
poetry add openai
```

In your code:
```python
import aisuite as ai
client = ai.Client()

provider = "sarvam"
model_id = "sarvam-m"

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me about India."},
]

response = client.chat.completions.create(
model=f"{provider}:{model_id}",
messages=messages,
)

print(response.choices[0].message.content)
```

Happy coding! If you'd like to contribute, please read our [Contributing Guide](CONTRIBUTING.md).
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ huggingface = []
mistral = ["mistralai"]
ollama = []
openai = ["openai"]
sarvam = ["openai"]
watsonx = ["ibm-watsonx-ai"]
mcp = ["mcp", "nest-asyncio"]
all = ["anthropic", "boto3", "cerebras_cloud_sdk", "vertexai", "google-cloud-speech", "groq", "mistralai", "openai", "cohere", "ibm-watsonx-ai", "deepgram-sdk", "soundfile", "scipy", "numpy", "mcp", "nest-asyncio"] # To install all providers
Expand Down
120 changes: 120 additions & 0 deletions tests/providers/test_sarvam_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from unittest.mock import MagicMock, patch

import pytest

from aisuite.framework import ChatCompletionResponse
from aisuite.providers.sarvam_provider import SarvamProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("SARVAM_API_KEY", "test-api-key")


def test_sarvam_provider():
"""Test that the provider is initialized and chat completions are requested."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "sarvam-m"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = SarvamProvider()
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"choices": [
{"message": {"content": response_text_content, "role": "assistant"}}
]
}

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_once_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert isinstance(response, ChatCompletionResponse)
assert response.choices[0].message.content == response_text_content
assert response.usage is None


def test_sarvam_provider_with_usage():
"""Tests that usage data is correctly parsed when present in the response."""

message_history = [{"role": "user", "content": "Hello!"}]
selected_model = "sarvam-m"
response_text_content = "mocked-text-response-from-model"

provider = SarvamProvider()
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"choices": [
{"message": {"content": response_text_content, "role": "assistant"}}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
},
}

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
)

mock_create.assert_called_once_with(
messages=message_history,
model=selected_model,
)

assert isinstance(response, ChatCompletionResponse)
assert response.usage is not None
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 20
assert response.usage.total_tokens == 30


def test_sarvam_provider_missing_api_key(monkeypatch):
"""Tests that a missing API key raises a ValueError."""
monkeypatch.delenv("SARVAM_API_KEY", raising=False)

with pytest.raises(ValueError, match="Sarvam API key is missing"):
SarvamProvider()


def test_sarvam_provider_uses_correct_auth_header():
"""Tests that the provider sets the Sarvam-specific auth header on the client."""
provider = SarvamProvider()

assert provider.client.default_headers.get("api-subscription-key") == "test-api-key"


def test_sarvam_provider_api_key_from_config(monkeypatch):
"""Tests that the API key can be passed via config instead of env var."""
monkeypatch.delenv("SARVAM_API_KEY", raising=False)

provider = SarvamProvider(api_key="config-api-key")

assert provider.api_key == "config-api-key"
assert (
provider.client.default_headers.get("api-subscription-key") == "config-api-key"
)