Skip to content

RFC: explore sending content parts to LiteLLM instead of always sending strings to be deserialized for custom types #8280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions dspy/adapters/pass_through_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from dspy.adapters.chat_adapter import ChatAdapter
from typing import Any, Dict, NamedTuple, Optional, Type
from dspy.signatures.signature import Signature
from dspy.adapters.utils import (
format_field_value,
get_annotation_name,
get_field_description_string,
parse_value,
translate_field_type,
)
from dspy.adapters.types import BaseType
import itertools


def format_field_value(value) -> list[dict]:
if isinstance(value, str):
return [{"type": "text", "text": value}]
elif isinstance(value, list):
formatted_list = [format_field_value(v) for v in value]
flattened = list(itertools.chain.from_iterable(formatted_list))
return flattened
elif isinstance(value, BaseType) or hasattr(
value, "format"
): # Check if Custom Type
return value.format() # WARN: assumes a list. Dangerous.
else:
return value


class PassThroughChatAdapter(ChatAdapter):
def format(
self,
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
inputs_copy = dict(inputs)

# If the signature and inputs have conversation history, we need to format the conversation history and
# remove the history field from the signature.
history_field_name = self._get_history_field_name(signature)
if history_field_name:
# In order to format the conversation history, we need to remove the history field from the signature.
signature_without_history = signature.delete(history_field_name)
conversation_history = self.format_conversation_history(
signature_without_history,
history_field_name,
inputs_copy,
)

messages = []
system_message = (
f"{self.format_field_description(signature)}\n"
f"{self.format_field_structure(signature)}\n"
f"{self.format_task_description(signature)}"
)
messages.append({"role": "system", "content": system_message})
messages.extend(self.format_demos(signature, demos))
if history_field_name:
# Conversation history and current input
content_parts = self.format_user_message_content(
signature_without_history, inputs_copy, main_request=True
)
messages.extend(conversation_history)
messages.append({"role": "user", "content": content_parts})
else:
# Only current input
content_parts = self.format_user_message_content(
signature, inputs_copy, main_request=True
)
messages.append({"role": "user", "content": content_parts})

return messages

def format_user_message_content(
self,
signature: Type[Signature],
inputs: dict[str, Any],
prefix: str = "",
suffix: str = "",
main_request: bool = False,
) -> list[dict[str, Any]]:
messages = [{"type": "text", "text": prefix}]
for k, v in signature.input_fields.items():
messages.append(
{
"type": "text",
"text": f"[[ ## {k} ## ]]\n",
}
)

if k in inputs:
value = inputs.get(k)
normalized_value = format_field_value(value)
messages.extend(normalized_value)

if main_request:
output_requirements = self.user_message_output_requirements(signature)
if output_requirements is not None:
messages.append({"type": "text", "text": output_requirements})

messages.append({"type": "text", "text": suffix})
return messages
4 changes: 4 additions & 0 deletions dspy/adapters/types/base_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def format(self) -> list[dict[str, Any]]:
def format(self) -> list[dict[str, Any]]:
raise NotImplementedError

# WARN: This is the serialization step that then gets unserialized later
# to support image parts.
@pydantic.model_serializer()
def serialize_model(self):
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
Expand Down Expand Up @@ -81,6 +83,8 @@ def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> li
# Parse the JSON inside the block
custom_type_content = match.group(1).strip()
try:
# WARN: This occasionally fails to correctly deserialize string to original
# pre-CUSTOM_TYPE_START_IDENTIFIER string.
parsed = json_repair.loads(custom_type_content)
for custom_type_content in parsed:
result.append(custom_type_content)
Expand Down
222 changes: 222 additions & 0 deletions dspy/adapters/types/better_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import base64
import io
import mimetypes
import os
from typing import Any, Union
from urllib.parse import urlparse

import pydantic
import requests

from dspy.adapters.types.base_type import BaseType

try:
from PIL import Image as PILImage

PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False


class BetterImage(BaseType):
"""
Instead of being serialized inside CUSTOM_TYPE_START_IDENTIFIER
to get later deserialized for LiteLLM, return array of content parts
in format() and have it directly sent to LiteLLM.

This has the added bonus of setting a precedent type for others
who want to do other formats with N content parts.
"""

url: str

model_config = {
"frozen": True,
"str_strip_whitespace": True,
"validate_assignment": True,
"extra": "forbid",
}

def format(self) -> Union[list[dict[str, Any]], str]:
try:
image_url = encode_image(self.url)
except Exception as e:
raise ValueError(f"Failed to format image for DSPy: {e}")
return [{"type": "image_url", "image_url": {"url": image_url}}]

# NOTE: We keep it as a dict without stringifying it because our
# pass-through adapter sends the parts as is to LiteLLM.
@pydantic.model_serializer()
def serialize_model(self):
return self.format()

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, values):
# Allow the model to accept either a URL string or a dictionary with a single 'url' key
if isinstance(values, str):
# if a string, assume it's the URL directly and wrap it in a dict
return {"url": values}
elif isinstance(values, dict) and set(values.keys()) == {"url"}:
# if it's a dict, ensure it has only the 'url' key
return values
elif isinstance(values, cls):
return values.model_dump()
else:
raise TypeError("Expected a string URL or a dictionary with a key 'url'.")

