From 3f3331c21d4c99cb4257502c0b12db238c5f58d3 Mon Sep 17 00:00:00 2001 From: jbauthess Date: Thu, 6 Mar 2025 11:12:48 +0100 Subject: [PATCH 1/2] add lmstudio provider --- aisuite/providers/lmstudio_provider.py | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 aisuite/providers/lmstudio_provider.py diff --git a/aisuite/providers/lmstudio_provider.py b/aisuite/providers/lmstudio_provider.py new file mode 100644 index 00000000..e6280bb2 --- /dev/null +++ b/aisuite/providers/lmstudio_provider.py @@ -0,0 +1,33 @@ +import lmstudio + +from aisuite.framework import ChatCompletionResponse +from aisuite.provider import LLMError, Provider + + +class LmstudioProvider(Provider): + def __init__(self, **config): + self.client = lmstudio.Client(**config) + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the Cerebras chat completions endpoint using the official client. + """ + try: + model = self.client.llm.model(model) + result = model.respond({"messages": messages}) + + # Return the normalized response + normalized_response = self._normalize_response(result) + return normalized_response + + # Wrap all other exceptions in LLMError. + except Exception as e: + raise LLMError("An error occurred.") from e + + def _normalize_response(self, response_data: lmstudio.PredictionResult): + """ + Normalize the lmstudio response to a common format (ChatCompletionResponse). + """ + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response_data.content + return normalized_response From 5595eabdf64f9f6cd348a538eb72556c12fea14a Mon Sep 17 00:00:00 2001 From: jbauthess Date: Thu, 6 Mar 2025 17:30:47 +0100 Subject: [PATCH 2/2] add tests to LMstudio provider LMStudio specific optional settings can now be passed --- README.md | 1 + aisuite/providers/lmstudio_provider.py | 21 ++++++++++--- pyproject.toml | 8 +++-- tests/providers/test_lmstudio_provider.py | 38 +++++++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 tests/providers/test_lmstudio_provider.py diff --git a/README.md b/README.md index b3b7c08c..5aabe977 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Currently supported providers are: - Google - Groq - HuggingFace Ollama +- LMStudio - Mistral - OpenAI - Sambanova diff --git a/aisuite/providers/lmstudio_provider.py b/aisuite/providers/lmstudio_provider.py index e6280bb2..aeaff128 100644 --- a/aisuite/providers/lmstudio_provider.py +++ b/aisuite/providers/lmstudio_provider.py @@ -7,16 +7,29 @@ class LmstudioProvider(Provider): def __init__(self, **config): self.client = lmstudio.Client(**config) + self.model: lmstudio.LLM | None = None + + def _chat(self, model: str, messages, **kwargs) -> lmstudio.PredictionResult[str]: + """ + Makes a request to the lmstudio chat completions endpoint using the official client + """ + # get an handle of the specified model (load it if necessary) + model = self.client.llm.model(model) + # send the request to the model + result = model.respond({"messages": messages}, **kwargs) + + return result def chat_completions_create(self, model, messages, **kwargs): """ - Makes a request to the Cerebras chat completions endpoint using the official client. + Makes a request to the lmstudio chat completions endpoint using the official client + and convert output to be conform to openAI model response """ try: - model = self.client.llm.model(model) - result = model.respond({"messages": messages}) + # --- send request to lmstudio endpoint + result = self._chat(model=model, messages=messages, **kwargs) - # Return the normalized response + # --- Return the normalized response normalized_response = self._normalize_response(result) return normalized_response diff --git a/pyproject.toml b/pyproject.toml index af25aa3f..21b58b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Andrew Ng, Rohit P"] readme = "README.md" [tool.poetry.dependencies] -python = "^3.10" +python = "^3.11" anthropic = { version = "^0.30.1", optional = true } boto3 = { version = "^1.34.144", optional = true } cohere = { version = "^5.12.0", optional = true } @@ -17,9 +17,11 @@ openai = { version = "^1.35.8", optional = true } ibm-watsonx-ai = { version = "^1.1.16", optional = true } docstring-parser = { version = "^0.14.0", optional = true } cerebras_cloud_sdk = { version = "^1.19.0", optional = true } +lmstudio = { version = "^1.0.1", optional = true } # Optional dependencies for different providers httpx = "~0.27.0" + [tool.poetry.extras] anthropic = ["anthropic"] aws = ["boto3"] @@ -30,11 +32,12 @@ deepseek = ["openai"] google = ["vertexai"] groq = ["groq"] huggingface = [] +lmstudio = ["lmstudio"] mistral = ["mistralai"] ollama = [] openai = ["openai"] watsonx = ["ibm-watsonx-ai"] -all = ["anthropic", "aws", "cerebras_cloud_sdk", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers +all = ["anthropic", "aws", "cerebras_cloud_sdk", "google", "groq", "mistral", "openai", "cohere", "watsonx", "lmstudio"] # To install all providers [tool.poetry.group.dev.dependencies] pre-commit = "^3.7.1" @@ -54,6 +57,7 @@ datasets = "^2.20.0" vertexai = "^1.63.0" ibm-watsonx-ai = "^1.1.16" cerebras_cloud_sdk = "^1.19.0" +lmstudio = "^1.0.1" [tool.poetry.group.test] optional = true diff --git a/tests/providers/test_lmstudio_provider.py b/tests/providers/test_lmstudio_provider.py new file mode 100644 index 00000000..c1aa576c --- /dev/null +++ b/tests/providers/test_lmstudio_provider.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from unittest.mock import patch + +from aisuite.providers.lmstudio_provider import LmstudioProvider + + +@dataclass +class MockResponse: + content: str + + +def test_lmstudio_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + config = { + "temperature": 0.6, + "maxTokens": 5000, + } + response_text_content = "mocked-text-response-from-model" + + provider = LmstudioProvider() + mock_response = MockResponse(response_text_content) + + with patch.object(provider, "_chat", return_value=mock_response) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + config=config, + ) + + mock_create.assert_called_with( + model=selected_model, messages=message_history, config=config + ) + + assert response.choices[0].message.content == response_text_content