diff --git a/dspy/adapters/pass_through_adapter.py b/dspy/adapters/pass_through_adapter.py new file mode 100644 index 0000000000..add953e5ba --- /dev/null +++ b/dspy/adapters/pass_through_adapter.py @@ -0,0 +1,103 @@ +from dspy.adapters.chat_adapter import ChatAdapter +from typing import Any, Dict, NamedTuple, Optional, Type +from dspy.signatures.signature import Signature +from dspy.adapters.utils import ( + format_field_value, + get_annotation_name, + get_field_description_string, + parse_value, + translate_field_type, +) +from dspy.adapters.types import BaseType +import itertools + + +def format_field_value(value) -> list[dict]: + if isinstance(value, str): + return [{"type": "text", "text": value}] + elif isinstance(value, list): + formatted_list = [format_field_value(v) for v in value] + flattened = list(itertools.chain.from_iterable(formatted_list)) + return flattened + elif isinstance(value, BaseType) or hasattr( + value, "format" + ): # Check if Custom Type + return value.format() # WARN: assumes a list. Dangerous. + else: + return value + + +class PassThroughChatAdapter(ChatAdapter): + def format( + self, + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ) -> list[dict[str, Any]]: + inputs_copy = dict(inputs) + + # If the signature and inputs have conversation history, we need to format the conversation history and + # remove the history field from the signature. + history_field_name = self._get_history_field_name(signature) + if history_field_name: + # In order to format the conversation history, we need to remove the history field from the signature. + signature_without_history = signature.delete(history_field_name) + conversation_history = self.format_conversation_history( + signature_without_history, + history_field_name, + inputs_copy, + ) + + messages = [] + system_message = ( + f"{self.format_field_description(signature)}\n" + f"{self.format_field_structure(signature)}\n" + f"{self.format_task_description(signature)}" + ) + messages.append({"role": "system", "content": system_message}) + messages.extend(self.format_demos(signature, demos)) + if history_field_name: + # Conversation history and current input + content_parts = self.format_user_message_content( + signature_without_history, inputs_copy, main_request=True + ) + messages.extend(conversation_history) + messages.append({"role": "user", "content": content_parts}) + else: + # Only current input + content_parts = self.format_user_message_content( + signature, inputs_copy, main_request=True + ) + messages.append({"role": "user", "content": content_parts}) + + return messages + + def format_user_message_content( + self, + signature: Type[Signature], + inputs: dict[str, Any], + prefix: str = "", + suffix: str = "", + main_request: bool = False, + ) -> list[dict[str, Any]]: + messages = [{"type": "text", "text": prefix}] + for k, v in signature.input_fields.items(): + messages.append( + { + "type": "text", + "text": f"[[ ## {k} ## ]]\n", + } + ) + + if k in inputs: + value = inputs.get(k) + normalized_value = format_field_value(value) + messages.extend(normalized_value) + + if main_request: + output_requirements = self.user_message_output_requirements(signature) + if output_requirements is not None: + messages.append({"type": "text", "text": output_requirements}) + + messages.append({"type": "text", "text": suffix}) + return messages diff --git a/dspy/adapters/types/base_type.py b/dspy/adapters/types/base_type.py index f2983ef463..f51b384fbb 100644 --- a/dspy/adapters/types/base_type.py +++ b/dspy/adapters/types/base_type.py @@ -29,6 +29,8 @@ def format(self) -> list[dict[str, Any]]: def format(self) -> list[dict[str, Any]]: raise NotImplementedError + # WARN: This is the serialization step that then gets unserialized later + # to support image parts. @pydantic.model_serializer() def serialize_model(self): return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}" @@ -81,6 +83,8 @@ def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> li # Parse the JSON inside the block custom_type_content = match.group(1).strip() try: + # WARN: This occasionally fails to correctly deserialize string to original + # pre-CUSTOM_TYPE_START_IDENTIFIER string. parsed = json_repair.loads(custom_type_content) for custom_type_content in parsed: result.append(custom_type_content) diff --git a/dspy/adapters/types/better_image.py b/dspy/adapters/types/better_image.py new file mode 100644 index 0000000000..380838b925 --- /dev/null +++ b/dspy/adapters/types/better_image.py @@ -0,0 +1,222 @@ +import base64 +import io +import mimetypes +import os +from typing import Any, Union +from urllib.parse import urlparse + +import pydantic +import requests + +from dspy.adapters.types.base_type import BaseType + +try: + from PIL import Image as PILImage + + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False + + +class BetterImage(BaseType): + """ + Instead of being serialized inside CUSTOM_TYPE_START_IDENTIFIER + to get later deserialized for LiteLLM, return array of content parts + in format() and have it directly sent to LiteLLM. + + This has the added bonus of setting a precedent type for others + who want to do other formats with N content parts. + """ + + url: str + + model_config = { + "frozen": True, + "str_strip_whitespace": True, + "validate_assignment": True, + "extra": "forbid", + } + + def format(self) -> Union[list[dict[str, Any]], str]: + try: + image_url = encode_image(self.url) + except Exception as e: + raise ValueError(f"Failed to format image for DSPy: {e}") + return [{"type": "image_url", "image_url": {"url": image_url}}] + + # NOTE: We keep it as a dict without stringifying it because our + # pass-through adapter sends the parts as is to LiteLLM. + @pydantic.model_serializer() + def serialize_model(self): + return self.format() + + @pydantic.model_validator(mode="before") + @classmethod + def validate_input(cls, values): + # Allow the model to accept either a URL string or a dictionary with a single 'url' key + if isinstance(values, str): + # if a string, assume it's the URL directly and wrap it in a dict + return {"url": values} + elif isinstance(values, dict) and set(values.keys()) == {"url"}: + # if it's a dict, ensure it has only the 'url' key + return values + elif isinstance(values, cls): + return values.model_dump() + else: + raise TypeError("Expected a string URL or a dictionary with a key 'url'.") + + # If all my inits just call encode_image, should that be in this class + @classmethod + def from_url(cls, url: str, download: bool = False): + return cls(url=encode_image(url, download)) + + @classmethod + def from_file(cls, file_path: str): + return cls(url=encode_image(file_path)) + + @classmethod + def from_PIL(cls, pil_image): # noqa: N802 + return cls(url=encode_image(pil_image)) + + def __str__(self): + return self.serialize_model() + + def __repr__(self): + if "base64" in self.url: + len_base64 = len(self.url.split("base64,")[1]) + image_type = self.url.split(";")[0].split("/")[-1] + return f"Image(url=data:image/{image_type};base64,)" + return f"Image(url='{self.url}')" + + +def is_url(string: str) -> bool: + """Check if a string is a valid URL.""" + try: + result = urlparse(string) + return all([result.scheme in ("http", "https", "gs"), result.netloc]) + except ValueError: + return False + + +def encode_image( + image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False +) -> str: + """ + Encode an image or file to a base64 data URI. + + Args: + image: The image or file to encode. Can be a PIL Image, file path, URL, or data URI. + download_images: Whether to download images from URLs. + + Returns: + str: The data URI of the file or the URL if download_images is False. + + Raises: + ValueError: If the file type is not supported. + """ + if isinstance(image, dict) and "url" in image: + # NOTE: Not doing other validation for now + return image["url"] + elif isinstance(image, str): + if image.startswith("data:"): + # Already a data URI + return image + elif os.path.isfile(image): + # File path + return _encode_image_from_file(image) + elif is_url(image): + # URL + if download_images: + return _encode_image_from_url(image) + else: + # Return the URL as is + return image + else: + # Unsupported string format + print(f"Unsupported file string: {image}") + raise ValueError(f"Unsupported file string: {image}") + elif PIL_AVAILABLE and isinstance(image, PILImage.Image): + # PIL Image + return _encode_pil_image(image) + elif isinstance(image, bytes): + # Raw bytes + if not PIL_AVAILABLE: + raise ImportError("Pillow is required to process image bytes.") + img = PILImage.open(io.BytesIO(image)) + return _encode_pil_image(img) + elif isinstance(image, Image): + return image.url + else: + print(f"Unsupported image type: {type(image)}") + raise ValueError(f"Unsupported image type: {type(image)}") + + +def _encode_image_from_file(file_path: str) -> str: + """Encode a file from a file path to a base64 data URI.""" + with open(file_path, "rb") as file: + file_data = file.read() + + # Use mimetypes to guess directly from the file path + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type is None: + raise ValueError(f"Could not determine MIME type for file: {file_path}") + + encoded_data = base64.b64encode(file_data).decode("utf-8") + return f"data:{mime_type};base64,{encoded_data}" + + +def _encode_image_from_url(image_url: str) -> str: + """Encode a file from a URL to a base64 data URI.""" + response = requests.get(image_url) + response.raise_for_status() + content_type = response.headers.get("Content-Type", "") + + # Use the content type from the response headers if available + if content_type: + mime_type = content_type + else: + # Try to guess MIME type from URL + mime_type, _ = mimetypes.guess_type(image_url) + if mime_type is None: + raise ValueError(f"Could not determine MIME type for URL: {image_url}") + + encoded_data = base64.b64encode(response.content).decode("utf-8") + return f"data:{mime_type};base64,{encoded_data}" + + +def _encode_pil_image(image: "PILImage") -> str: + """Encode a PIL Image object to a base64 data URI.""" + buffered = io.BytesIO() + file_format = image.format or "PNG" + image.save(buffered, format=file_format) + + # Get the correct MIME type using the image format + file_extension = file_format.lower() + mime_type, _ = mimetypes.guess_type(f"file.{file_extension}") + if mime_type is None: + raise ValueError( + f"Could not determine MIME type for image format: {file_format}" + ) + + encoded_data = base64.b64encode(buffered.getvalue()).decode("utf-8") + return f"data:{mime_type};base64,{encoded_data}" + + +def _get_file_extension(path_or_url: str) -> str: + """Extract the file extension from a file path or URL.""" + extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip(".").lower() + return extension or "png" # Default to 'png' if no extension found + + +def is_image(obj) -> bool: + """Check if the object is an image or a valid media file reference.""" + if PIL_AVAILABLE and isinstance(obj, PILImage.Image): + return True + if isinstance(obj, str): + if obj.startswith("data:"): + return True + elif os.path.isfile(obj): + return True + elif is_url(obj): + return True + return False