From 9d2d9322ca6410e133919bee5d294ff487eafee7 Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso de Vargas <77084039+vargacypher@users.noreply.github.com> Date: Tue, 11 Mar 2025 13:03:58 +0000 Subject: [PATCH 1/4] Remove obsolete interface / Add safety_setting for google provider --- aisuite/framework/provider_interface.py | 26 ------------------------- aisuite/providers/google_provider.py | 8 ++++++-- 2 files changed, 6 insertions(+), 28 deletions(-) delete mode 100644 aisuite/framework/provider_interface.py diff --git a/aisuite/framework/provider_interface.py b/aisuite/framework/provider_interface.py deleted file mode 100644 index d942c1fe..00000000 --- a/aisuite/framework/provider_interface.py +++ /dev/null @@ -1,26 +0,0 @@ -"""The shared interface for model providers.""" - - -# TODO(rohit): Remove this. This interface is obsolete in favor of Provider. -class ProviderInterface: - """Defines the expected behavior for provider-specific interfaces.""" - - def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: - """Create a chat completion using the specified messages, model, and temperature. - - This method must be implemented by subclasses to perform completions. - - Args: - ---- - messages (list): The chat history. - model (str): The identifier of the model to be used in the completion. - temperature (float): The temperature to use in the completion. - - Raises: - ------ - NotImplementedError: If this method has not been implemented by a subclass. - - """ - raise NotImplementedError( - "Provider Interface has not implemented chat_completion_create()" - ) diff --git a/aisuite/providers/google_provider.py b/aisuite/providers/google_provider.py index 7d4b586b..b2f53bf7 100644 --- a/aisuite/providers/google_provider.py +++ b/aisuite/providers/google_provider.py @@ -189,8 +189,8 @@ def convert_response(response) -> ChatCompletionResponse: return openai_response -class GoogleProvider(ProviderInterface): - """Implements the ProviderInterface for interacting with Google's Vertex AI.""" +class GoogleProvider: + """Implements the Provider Interface for interacting with Google's Vertex AI.""" def __init__(self, **config): """Set up the Google AI client with a project ID.""" @@ -229,6 +229,9 @@ def chat_completions_create(self, model, messages, **kwargs): # Set the temperature if provided, otherwise use the default temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE) + # Set safety_settings if provided + safety_settings = kwargs.get("safety_settings") + # Convert messages to Vertex AI format message_history = self.transformer.convert_request(messages) @@ -274,6 +277,7 @@ def chat_completions_create(self, model, messages, **kwargs): model, generation_config=GenerationConfig(temperature=temperature), tools=tools, + safety_settings=safety_settings ) if ENABLE_DEBUG_MESSAGES: From b484e0d456159cb6b930306549259070b98959cf Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso de Vargas <77084039+vargacypher@users.noreply.github.com> Date: Tue, 11 Mar 2025 13:17:53 +0000 Subject: [PATCH 2/4] Remove unused imports --- aisuite/framework/__init__.py | 1 - aisuite/providers/google_provider.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/aisuite/framework/__init__.py b/aisuite/framework/__init__.py index bc7d71c4..c61f0113 100644 --- a/aisuite/framework/__init__.py +++ b/aisuite/framework/__init__.py @@ -1,3 +1,2 @@ -from .provider_interface import ProviderInterface from .chat_completion_response import ChatCompletionResponse from .message import Message diff --git a/aisuite/providers/google_provider.py b/aisuite/providers/google_provider.py index b2f53bf7..aa68895e 100644 --- a/aisuite/providers/google_provider.py +++ b/aisuite/providers/google_provider.py @@ -15,7 +15,7 @@ ) import pprint -from aisuite.framework import ProviderInterface, ChatCompletionResponse, Message +from aisuite.framework import ChatCompletionResponse, Message DEFAULT_TEMPERATURE = 0.7 From 4224677063247411850250b13673a1934fafd87f Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso de Vargas <77084039+vargacypher@users.noreply.github.com> Date: Wed, 12 Mar 2025 08:29:40 -0300 Subject: [PATCH 3/4] Update google_provider.py --- aisuite/providers/google_provider.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aisuite/providers/google_provider.py b/aisuite/providers/google_provider.py index aa68895e..01850adf 100644 --- a/aisuite/providers/google_provider.py +++ b/aisuite/providers/google_provider.py @@ -16,7 +16,7 @@ import pprint from aisuite.framework import ChatCompletionResponse, Message - +from aisuite.provider import Provider DEFAULT_TEMPERATURE = 0.7 ENABLE_DEBUG_MESSAGES = False @@ -189,7 +189,7 @@ def convert_response(response) -> ChatCompletionResponse: return openai_response -class GoogleProvider: +class GoogleProvider(Provider): """Implements the Provider Interface for interacting with Google's Vertex AI.""" def __init__(self, **config): From b73198345cf0bda1048f4bed9faf37d5004f2b3b Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso de Vargas <77084039+vargacypher@users.noreply.github.com> Date: Wed, 12 Mar 2025 08:53:58 -0300 Subject: [PATCH 4/4] Update google.md --- guides/google.md | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/guides/google.md b/guides/google.md index e357679e..5731e75e 100644 --- a/guides/google.md +++ b/guides/google.md @@ -89,4 +89,45 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` +## Safety Settings + +```python +from aisuite import Client + +client = Client({ + "google":{ + "project_id": "project-id", + "region": "us-central1", + } +}) + +model = "google:gemini-2.0-flash-001" + +messages = [{ + "role": "user", + "content": "I shouldn't use a public swimming pool"}] + +from vertexai.generative_models import ( + HarmCategory, + HarmBlockThreshold, + SafetySetting, +) + +safety_config = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), +] + +response = client.chat.completions.create( safety_settings=safety_config, + model=model, messages=messages) +print(response.choices[0].message.content) + +``` + Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md).