From e2ae4c22481adf056ab0fb88b64ae8ff72ea190d Mon Sep 17 00:00:00 2001 From: PYTHON01100100 Date: Mon, 26 Jan 2026 21:33:52 +0300 Subject: [PATCH] test: initial testing branch --- aisuite/providers/eas_pai_provider.py | 156 +++++++++++ aisuite/providers/fireworks_provider.py | 3 + tests/providers/test_eas_pai_provider.py | 321 +++++++++++++++++++++++ 3 files changed, 480 insertions(+) create mode 100644 aisuite/providers/eas_pai_provider.py create mode 100644 tests/providers/test_eas_pai_provider.py diff --git a/aisuite/providers/eas_pai_provider.py b/aisuite/providers/eas_pai_provider.py new file mode 100644 index 00000000..fa15b8e6 --- /dev/null +++ b/aisuite/providers/eas_pai_provider.py @@ -0,0 +1,156 @@ +import urllib.request +import json +import os + +from aisuite.provider import Provider +from aisuite.framework import ChatCompletionResponse +from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function + +# Alibaba Cloud EAS (Elastic Algorithm Service) PAI Provider +# EAS is part of Alibaba Cloud's PAI (Platform for AI) platform. +# Documentation: https://www.alibabacloud.com/help/en/pai/user-guide/eas-model-serving +# +# EAS provides model inference services with OpenAI-compatible API endpoints. +# The endpoint URL format is typically: +# https://..pai-eas.aliyuncs.com/api/predict/ +# Or for OpenAI-compatible endpoints: +# https:///v1/chat/completions +# +# Authentication is done via Token in the Authorization header. + + +class Eas_paiMessageConverter: + @staticmethod + def convert_request(messages): + """Convert messages to EAS PAI format (OpenAI-compatible).""" + transformed_messages = [] + for message in messages: + if isinstance(message, Message): + transformed_messages.append(message.model_dump(mode="json")) + else: + transformed_messages.append(message) + return transformed_messages + + @staticmethod + def convert_response(resp_json) -> ChatCompletionResponse: + """Normalize the response from EAS PAI API to match OpenAI's response format.""" + completion_response = ChatCompletionResponse() + choice = resp_json["choices"][0] + message = choice["message"] + + # Set basic message content + completion_response.choices[0].message.content = message.get("content") + completion_response.choices[0].message.role = message.get("role", "assistant") + + # Handle tool calls if present + if "tool_calls" in message and message["tool_calls"] is not None: + tool_calls = [] + for tool_call in message["tool_calls"]: + new_tool_call = ChatCompletionMessageToolCall( + id=tool_call["id"], + type=tool_call["type"], + function={ + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + }, + ) + tool_calls.append(new_tool_call) + completion_response.choices[0].message.tool_calls = tool_calls + + return completion_response + + +class Eas_paiProvider(Provider): + """ + Alibaba Cloud EAS PAI Provider for aisuite. + + Configuration options: + - base_url: The EAS endpoint URL (or set EAS_PAI_BASE_URL env var) + - api_key: The EAS service token (or set EAS_PAI_API_KEY env var) + + Usage: + client = aisuite.Client() + client.configure({ + "eas_pai": { + "base_url": "https://your-service.region.pai-eas.aliyuncs.com/api/predict/your-service", + "api_key": "your-eas-token" + } + }) + response = client.chat.completions.create( + model="eas_pai:your-model-name", + messages=[{"role": "user", "content": "Hello!"}] + ) + """ + + def __init__(self, **config): + self.base_url = config.get("base_url") or os.getenv("EAS_PAI_BASE_URL") + self.api_key = config.get("api_key") or os.getenv("EAS_PAI_API_KEY") + + if not self.api_key: + raise ValueError( + "For EAS PAI, api_key is required. " + "Set it via config or EAS_PAI_API_KEY environment variable." + ) + if not self.base_url: + raise ValueError( + "For EAS PAI, base_url is required. " + "Set it via config or EAS_PAI_BASE_URL environment variable. " + "Example: https://..pai-eas.aliyuncs.com/api/predict/" + ) + + self.transformer = Eas_paiMessageConverter() + + def chat_completions_create(self, model, messages, **kwargs): + # Determine the endpoint URL + # If base_url already contains /chat/completions, use it directly + # Otherwise, append /v1/chat/completions for OpenAI-compatible endpoints + if "/chat/completions" in self.base_url: + url = self.base_url + elif self.base_url.endswith("/v1"): + url = f"{self.base_url}/chat/completions" + else: + # For standard EAS endpoints, the URL is used as-is + # The model inference happens at the base URL + url = self.base_url + + # Remove 'stream' from kwargs if present (streaming not supported) + kwargs.pop("stream", None) + + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + # Prepare the request payload + data = { + "model": model, + "messages": transformed_messages, + } + + # Add tools if provided + if "tools" in kwargs: + data["tools"] = kwargs.pop("tools") + + # Add tool_choice if provided + if "tool_choice" in kwargs: + data["tool_choice"] = kwargs.pop("tool_choice") + + # Add remaining kwargs (temperature, max_tokens, etc.) + data.update(kwargs) + + body = json.dumps(data).encode("utf-8") + headers = { + "Content-Type": "application/json", + "Authorization": self.api_key, + } + + try: + req = urllib.request.Request(url, body, headers) + with urllib.request.urlopen(req) as response: + result = response.read() + resp_json = json.loads(result) + return self.transformer.convert_response(resp_json) + + except urllib.error.HTTPError as error: + error_message = f"EAS PAI request failed with status code: {error.code}\n" + error_message += f"Headers: {error.info()}\n" + error_message += error.read().decode("utf-8", "ignore") + raise Exception(error_message) diff --git a/aisuite/providers/fireworks_provider.py b/aisuite/providers/fireworks_provider.py index 10bea195..e4db9e51 100644 --- a/aisuite/providers/fireworks_provider.py +++ b/aisuite/providers/fireworks_provider.py @@ -139,3 +139,6 @@ def _normalize_response(self, response_data): "message" ]["content"] return normalized_response + + +# test branch diff --git a/tests/providers/test_eas_pai_provider.py b/tests/providers/test_eas_pai_provider.py new file mode 100644 index 00000000..a2cf7ae8 --- /dev/null +++ b/tests/providers/test_eas_pai_provider.py @@ -0,0 +1,321 @@ +"""Tests for the Eas_paiProvider and Eas_paiMessageConverter.""" + +import json +import unittest +from unittest.mock import patch, MagicMock + +from aisuite.providers.eas_pai_provider import Eas_paiProvider, Eas_paiMessageConverter +from aisuite.framework import ChatCompletionResponse +from aisuite.framework.message import Message + + +class TestEasPaiMessageConverter(unittest.TestCase): + """Test suite for the Eas_paiMessageConverter class.""" + + def setUp(self): + """Set up the test case.""" + self.converter = Eas_paiMessageConverter() + + def test_convert_request_dict_messages(self): + """Test converting dict messages.""" + messages = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + converted = self.converter.convert_request(messages) + + self.assertEqual(len(converted), 2) + self.assertEqual(converted[0]["role"], "user") + self.assertEqual(converted[0]["content"], "Hello!") + self.assertEqual(converted[1]["role"], "assistant") + self.assertEqual(converted[1]["content"], "Hi there!") + + def test_convert_request_message_objects(self): + """Test converting Message objects.""" + messages = [ + Message(role="user", content="Hello!"), + ] + converted = self.converter.convert_request(messages) + + self.assertEqual(len(converted), 1) + self.assertEqual(converted[0]["role"], "user") + self.assertEqual(converted[0]["content"], "Hello!") + + def test_convert_response_normal_message(self): + """Test converting a normal text response.""" + resp_json = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello! How can I help you?", + } + } + ] + } + + response = self.converter.convert_response(resp_json) + + self.assertIsInstance(response, ChatCompletionResponse) + self.assertEqual( + response.choices[0].message.content, "Hello! How can I help you?" + ) + self.assertEqual(response.choices[0].message.role, "assistant") + + def test_convert_response_with_tool_calls(self): + """Test converting a response with tool calls.""" + resp_json = { + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "Beijing"}', + }, + } + ], + } + } + ] + } + + response = self.converter.convert_response(resp_json) + + self.assertIsInstance(response, ChatCompletionResponse) + self.assertIsNone(response.choices[0].message.content) + self.assertEqual(len(response.choices[0].message.tool_calls), 1) + self.assertEqual(response.choices[0].message.tool_calls[0].id, "call_123") + self.assertEqual(response.choices[0].message.tool_calls[0].type, "function") + self.assertEqual( + response.choices[0].message.tool_calls[0].function.name, "get_weather" + ) + self.assertEqual( + response.choices[0].message.tool_calls[0].function.arguments, + '{"location": "Beijing"}', + ) + + +class TestEasPaiProvider(unittest.TestCase): + """Test suite for the Eas_paiProvider class.""" + + def test_init_with_config(self): + """Test initialization with config parameters.""" + provider = Eas_paiProvider( + base_url="https://test.pai-eas.aliyuncs.com/api/predict/test", + api_key="test-token", + ) + + self.assertEqual( + provider.base_url, "https://test.pai-eas.aliyuncs.com/api/predict/test" + ) + self.assertEqual(provider.api_key, "test-token") + + def test_init_with_env_vars(self): + """Test initialization with environment variables.""" + with patch.dict( + "os.environ", + { + "EAS_PAI_BASE_URL": "https://env.pai-eas.aliyuncs.com/api/predict/env", + "EAS_PAI_API_KEY": "env-token", + }, + ): + provider = Eas_paiProvider() + + self.assertEqual( + provider.base_url, + "https://env.pai-eas.aliyuncs.com/api/predict/env", + ) + self.assertEqual(provider.api_key, "env-token") + + def test_init_missing_api_key_raises_error(self): + """Test that missing api_key raises ValueError.""" + with patch.dict("os.environ", {}, clear=True): + with self.assertRaises(ValueError) as context: + Eas_paiProvider( + base_url="https://test.pai-eas.aliyuncs.com/api/predict/test" + ) + + self.assertIn("api_key is required", str(context.exception)) + + def test_init_missing_base_url_raises_error(self): + """Test that missing base_url raises ValueError.""" + with patch.dict("os.environ", {}, clear=True): + with self.assertRaises(ValueError) as context: + Eas_paiProvider(api_key="test-token") + + self.assertIn("base_url is required", str(context.exception)) + + def test_chat_completions_create(self): + """Test chat completions create request.""" + provider = Eas_paiProvider( + base_url="https://test.pai-eas.aliyuncs.com/api/predict/test", + api_key="test-token", + ) + + mock_response = { + "choices": [ + {"message": {"role": "assistant", "content": "Hello from EAS!"}} + ] + } + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(mock_response).encode("utf-8") + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp) as mock_urlopen: + response = provider.chat_completions_create( + model="qwen-turbo", + messages=[{"role": "user", "content": "Hi"}], + temperature=0.7, + ) + + # Verify the request was made + mock_urlopen.assert_called_once() + request = mock_urlopen.call_args[0][0] + + # Verify request body + body = json.loads(request.data.decode("utf-8")) + self.assertEqual(body["model"], "qwen-turbo") + self.assertEqual(body["messages"], [{"role": "user", "content": "Hi"}]) + self.assertEqual(body["temperature"], 0.7) + + # Verify headers + self.assertEqual(request.get_header("Content-type"), "application/json") + self.assertEqual(request.get_header("Authorization"), "test-token") + + # Verify response + self.assertEqual( + response.choices[0].message.content, "Hello from EAS!" + ) + + def test_chat_completions_create_with_tools(self): + """Test chat completions create request with tools.""" + provider = Eas_paiProvider( + base_url="https://test.pai-eas.aliyuncs.com/api/predict/test", + api_key="test-token", + ) + + mock_response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "Shanghai"}', + }, + } + ], + } + } + ] + } + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(mock_response).encode("utf-8") + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather info", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + } + ] + + with patch("urllib.request.urlopen", return_value=mock_resp) as mock_urlopen: + response = provider.chat_completions_create( + model="qwen-turbo", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=tools, + ) + + # Verify tools were included in request + request = mock_urlopen.call_args[0][0] + body = json.loads(request.data.decode("utf-8")) + self.assertEqual(body["tools"], tools) + + # Verify response has tool calls + self.assertEqual(len(response.choices[0].message.tool_calls), 1) + self.assertEqual( + response.choices[0].message.tool_calls[0].function.name, "get_weather" + ) + + def test_chat_completions_url_with_chat_completions_endpoint(self): + """Test URL handling when base_url already contains /chat/completions.""" + provider = Eas_paiProvider( + base_url="https://test.pai-eas.aliyuncs.com/v1/chat/completions", + api_key="test-token", + ) + + mock_response = { + "choices": [{"message": {"role": "assistant", "content": "Hi"}}] + } + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(mock_response).encode("utf-8") + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp) as mock_urlopen: + provider.chat_completions_create( + model="qwen-turbo", + messages=[{"role": "user", "content": "Hi"}], + ) + + request = mock_urlopen.call_args[0][0] + self.assertEqual( + request.full_url, + "https://test.pai-eas.aliyuncs.com/v1/chat/completions", + ) + + def test_stream_parameter_is_removed(self): + """Test that stream parameter is removed from kwargs.""" + provider = Eas_paiProvider( + base_url="https://test.pai-eas.aliyuncs.com/api/predict/test", + api_key="test-token", + ) + + mock_response = { + "choices": [{"message": {"role": "assistant", "content": "Hi"}}] + } + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(mock_response).encode("utf-8") + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp) as mock_urlopen: + provider.chat_completions_create( + model="qwen-turbo", + messages=[{"role": "user", "content": "Hi"}], + stream=True, # This should be removed + ) + + request = mock_urlopen.call_args[0][0] + body = json.loads(request.data.decode("utf-8")) + self.assertNotIn("stream", body) + + +if __name__ == "__main__": + unittest.main()