# If all my inits just call encode_image, should that be in this class
@classmethod
def from_url(cls, url: str, download: bool = False):
return cls(url=encode_image(url, download))

@classmethod
def from_file(cls, file_path: str):
return cls(url=encode_image(file_path))

@classmethod
def from_PIL(cls, pil_image): # noqa: N802
return cls(url=encode_image(pil_image))

def __str__(self):
return self.serialize_model()

def __repr__(self):
if "base64" in self.url:
len_base64 = len(self.url.split("base64,")[1])
image_type = self.url.split(";")[0].split("/")[-1]
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({len_base64!s})>)"
return f"Image(url='{self.url}')"


def is_url(string: str) -> bool:
"""Check if a string is a valid URL."""
try:
result = urlparse(string)
return all([result.scheme in ("http", "https", "gs"), result.netloc])
except ValueError:
return False


def encode_image(
image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False
) -> str:
"""
Encode an image or file to a base64 data URI.

Args:
image: The image or file to encode. Can be a PIL Image, file path, URL, or data URI.
download_images: Whether to download images from URLs.

Returns:
str: The data URI of the file or the URL if download_images is False.

Raises:
ValueError: If the file type is not supported.
"""
if isinstance(image, dict) and "url" in image:
# NOTE: Not doing other validation for now
return image["url"]
elif isinstance(image, str):
if image.startswith("data:"):
# Already a data URI
return image
elif os.path.isfile(image):
# File path
return _encode_image_from_file(image)
elif is_url(image):
# URL
if download_images:
return _encode_image_from_url(image)
else:
# Return the URL as is
return image
else:
# Unsupported string format
print(f"Unsupported file string: {image}")
raise ValueError(f"Unsupported file string: {image}")
elif PIL_AVAILABLE and isinstance(image, PILImage.Image):
# PIL Image
return _encode_pil_image(image)
elif isinstance(image, bytes):
# Raw bytes
if not PIL_AVAILABLE:
raise ImportError("Pillow is required to process image bytes.")
img = PILImage.open(io.BytesIO(image))
return _encode_pil_image(img)
elif isinstance(image, Image):
return image.url
else:
print(f"Unsupported image type: {type(image)}")
raise ValueError(f"Unsupported image type: {type(image)}")


def _encode_image_from_file(file_path: str) -> str:
"""Encode a file from a file path to a base64 data URI."""
with open(file_path, "rb") as file:
file_data = file.read()

# Use mimetypes to guess directly from the file path
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
raise ValueError(f"Could not determine MIME type for file: {file_path}")

encoded_data = base64.b64encode(file_data).decode("utf-8")
return f"data:{mime_type};base64,{encoded_data}"


def _encode_image_from_url(image_url: str) -> str:
"""Encode a file from a URL to a base64 data URI."""
response = requests.get(image_url)
response.raise_for_status()
content_type = response.headers.get("Content-Type", "")

# Use the content type from the response headers if available
if content_type:
mime_type = content_type
else:
# Try to guess MIME type from URL
mime_type, _ = mimetypes.guess_type(image_url)
if mime_type is None:
raise ValueError(f"Could not determine MIME type for URL: {image_url}")

encoded_data = base64.b64encode(response.content).decode("utf-8")
return f"data:{mime_type};base64,{encoded_data}"


def _encode_pil_image(image: "PILImage") -> str:
"""Encode a PIL Image object to a base64 data URI."""
buffered = io.BytesIO()
file_format = image.format or "PNG"
image.save(buffered, format=file_format)

# Get the correct MIME type using the image format
file_extension = file_format.lower()
mime_type, _ = mimetypes.guess_type(f"file.{file_extension}")
if mime_type is None:
raise ValueError(
f"Could not determine MIME type for image format: {file_format}"
)

encoded_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:{mime_type};base64,{encoded_data}"


def _get_file_extension(path_or_url: str) -> str:
"""Extract the file extension from a file path or URL."""
extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip(".").lower()
return extension or "png" # Default to 'png' if no extension found


def is_image(obj) -> bool:
"""Check if the object is an image or a valid media file reference."""
if PIL_AVAILABLE and isinstance(obj, PILImage.Image):
return True
if isinstance(obj, str):
if obj.startswith("data:"):
return True
elif os.path.isfile(obj):
return True
elif is_url(obj):
return True
return False
Loading