diff --git a/crab/actions/file_actions.py b/crab/actions/file_actions.py index f876631..312c4a6 100644 --- a/crab/actions/file_actions.py +++ b/crab/actions/file_actions.py @@ -11,15 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== -import base64 -from io import BytesIO - from PIL import Image - -from crab.core import action - +from pydantic import Field +from crab.core.decorators import action +from crab.utils.common import base64_to_image @action -def save_base64_image(image: str, path: str = "image.png") -> None: - image = Image.open(BytesIO(base64.b64decode(image))) - image.save(path) +def save_image(image: str = Field(..., description="Base64 encoded image string"), path: str = Field(..., description="Path to save the image")): + """Save a base64 encoded image to a file.""" + img = base64_to_image(image) + img.save(path) + return f"Image saved to {path}" diff --git a/crab/agents/backend_models/glm_model.py b/crab/agents/backend_models/glm_model.py new file mode 100644 index 0000000..6279d12 --- /dev/null +++ b/crab/agents/backend_models/glm_model.py @@ -0,0 +1,158 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== + +import json +from typing import Any +from time import sleep + +from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType + +try: + from zhipuai import ZhipuAI + glm_model_enable = True +except ImportError: + glm_model_enable = False + +class GLMModel(BackendModel): + def __init__( + self, + model: str, + parameters: dict[str, Any] = dict(), + history_messages_len: int = 0, + ) -> None: + if not glm_model_enable: + raise ImportError("Please install zhipuai to use GLMModel") + super().__init__( + model, + parameters, + history_messages_len, + ) + self.client = ZhipuAI() + + def reset(self, system_message: str, action_space: list[Action] | None) -> None: + self.system_message = system_message + self.glm_system_message = { + "role": "system", + "content": system_message, + } + self.action_space = action_space + self.action_schema = self._convert_action_to_schema(self.action_space) + self.token_usage = 0 + self.chat_history = [] + + def chat(self, message: tuple[str, MessageType]): + request_messages = self._convert_to_request_messages(message) + response = self.call_api(request_messages) + + assistant_message = response.choices[0].message + action_list = self._convert_tool_calls_to_action_list(assistant_message) + + output = ChatOutput( + message=assistant_message.content if not action_list else None, + action_list=action_list, + ) + + self.record_message(request_messages[-1], assistant_message) + return output + + def get_token_usage(self): + return self.token_usage + + def record_message(self, new_message: dict, response_message: dict) -> None: + self.chat_history.append([new_message]) + self.chat_history[-1].append(response_message) + + if self.action_schema: + tool_calls = response_message.tool_calls + for tool_call in tool_calls: + self.chat_history[-1].append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": "", + } + ) + + def call_api(self, request_messages: list): + while True: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=request_messages, + **self.parameters, + ) + except Exception as e: + print(f"API call failed: {str(e)}. Retrying in 10 seconds...") + sleep(10) + else: + break + + self.token_usage += response.usage.total_tokens + return response + + @staticmethod + def _convert_action_to_schema(action_space: list[Action] | None): + if action_space is None: + return None + + tools = [] + for action in action_space: + tool = { + "type": "function", + "function": { + "name": action.name, + "description": action.description, + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + for param in action.parameters: + tool["function"]["parameters"]["properties"][param.name] = { + "type": param.type, + "description": param.description, + } + if param.required: + tool["function"]["parameters"]["required"].append(param.name) + tools.append(tool) + return tools + + @staticmethod + def _convert_tool_calls_to_action_list(self, message): + if not message.content or not message.content.startswith("arguments="): + return None + + action_list = [] + parts = message.content.split(", name=") + arguments = json.loads(parts[0].replace("arguments=", "").strip("'")) + name = parts[1].strip("'") + action_output = ActionOutput( + name=name, + args=arguments, + ) + action_list.append(action_output) + return action_list + + @staticmethod + def _convert_message(message: tuple[str, MessageType]): + content, message_type = message + if message_type == MessageType.TEXT: + return {"type": "text", "text": content} + elif message_type == MessageType.IMAGE_URL: + return {"type": "image_url", "image_url": {"url": content}} + else: + raise ValueError(f"Unsupported message type: {message_type}") \ No newline at end of file diff --git a/test/core/test_image_handling.py b/test/core/test_image_handling.py new file mode 100644 index 0000000..fbe33ee --- /dev/null +++ b/test/core/test_image_handling.py @@ -0,0 +1,104 @@ +import pytest +from PIL import Image +import io +import base64 +import os +from unittest.mock import patch, MagicMock +from crab.utils.common import base64_to_image, image_to_base64 +from crab.actions.file_actions import save_image + +import sys +# Mock the entire crab.agents.backend_models module +sys.modules['crab.agents.backend_models'] = MagicMock() +sys.modules['crab.agents.backend_models.openai_model'] = MagicMock() +sys.modules['crab.actions.desktop_actions'] = MagicMock() + +# Create mock classes/functions +class MockOpenAIModel: + def _convert_message(self, message): + return {"type": "image_url", "image_url": {"url": ""}} + +class MockMessageType: + IMAGE_JPG_BASE64 = "image_jpg_base64" + +def mock_screenshot(): + return Image.new('RGB', (100, 100), color='red') + +# Apply mocks +patch('crab.agents.backend_models.openai_model.OpenAIModel', MockOpenAIModel).start() +patch('crab.agents.backend_models.openai_model.MessageType', MockMessageType).start() +patch('crab.actions.desktop_actions.screenshot', mock_screenshot).start() + +class TestImageHandling: + @pytest.fixture(autouse=True) + def setup(self): + self.test_image = Image.new('RGB', (100, 100), color='red') + + def test_image_processing_path(self): + print("\n--- Image Processing Path Test ---") + + # Use self.test_image instead of taking a screenshot + screenshot_image = self.test_image + + # 1. Start with a PIL Image (using self.test_image) + print("1. Starting with a PIL Image") + assert isinstance(self.test_image, Image.Image) + + # 2. Simulate saving the image + print("2. Saving the image") + save_image(self.test_image, "test_image.png") + print(" Image saved successfully") + + # 3. Using self.test_image instead of taking a screenshot + print("3. Using self.test_image instead of taking a screenshot") + screenshot_image = self.test_image + assert isinstance(screenshot_image, Image.Image) + print(" Using self.test_image as PIL Image") + + # 4. Prepare for network transfer (serialize to base64) + print("4. Serializing image for network transfer") + base64_string = image_to_base64(self.test_image) + assert isinstance(base64_string, str) + print(" Image serialized to base64 string") + + # 5. Simulate network transfer + print("5. Simulating network transfer") + received_base64 = base64_string # In reality, this would be sent and received + + # 6. Deserialize after network transfer + print("6. Deserializing image after network transfer") + received_image = base64_to_image(received_base64) + assert isinstance(received_image, Image.Image) + print(" Image deserialized back to PIL Image") + + # 7. Use the image in a backend model (e.g., OpenAI) + print("7. Using image in backend model") + openai_model = MockOpenAIModel() + converted_message = openai_model._convert_message((received_image, MockMessageType.IMAGE_JPG_BASE64)) + assert converted_message["type"] == "image_url" + assert converted_message["image_url"]["url"].startswith("data:image/png;base64,") + print(" Image successfully converted for use in OpenAI model") + + print("--- Image Processing Path Test Completed Successfully ---") + + def test_base64_to_image(self): + # Convert image to base64 + base64_string = image_to_base64(self.test_image) + + # Test base64_to_image function + converted_image = base64_to_image(base64_string) + assert isinstance(converted_image, Image.Image) + assert converted_image.size == (100, 100) + + def test_image_to_base64(self): + # Test image_to_base64 function + base64_string = image_to_base64(self.test_image) + assert isinstance(base64_string, str) + + # Verify that the base64 string can be converted back to an image + converted_image = base64_to_image(base64_string) + assert converted_image.size == (100, 100) + +# Make sure to stop all patches after the tests +def teardown_module(module): + patch.stopall()