diff --git a/adalflow/adalflow/components/model_client/openai_client.py b/adalflow/adalflow/components/model_client/openai_client.py index 6d7c4fe49..b5ec9cb7f 100644 --- a/adalflow/adalflow/components/model_client/openai_client.py +++ b/adalflow/adalflow/components/model_client/openai_client.py @@ -1,1255 +1,1276 @@ -"""OpenAI ModelClient integration.""" - -import os -import base64 -from typing import ( - Dict, - Sequence, - Optional, - List, - Any, - TypeVar, - Callable, - Generator as GeneratorType, - Union, - Literal, - Iterable, - AsyncIterable, - AsyncGenerator, -) -import re - -import logging -import backoff - -# optional import -from adalflow.utils.lazy_import import safe_import, OptionalPackages - - -openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) - -from openai import ( - OpenAI, - AsyncOpenAI, -) # , Stream # COMMENTED OUT - USING RESPONSE API ONLY -from openai import ( - APITimeoutError, - InternalServerError, - RateLimitError, - UnprocessableEntityError, - BadRequestError, -) -from openai.types import ( - Completion, - CreateEmbeddingResponse, - Image, -) - -# from openai.types.chat import ChatCompletionChunk, ChatCompletion # COMMENTED OUT - USING RESPONSE API ONLY -from openai.types.responses import Response, ResponseUsage -from adalflow.core.model_client import ModelClient -from adalflow.core.types import ( - ModelType, - EmbedderOutput, - ResponseUsage as AdalFlowResponseUsage, - InputTokensDetails, - OutputTokensDetails, - GeneratorOutput, -) -from dataclasses import dataclass - -from adalflow.components.model_client.utils import ( - parse_embedding_response, - format_content_for_response_api, -) - -log = logging.getLogger(__name__) -T = TypeVar("T") - - -@dataclass -class ParsedResponseContent: - """Structured container for parsed response content from OpenAI Response API. - - This dataclass provides a consistent interface for accessing different types - of content that can be returned by the Response API, including text, images, - tool calls, reasoning chains, and more. - - Attributes: - text: The main text content from the response - images: List of image data (base64 or URLs) from image generation - tool_calls: List of other tool call results - reasoning: Reasoning chain from reasoning models - code_outputs: Outputs from code interpreter - raw_output: The original output array for advanced processing - """ - text: Optional[str] = None - images: Optional[Union[str, List[str]]] = None - tool_calls: Optional[List[Dict[str, Any]]] = None - reasoning: Optional[List[Dict[str, Any]]] = None - code_outputs: Optional[List[Dict[str, Any]]] = None - raw_output: Optional[Any] = None - - def __bool__(self) -> bool: - """Check if there's any content.""" - return any([ - self.text, - self.images, - self.tool_calls, - self.reasoning, - self.code_outputs - ]) - - -# OLD CHAT COMPLETION PARSING FUNCTIONS (COMMENTED OUT) -# # completion parsing functions and you can combine them into one single chat completion parser -# def get_first_message_content(completion: ChatCompletion) -> str: -# r"""When we only need the content of the first message. -# It is the default parser for chat completion.""" -# log.debug(f"raw completion: {completion}") -# return completion.choices[0].message.content - - -def get_response_output_text(response: Response) -> str: - """Used to extract the data field for the reasoning model""" - log.debug(f"raw response: {response}") - return response.output_text - - -def parse_response_output(response: Response) -> ParsedResponseContent: - """Parse response output that may include various types of content and tool calls. - - The output array can contain: - - Output messages (with nested content items) - - Tool calls (file search, function, web search, computer use, etc.) - - Reasoning chains - - Image generation calls - - Code interpreter calls - - And more... - - Returns: - ParsedResponseContent: Structured content with typed access to all response data - """ - log.debug(f"raw response from api: {response}") - - content = ParsedResponseContent() - - # Store raw output for advanced users - if hasattr(response, 'output'): - content.raw_output = response.output - - # First try to use output_text if available (SDK convenience property) - if hasattr(response, 'output_text') and response.output_text: - content.text = response.output_text - # Parse the output array manually if no output_text - if hasattr(response, 'output') and response.output: - parsed = _parse_output_array(response.output) - content.text = content.text or parsed.get("text") - content.images = parsed.get("images", []) - content.tool_calls = parsed.get("tool_calls") - content.reasoning = parsed.get("reasoning") - content.code_outputs = parsed.get("code_outputs") - - return content - - - -def _parse_message(item) -> Dict[str, Any]: - """Parse a message item from the output array. - - Args: - item: A message item with type="message" and content array - - Returns: - Dict with parsed text and images from the message - """ - result = {"text": None} - - if hasattr(item, 'content') and isinstance(item.content, list): - # now pick the longer response - text_parts = [] - - for content_item in item.content: - content_type = getattr(content_item, 'type', None) - - if content_type == "output_text": - if hasattr(content_item, 'text'): - text_parts.append(content_item.text) - - if text_parts: - result["text"] = max(text_parts, key=len) if len(text_parts) > 1 else text_parts[0] - - return result - - -def _parse_reasoning(item) -> Dict[str, Any]: - """Parse a reasoning item from the output array. - - Args: - item: A reasoning item with type="reasoning" and summary array - - Returns: - Dict with extracted reasoning text and full structure - """ - result = {"reasoning": None} - - # Extract text from reasoning summary if available - if hasattr(item, 'summary') and isinstance(item.summary, list): - summary_texts = [] - for summary_item in item.summary: - if hasattr(summary_item, 'type') and summary_item.type == "summary_text": - if hasattr(summary_item, 'text'): - summary_texts.append(summary_item.text) - - if summary_texts: - # Store reasoning text separately for later combination - result["reasoning"] = "\n".join(summary_texts) - - return result - - -def _parse_image(item) -> Dict[str, Any]: - """Parse an image generation call item from the output array. - - Args: - item: An image generation item with type="image_generation_call" and result field - - Returns: - Dict with extracted image data - """ - result = {"images": None} - - if hasattr(item, 'result'): - # The result contains the base64 image data or URL - result["images"] = item.result - - return result - - -def _parse_tool_call(item) -> Dict[str, Any]: - """Parse a tool call item from the output array. - - Args: - item: A tool call item (various types ending in _call or containing tool_call) - - Returns: - Dict with tool call information - """ - item_type = getattr(item, 'type', None) - - if item_type == "image_generation_call": - # Handle image generation - extract the result which contains the image data - if hasattr(item, 'result'): - # The result contains the base64 image data or URL - return {"images": item.result} - elif item_type == "code_interpreter_tool_call": - return {"code_outputs": [_serialize_item(item)]} - else: - # Generic tool call - return { - "tool_calls": [{ - "type": item_type, - "content": _serialize_item(item) - }] - } - - return {} - - -def _parse_output_array(output_array) -> Dict[str, Any]: - """Parse the entire output array, processing all elements. - - The output array typically contains: - 1. Reasoning (optional) - thinking/reasoning before the response - 2. Message - the actual response with content - 3. Tool calls (optional) - any tool invocations - - Returns: - Dict with keys: text, images, tool_calls, reasoning, code_outputs - """ - result = { - "text": None, - "images": None, - "tool_calls": None, - "reasoning": None, - "code_outputs": None - } - - if not output_array: - return result - - # Process all items in the array - all_images = [] - all_tool_calls = [] - all_code_outputs = [] - all_reasoning = None - text = None - - for item in output_array: - item_type = getattr(item, 'type', None) - - if item_type == "reasoning": - # Parse reasoning item - parsed = _parse_reasoning(item) - if parsed.get("reasoning"): - all_reasoning = parsed["reasoning"] - - elif item_type == "message": - # Parse message item - parsed = _parse_message(item) - if parsed.get("text"): - text = parsed["text"] - - elif item_type == "image_generation_call": - # Parse image generation call separately - parsed = _parse_image(item) - if parsed.get("images"): - all_images.append(parsed["images"]) - - elif item_type and ('call' in item_type or 'tool' in item_type): - # Parse other tool calls - parsed = _parse_tool_call(item) - if parsed.get("tool_calls"): - all_tool_calls.extend(parsed["tool_calls"]) - if parsed.get("code_outputs"): - all_code_outputs.extend(parsed["code_outputs"]) - - - result["text"] = text if text else None # TODO: they can potentially send multiple complete text messages, we might need to save all of them and only return the first that can convert to outpu parser - - # Set other fields if they have content - result["images"] = all_images - if all_tool_calls: - result["tool_calls"] = all_tool_calls - if all_reasoning: - result["reasoning"] = all_reasoning - if all_code_outputs: - result["code_outputs"] = all_code_outputs - - return result - - -def _serialize_item(item) -> Dict[str, Any]: - """Convert an output item to a serializable dict.""" - result = {} - for attr in dir(item): - if not attr.startswith('_'): - value = getattr(item, attr, None) - if value is not None and not callable(value): - result[attr] = value - return result - - -# def _get_chat_completion_usage(completion: ChatCompletion) -> OpenAICompletionUsage: -# return completion.usage - - -# A simple heuristic to estimate token count for estimating number of tokens in a Streaming response -def estimate_token_count(text: str) -> int: - """ - Estimate the token count of a given text. - - Args: - text (str): The text to estimate token count for. - - Returns: - int: Estimated token count. - """ - # Split the text into tokens using spaces as a simple heuristic - tokens = text.split() - - # Return the number of tokens - return len(tokens) - - -# OLD CHAT COMPLETION STREAMING FUNCTIONS (COMMENTED OUT) -# def parse_stream_chat_completion(completion: ChatCompletionChunk) -> str: -# r"""Parse the completion chunks of the chat completion API.""" -# output = completion.choices[0].delta.content -# if hasattr(completion, "citations"): -# citations = completion.citations -# return output, citations -# return output - - -# def handle_streaming_chat_completion(generator: Stream[ChatCompletionChunk]): -# r"""Handle the streaming completion.""" -# for completion in generator: -# log.debug(f"Raw chunk completion: {completion}") -# parsed_content = parse_stream_chat_completion(completion) -# yield parsed_content - - -async def handle_streaming_response( - stream: AsyncIterable[Any], -) -> AsyncGenerator[str, None]: - """ - Async generator that processes a stream of SSE events from client.responses.create(..., stream=True). - - Args: - stream: An async iterable of SSE events from the OpenAI API - - Yields: - str: Non-empty text fragments parsed from the stream events - """ - async for event in stream: - yield event - - -def handle_streaming_response_sync(stream: Iterable) -> GeneratorType: - """ - Synchronous version: Iterate over an SSE stream from client.responses.create(..., stream=True), - logging each raw event and yielding non-empty text fragments. - """ - # already compatible as this is the OpenAI client - for event in stream: - yield event - - - - -class OpenAIClient(ModelClient): - __doc__ = r"""A component wrapper for the OpenAI API client. - - Support both embedding and response API, including multimodal capabilities. - - - Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client. - (2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project. - - Note: - We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API. - We do not know how OpenAI is doing the formating or what prompt they have added. - Instead - - use :ref:`OutputParser` for response parsing and formating. - - For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them. - The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini). - - For image generation, use model_type=ModelType.IMAGE_GENERATION and provide: - - model: "dall-e-3" or "dall-e-2" - - prompt: Text description of the image to generate - - size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 - - quality: "standard" or "hd" (DALL-E 3 only) - - n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) - - response_format: "url" or "b64_json" - - Examples: - Basic text generation:: - - from adalflow.components.model_client import OpenAIClient - from adalflow.core import Generator - - # Initialize client (uses OPENAI_API_KEY env var by default) - client = OpenAIClient() - - # Create a generator for text - generator = Generator( - model_client=client, - model_kwargs={"model": "gpt-4o-mini"} - ) - - # Generate response - response = generator(prompt_kwargs={"input_str": "What is machine learning?"}) - print(response.data) - - Multimodal with URL image:: - - # Vision model with image from URL - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o", - "images": "https://example.com/chart.jpg" - } - ) - - response = generator( - prompt_kwargs={"input_str": "Analyze this chart and explain the trends"} - ) - - Multimodal with local images:: - - # Multiple local images - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o", - "images": [ - "/path/to/image1.jpg", - "/path/to/image2.png" - ] - } - ) - - response = generator( - prompt_kwargs={"input_str": "Compare these two images"} - ) - - Pre-formatted images with custom encoding:: - - import base64 - from adalflow.core.functional import encode_image - - # Option 1: Using the encode_image helper - base64_img = encode_image("/path/to/image.jpg") - - # Option 2: Manual base64 encoding - with open("/path/to/image.png", "rb") as f: - base64_img = base64.b64encode(f.read()).decode('utf-8') - - # Use pre-formatted image data - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o", - "images": [ - # Pre-formatted as base64 data URI - f"data:image/png;base64,{base64_img}", - # Or as a dict with type and image_url - { - "type": "input_image", - "image_url": f"data:image/jpeg;base64,{base64_img}" - }, - # Mix with regular URLs - "https://example.com/chart.jpg" - ] - } - ) - - response = generator( - prompt_kwargs={"input_str": "Analyze these images"} - ) - - Reasoning models (O1, O3):: - - from adalflow.core.types import ModelType - - # O3 reasoning model with effort configuration - generator = Generator( - model_client=OpenAIClient(), - model_type=ModelType.LLM_REASONING, - model_kwargs={ - "model": "o3", - "reasoning": { - "effort": "medium", # low, medium, high - "summary": "auto" # detailed, auto, none - } - } - ) - - response = generator( - prompt_kwargs={"input_str": "Solve this complex problem: ..."} - ) - - Image generation with DALL-E (legacy method):: - - from adalflow.core.types import ModelType - - # Generate an image using ModelType.IMAGE_GENERATION - generator = Generator( - model_client=OpenAIClient(), - model_type=ModelType.IMAGE_GENERATION, - model_kwargs={ - "model": "dall-e-3", - "size": "1024x1792", - "quality": "hd", - "n": 1 - } - ) - - response = generator( - prompt_kwargs={"input_str": "A futuristic city with flying cars at sunset"} - ) - # response.data contains the image URL or base64 data - - Image generation via tools (new API):: - - import base64 - - # Generate images using the new tools API - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o-mini", # or any model that supports tools - "tools": [{"type": "image_generation"}] - } - ) - - # Generate an image - response = generator( - prompt_kwargs={ - "input_str": "Generate an image of a gray tabby cat hugging an otter with an orange scarf" - } - ) - - # Access the generated image(s) - if isinstance(response.data, list): - # Multiple images - for i, img_base64 in enumerate(response.data): - with open(f"generated_{i}.png", "wb") as f: - f.write(base64.b64decode(img_base64)) - elif isinstance(response.data, str): - # Single image - with open("generated.png", "wb") as f: - f.write(base64.b64decode(response.data)) - elif isinstance(response.data, dict) and "images" in response.data: - # Mixed response with text and images - print("Text:", response.data["text"]) - for i, img_base64 in enumerate(response.data["images"]): - with open(f"generated_{i}.png", "wb") as f: - f.write(base64.b64decode(img_base64)) - - Embeddings:: - - from adalflow.core import Embedder - - # Create embedder - embedder = Embedder( - model_client=OpenAIClient(), - model_kwargs={"model": "text-embedding-3-small"} - ) - - # Generate embeddings - embeddings = embedder(input=["Hello world", "Machine learning"]) - print(embeddings.data) # List of embedding vectors - - Streaming responses:: - - from adalflow.components.model_client.utils import extract_text_from_response_stream - - # Enable streaming - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o", - "stream": True - } - ) - - # Stream the response - response = generator(prompt_kwargs={"input_str": "Tell me a story"}) - - # Extract text from Response API streaming events - for event in response.raw_response: - text = extract_text_from_response_stream(event) - if text: - print(text, end="") - - Custom API endpoint:: - - # Use with third-party providers or local models - client = OpenAIClient( - base_url="https://api.custom-provider.com/v1/", - api_key="your-api-key", - headers={"X-Custom-Header": "value"} - ) - - Args: - api_key (Optional[str], optional): OpenAI API key. Defaults to `None`. - non_streaming_chat_completion_parser (Callable[[Completion], Any], optional): Legacy parser for chat completions. - Defaults to `None` (deprecated). - streaming_chat_completion_parser (Callable[[Completion], Any], optional): Legacy parser for streaming chat completions. - Defaults to `None` (deprecated). - non_streaming_response_parser (Callable[[Response], Any], optional): The parser for non-streaming responses. - Defaults to `get_response_output_text`. - streaming_response_parser (Callable[[Response], Any], optional): The parser for streaming responses. - Defaults to `handle_streaming_response`. - input_type (Literal["text", "messages"]): Input type for the client. Defaults to "text". - base_url (str): The API base URL to use when initializing the client. - Defaults to `"https://api.openai.com/v1/"`, but can be customized for third-party API providers or self-hosted models. - env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. - organization (Optional[str], optional): OpenAI organization key. Defaults to None. - headers (Optional[Dict[str, str]], optional): Additional headers to include in API requests. Defaults to None. - - References: - - OpenAI API Overview: https://platform.openai.com/docs/introduction, https://platform.openai.com/docs/guides/images-vision?api-mode=responses - - Embeddings Guide: https://platform.openai.com/docs/guides/embeddings - - Chat Completion Models: https://platform.openai.com/docs/guides/text-generation - - Response api: https://platform.openai.com/docs/api-reference/responses/create, Analyze images and use them as input and/or generate images as output - - Vision Models: https://platform.openai.com/docs/guides/vision - - Image Generation: https://platform.openai.com/docs/guides/images - - reasoning: https://platform.openai.com/docs/guides/reasoning - - Note: - - Ensure each OpenAIClient instance is used by one generator only. - """ - - def __init__( - self, - api_key: Optional[str] = None, - # OLD CHAT COMPLETION PARSER PARAMS (kept for backward compatibility) - non_streaming_chat_completion_parser: Optional[ - Callable[[Completion], Any] - ] = None, # non-streaming parser - deprecated but accepted - streaming_chat_completion_parser: Optional[ - Callable[[Completion], Any] - ] = None, # streaming parser - deprecated but accepted - # Response API parsers (used for reasoning models) - non_streaming_response_parser: Optional[Callable[[Response], Any]] = None, - streaming_response_parser: Optional[Callable[[Response], Any]] = None, - input_type: Literal["text", "messages"] = "text", - base_url: str = "https://api.openai.com/v1/", - env_api_key_name: str = "OPENAI_API_KEY", - organization: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - ): - r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument. - - Args: - api_key (Optional[str], optional): OpenAI API key. Defaults to None. - non_streaming_chat_completion_parser (Optional[Callable[[Completion], Any]], optional): DEPRECATED - Legacy parser for chat completions. Ignored, kept for backward compatibility. Defaults to None. - streaming_chat_completion_parser (Optional[Callable[[Completion], Any]], optional): DEPRECATED - Legacy parser for streaming chat completions. Ignored, kept for backward compatibility. Defaults to None. - non_streaming_response_parser (Optional[Callable[[Response], Any]], optional): Parser for non-streaming responses. Defaults to None. - streaming_response_parser (Optional[Callable[[Response], Any]], optional): Parser for streaming responses. Defaults to None. - input_type (Literal["text", "messages"]): Input type for the client. Defaults to "text". - base_url (str): The API base URL to use when initializing the client. - env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. - organization (Optional[str], optional): OpenAI organization key. Defaults to None. - headers (Optional[Dict[str, str]], optional): Additional headers to include in API requests. Defaults to None. - """ - # Log deprecation warning if old parsers are provided - if non_streaming_chat_completion_parser is not None: - log.warning( - "non_streaming_chat_completion_parser is deprecated and will be ignored. " - "The OpenAI client now uses the Response API exclusively." - ) - if streaming_chat_completion_parser is not None: - log.warning( - "streaming_chat_completion_parser is deprecated and will be ignored. " - "The OpenAI client now uses the Response API exclusively." - ) - - super().__init__() - self._api_key = api_key - self.base_url = base_url - self._env_api_key_name = env_api_key_name - self.organization = organization - self.headers = headers or {} - self.sync_client = self.init_sync_client() - self.async_client = None # only initialize if the async call is called - self._input_type = input_type - self._api_kwargs = {} # add api kwargs when the OpenAI Client is called - - # Response API parsers (RESPONSE API ONLY NOW) - # (used for both synchronous and asynchronous (stream + non-streaming) calls via Response API) - self.non_streaming_response_parser = ( - non_streaming_response_parser or get_response_output_text - ) - # Separate sync and async streaming parsers - self.streaming_response_parser_sync = handle_streaming_response_sync - self.streaming_response_parser_async = ( - streaming_response_parser or handle_streaming_response - ) - - # Default parsers (will be set dynamically based on sync/async context) - self.response_parser = self.non_streaming_response_parser - self.streaming_response_parser = ( - self.streaming_response_parser_async - ) # Default to async - # self.chat_completion_parser = self.non_streaming_chat_completion_parser # COMMENTED OUT - - def init_sync_client(self): - api_key = self._api_key or os.getenv(self._env_api_key_name) - if not api_key: - raise ValueError( - f"Environment variable {self._env_api_key_name} must be set" - ) - return OpenAI( - api_key=api_key, - base_url=self.base_url, - organization=self.organization, - default_headers=self.headers, - ) - - def init_async_client(self): - api_key = self._api_key or os.getenv(self._env_api_key_name) - if not api_key: - raise ValueError( - f"Environment variable {self._env_api_key_name} must be set" - ) - return AsyncOpenAI( - api_key=api_key, - base_url=self.base_url, - organization=self.organization, - default_headers=self.headers, - ) - - # NEW RESPONSE API ONLY FUNCTION - def parse_chat_completion( - self, - completion: Union[Response, AsyncIterable], - ) -> "GeneratorOutput": - """Parse the Response API completion and put it into the raw_response. - Fully migrated to Response API only.""" - - parser = self.response_parser - log.info(f"completion/response: {completion}, parser: {parser}") - - # Check if this is a Response with complex output (tools, images, etc.) - if isinstance(completion, Response): - parsed_content = parse_response_output(completion) - usage = self.track_completion_usage(completion) - - data = parsed_content.text - - thinking = None - if parsed_content.reasoning: - thinking = str(parsed_content.reasoning) - - - return GeneratorOutput( - data=data, # only text - thinking=thinking, - images=parsed_content.images, # List of image data (base64 or URLs) - tool_use=None, # Will be populated when we handle function tool calls - error=None, - raw_response=data, - usage=usage - ) - # Regular response handling (streaming or other) - data = parser(completion) - usage = self.track_completion_usage(completion) - return GeneratorOutput(data=None, error=None, raw_response=data, usage=usage) - - - # NEW RESPONSE API ONLY FUNCTION - def track_completion_usage( - self, - completion: Union[Response, AsyncIterable], - ) -> ResponseUsage: - """Track usage for Response API only.""" - if isinstance(completion, Response): - # Handle Response object with ResponseUsage structure - input_tokens_details = InputTokensDetails( - cached_tokens=getattr(completion.usage, "cached_tokens", 0) - ) - - output_tokens_details = OutputTokensDetails( - reasoning_tokens=getattr(completion.usage, "reasoning_tokens", 0) - ) - - return AdalFlowResponseUsage( - input_tokens=completion.usage.input_tokens, - input_tokens_details=input_tokens_details, - output_tokens=completion.usage.output_tokens, - output_tokens_details=output_tokens_details, - total_tokens=completion.usage.total_tokens, - ) - - # otherwise return the AdalFlowResponseUsage with None values with log warnings - elif hasattr(completion, "__aiter__") or hasattr(completion, "__iter__"): - log.debug( - "Cannot track usage for generator/iterator. Usage tracking should be handled when consuming the stream." - ) - else: - log.debug(f"Unknown completion type: {type(completion)}") - - return AdalFlowResponseUsage( - input_tokens=None, - input_tokens_details=InputTokensDetails(cached_tokens=0), - output_tokens=None, - output_tokens_details=OutputTokensDetails(reasoning_tokens=0), - total_tokens=None, - ) - - def parse_embedding_response( - self, response: CreateEmbeddingResponse - ) -> EmbedderOutput: - r"""Parse the embedding response to a structure Adalflow components can understand. - - Should be called in ``Embedder``. - """ - try: - return parse_embedding_response(response) - except Exception as e: - log.error(f"Error parsing the embedding response: {e}") - return EmbedderOutput(data=[], error=str(e), raw_response=response) - - def _convert_llm_inputs_to_messages( - self, - input: Optional[Any] = None, - images: Optional[Any] = None, - detail: Optional[str] = "auto", - ) -> List[Dict[str, str]]: - # convert input to messages - messages: List[Dict[str, str]] = [] - if self._input_type == "messages": - system_start_tag = "" - system_end_tag = "" - user_start_tag = "" - user_end_tag = "" - - # new regex pattern to ignore special characters such as \n - pattern = ( - rf"{system_start_tag}\s*(.*?)\s*{system_end_tag}\s*" - rf"{user_start_tag}\s*(.*?)\s*{user_end_tag}" - ) - - # Compile the regular expression - regex = re.compile(pattern, re.DOTALL) - - # re.DOTALL is to allow . to match newline so that (.*?) does not match in a single line - regex = re.compile(pattern, re.DOTALL) - # Match the pattern - match = regex.match(input) - system_prompt, input_str = None, None - - if match: - system_prompt = match.group(1) - input_str = match.group(2) - else: - print("No match found.") - if system_prompt and input_str: - messages.append({"role": "system", "content": system_prompt}) - if images: - content = [{"type": "text", "text": input_str}] - if isinstance(images, (str, dict)): - images = [images] - for img in images: - content.append(self._prepare_image_content(img, detail)) - messages.append({"role": "user", "content": content}) - else: - messages.append({"role": "user", "content": input_str}) - if len(messages) == 0: - if images: - content = [{"type": "text", "text": input}] - if isinstance(images, (str, dict)): - images = [images] - for img in images: - content.append(self._prepare_image_content(img, detail)) - messages.append({"role": "user", "content": content}) - else: - messages.append({"role": "system", "content": input}) - return messages - - # adapted for the response api - def convert_inputs_to_api_kwargs( - self, - input: Optional[Any] = None, - model_kwargs: Dict = {}, - model_type: ModelType = ModelType.UNDEFINED, - ) -> Dict: - r""" - Specify the API input type and output api_kwargs that will be used in _call and _acall methods. - Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format. - For multimodal inputs, images can be provided in model_kwargs["images"] as a string path, URL, or list of them. - The model specified in model_kwargs["model"] must support multimodal capabilities when using images. - - Args: - input: The input text or messages to process - model_kwargs: Additional parameters including: - - images: Optional image source(s) as path, URL, or list of them - - detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto' - - model: The model to use (must support multimodal inputs if images are provided) - model_type: The type of model (EMBEDDER or LLM) - - Returns: - Dict: API-specific kwargs for the model call - """ - - final_model_kwargs = model_kwargs.copy() - if model_type == ModelType.EMBEDDER: - if isinstance(input, str): - input = [input] - # convert input to input - if not isinstance(input, Sequence): - raise TypeError("input must be a sequence of text") - final_model_kwargs["input"] = input - elif model_type == ModelType.LLM or model_type == ModelType.LLM_REASONING: - # Check if images are provided for multimodal input - images = final_model_kwargs.pop("images", None) - - if images: - # Use helper function to format content with images - content = format_content_for_response_api(input, images) - - # For responses.create API, wrap in user message format - final_model_kwargs["input"] = [ - { - "role": "user", - "content": content - } - ] - else: - # Text-only input - final_model_kwargs["input"] = input - else: - raise ValueError(f"model_type {model_type} is not supported") - return final_model_kwargs - - def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput: - """Parse the image generation response into a GeneratorOutput.""" - try: - # Extract URLs or base64 data from the response - data = [img.url or img.b64_json for img in response] - # For single image responses, unwrap from list - if len(data) == 1: - data = data[0] - return GeneratorOutput( - data=data, - raw_response=str(response), - ) - except Exception as e: - log.error(f"Error parsing image generation response: {e}") - return GeneratorOutput(data=None, error=str(e), raw_response=str(response)) - - @backoff.on_exception( - backoff.expo, - ( - APITimeoutError, - InternalServerError, - RateLimitError, - UnprocessableEntityError, - BadRequestError, - ), - max_time=5, - ) - def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): - """ - kwargs is the combined input and model_kwargs. Support streaming call. - For reasoning model, users can add "reasoning" key to the api_kwargs to pass the reasoning config. - eg: - model_kwargs = { - "model": "gpt-4o-reasoning", - "reasoning": { - "effort": "medium", # low, medium, highc - "summary": "auto", #detailed, auto, none - } - } - """ - self._api_kwargs = api_kwargs - if model_type == ModelType.EMBEDDER: - return self.sync_client.embeddings.create(**api_kwargs) - # OLD CHAT COMPLETION CALLS (COMMENTED OUT) - # elif model_type == ModelType.LLM: - # if "stream" in api_kwargs and api_kwargs.get("stream", False): - # log.debug("streaming call") - # self.chat_completion_parser = self.streaming_chat_completion_parser - # return self.sync_client.chat.completions.create(**api_kwargs) - # else: - # log.debug("non-streaming call") - # self.chat_completion_parser = self.non_streaming_chat_completion_parser - # return self.sync_client.chat.completions.create(**api_kwargs) - elif model_type == ModelType.LLM_REASONING or model_type == ModelType.LLM: - if "stream" in api_kwargs and api_kwargs.get("stream", False): - log.debug("streaming call") - self.response_parser = ( - self.streaming_response_parser_sync - ) # Use sync streaming parser - return self.sync_client.responses.create(**api_kwargs) - else: - log.debug("non-streaming call") - self.response_parser = self.non_streaming_response_parser - return self.sync_client.responses.create(**api_kwargs) - - else: - raise ValueError(f"model_type {model_type} is not supported") - - @backoff.on_exception( - backoff.expo, - ( - APITimeoutError, - InternalServerError, - RateLimitError, - UnprocessableEntityError, - BadRequestError, - ), - max_time=5, - ) - async def acall( - self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED - ): - """ - kwargs is the combined input and model_kwargs. Support async streaming call. - - This method now relies on the OpenAI Responses API to handle streaming and non-streaming calls - with the asynchronous client - """ - # store the api kwargs in the client - self._api_kwargs = api_kwargs - if self.async_client is None: - self.async_client = self.init_async_client() - if model_type == ModelType.EMBEDDER: - return await self.async_client.embeddings.create(**api_kwargs) - # old chat completions api calls (commented out) - # elif model_type == ModelType.LLM: - # return await self.async_client.chat.completions.create(**api_kwargs) - # elif model_type == ModelType.LLM_REASONING: - # if "stream" in api_kwargs and api_kwargs.get("stream", False): - # log.debug("async streaming call") - # self.response_parser = self.streaming_response_parser - # # setting response parser as async streaming parser for Response API - # return await self.async_client.responses.create(**api_kwargs) - # else: - # log.debug("async non-streaming call") - # self.response_parser = self.non_streaming_response_parser - # # setting response parser as async non-streaming parser for Response API - # return await self.async_client.responses.create(**api_kwargs) - elif model_type == ModelType.LLM or model_type == ModelType.LLM_REASONING: - if "stream" in api_kwargs and api_kwargs.get("stream", False): - log.debug("async streaming call") - self.response_parser = ( - self.streaming_response_parser_async - ) # Use async streaming parser - # setting response parser as async streaming parser for Response API - return await self.async_client.responses.create(**api_kwargs) - else: - log.debug("async non-streaming call") - self.response_parser = self.non_streaming_response_parser - # setting response parser as async non-streaming parser for Response API - return await self.async_client.responses.create(**api_kwargs) - elif model_type == ModelType.IMAGE_GENERATION: - # Determine which image API to call based on the presence of image/mask - if "image" in api_kwargs: - if "mask" in api_kwargs: - # Image edit - response = await self.async_client.images.edit(**api_kwargs) - else: - # Image variation - response = await self.async_client.images.create_variation( - **api_kwargs - ) - else: - # Image generation - response = await self.async_client.images.generate(**api_kwargs) - return response.data - else: - raise ValueError(f"model_type {model_type} is not supported") - - @classmethod - def from_dict(cls: type[T], data: Dict[str, Any]) -> T: - obj = super().from_dict(data) - # recreate the existing clients - obj.sync_client = obj.init_sync_client() - obj.async_client = obj.init_async_client() - return obj - - def to_dict(self) -> Dict[str, Any]: - r"""Convert the component to a dictionary.""" - # TODO: not exclude but save yes or no for recreating the clients - exclude = [ - "sync_client", - "async_client", - ] # unserializable object - output = super().to_dict(exclude=exclude) - return output - - def _encode_image(self, image_path: str) -> str: - """Encode image to base64 string. - - Args: - image_path: Path to image file. - - Returns: - Base64 encoded image string. - - Raises: - ValueError: If the file cannot be read or doesn't exist. - """ - try: - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") - except FileNotFoundError: - raise ValueError(f"Image file not found: {image_path}") - except PermissionError: - raise ValueError(f"Permission denied when reading image file: {image_path}") - except Exception as e: - raise ValueError(f"Error encoding image {image_path}: {str(e)}") - - def _prepare_image_content( - self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" - ) -> Dict[str, Any]: - """Prepare image content for API request. - - Args: - image_source: Either a path to local image or a URL. - detail: Image detail level ('auto', 'low', or 'high'). - - Returns: - Formatted image content for API request. - """ - if isinstance(image_source, str): - if image_source.startswith(("http://", "https://")): - return { - "type": "image_url", - "image_url": {"url": image_source, "detail": detail}, - } - else: - base64_image = self._encode_image(image_source) - return { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - "detail": detail, - }, - } - return image_source - - -# Example usage: -if __name__ == "__main__": - from adalflow.core import Generator - from adalflow.utils import setup_env - - # log = get_logger(level="DEBUG") - - setup_env() - prompt_kwargs = {"input_str": "What is the meaning of life?"} - - gen = Generator( - model_client=OpenAIClient(), - model_kwargs={"model": "gpt-3.5-turbo", "stream": False}, - ) - gen_response = gen(prompt_kwargs) - print(f"gen_response: {gen_response}") - - # for genout in gen_response.data: - # print(f"genout: {genout}") - - # test that to_dict and from_dict works - # model_client = OpenAIClient() - # model_client_dict = model_client.to_dict() - # from_dict_model_client = OpenAIClient.from_dict(model_client_dict) - # assert model_client_dict == from_dict_model_client.to_dict() - - -if __name__ == "__main__": - - def test_openai_llm(): - import adalflow as adal - - # setup env or pass the api_key - from adalflow.utils import setup_env - - setup_env() - - openai_llm = adal.Generator( - model_client=adal.OpenAIClient(), model_kwargs={"model": "gpt-3.5-turbo"} - ) - resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"}) - print(resopnse) - - def test_openai_reasoning(): - import adalflow as adal - - # setup env or pass the api_key - from adalflow.utils import setup_env - - setup_env() - - from adalflow.core.types import ModelType - - openai_llm = adal.Generator( - model_client=adal.OpenAIClient(), - model_type=ModelType.LLM_REASONING, - model_kwargs={ - "model": "o3", - "reasoning": {"effort": "medium", "summary": "auto"}, - }, - ) - - resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"}) - print(resopnse) - - test_openai_reasoning() +"""OpenAI ModelClient integration.""" + +import os +import base64 +from typing import ( + Dict, + Sequence, + Optional, + List, + Any, + TypeVar, + Callable, + Generator as GeneratorType, + Union, + Literal, + Iterable, + AsyncIterable, + AsyncGenerator, +) +import re + +import logging +import backoff + +# optional import +from adalflow.utils.lazy_import import safe_import, OptionalPackages + + +openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) + +from openai import ( + OpenAI, + AsyncOpenAI, +) # , Stream # COMMENTED OUT - USING RESPONSE API ONLY +from openai import ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + BadRequestError, +) +from openai.types import ( + Completion, + CreateEmbeddingResponse, + Image, +) + +# from openai.types.chat import ChatCompletionChunk, ChatCompletion # COMMENTED OUT - USING RESPONSE API ONLY +from openai.types.responses import Response, ResponseUsage +from adalflow.core.model_client import ModelClient +from adalflow.core.types import ( + ModelType, + EmbedderOutput, + ResponseUsage as AdalFlowResponseUsage, + InputTokensDetails, + OutputTokensDetails, + GeneratorOutput, +) +from dataclasses import dataclass + +from adalflow.components.model_client.utils import ( + parse_embedding_response, + format_content_for_response_api, +) + +log = logging.getLogger(__name__) +T = TypeVar("T") + + +@dataclass +class ParsedResponseContent: + """Structured container for parsed response content from OpenAI Response API. + + This dataclass provides a consistent interface for accessing different types + of content that can be returned by the Response API, including text, images, + tool calls, reasoning chains, and more. + + Attributes: + text: The main text content from the response + images: List of image data (base64 or URLs) from image generation + tool_calls: List of other tool call results + reasoning: Reasoning chain from reasoning models + code_outputs: Outputs from code interpreter + raw_output: The original output array for advanced processing + """ + text: Optional[str] = None + images: Optional[Union[str, List[str]]] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + reasoning: Optional[List[Dict[str, Any]]] = None + code_outputs: Optional[List[Dict[str, Any]]] = None + raw_output: Optional[Any] = None + + def __bool__(self) -> bool: + """Check if there's any content.""" + return any([ + self.text, + self.images, + self.tool_calls, + self.reasoning, + self.code_outputs + ]) + + +# OLD CHAT COMPLETION PARSING FUNCTIONS (COMMENTED OUT) +# # completion parsing functions and you can combine them into one single chat completion parser +# def get_first_message_content(completion: ChatCompletion) -> str: +# r"""When we only need the content of the first message. +# It is the default parser for chat completion.""" +# log.debug(f"raw completion: {completion}") +# return completion.choices[0].message.content + + +def get_response_output_text(response: Response) -> str: + """Used to extract the data field for the reasoning model""" + log.debug(f"raw response: {response}") + return response.output_text + + +def parse_response_output(response: Response) -> ParsedResponseContent: + """Parse response output that may include various types of content and tool calls. + + The output array can contain: + - Output messages (with nested content items) + - Tool calls (file search, function, web search, computer use, etc.) + - Reasoning chains + - Image generation calls + - Code interpreter calls + - And more... + + Returns: + ParsedResponseContent: Structured content with typed access to all response data + """ + log.debug(f"raw response from api: {response}") + + content = ParsedResponseContent() + + # Store raw output for advanced users + if hasattr(response, 'output'): + content.raw_output = response.output + + # First try to use output_text if available (SDK convenience property) + if hasattr(response, 'output_text') and response.output_text: + content.text = response.output_text + # Parse the output array manually if no output_text + if hasattr(response, 'output') and response.output: + parsed = _parse_output_array(response.output) + content.text = content.text or parsed.get("text") + content.images = parsed.get("images", []) + content.tool_calls = parsed.get("tool_calls") + content.reasoning = parsed.get("reasoning") + content.code_outputs = parsed.get("code_outputs") + + return content + + + +def _parse_message(item) -> Dict[str, Any]: + """Parse a message item from the output array. + + Args: + item: A message item with type="message" and content array + + Returns: + Dict with parsed text and images from the message + """ + result = {"text": None} + + if hasattr(item, 'content') and isinstance(item.content, list): + # now pick the longer response + text_parts = [] + + for content_item in item.content: + content_type = getattr(content_item, 'type', None) + + if content_type == "output_text": + if hasattr(content_item, 'text'): + text_parts.append(content_item.text) + + if text_parts: + result["text"] = max(text_parts, key=len) if len(text_parts) > 1 else text_parts[0] + + return result + + +def _parse_reasoning(item) -> Dict[str, Any]: + """Parse a reasoning item from the output array. + + Args: + item: A reasoning item with type="reasoning" and summary array + + Returns: + Dict with extracted reasoning text and full structure + """ + result = {"reasoning": None} + + # Extract text from reasoning summary if available + if hasattr(item, 'summary') and isinstance(item.summary, list): + summary_texts = [] + for summary_item in item.summary: + if hasattr(summary_item, 'type') and summary_item.type == "summary_text": + if hasattr(summary_item, 'text'): + summary_texts.append(summary_item.text) + + if summary_texts: + # Store reasoning text separately for later combination + result["reasoning"] = "\n".join(summary_texts) + + return result + + +def _parse_image(item) -> Dict[str, Any]: + """Parse an image generation call item from the output array. + + Args: + item: An image generation item with type="image_generation_call" and result field + + Returns: + Dict with extracted image data + """ + result = {"images": None} + + if hasattr(item, 'result'): + # The result contains the base64 image data or URL + result["images"] = item.result + + return result + + +def _parse_tool_call(item) -> Dict[str, Any]: + """Parse a tool call item from the output array. + + Args: + item: A tool call item (various types ending in _call or containing tool_call) + + Returns: + Dict with tool call information + """ + item_type = getattr(item, 'type', None) + + if item_type == "image_generation_call": + # Handle image generation - extract the result which contains the image data + if hasattr(item, 'result'): + # The result contains the base64 image data or URL + return {"images": item.result} + elif item_type == "code_interpreter_tool_call": + return {"code_outputs": [_serialize_item(item)]} + else: + # Generic tool call + return { + "tool_calls": [{ + "type": item_type, + "content": _serialize_item(item) + }] + } + + return {} + + +def _parse_output_array(output_array) -> Dict[str, Any]: + """Parse the entire output array, processing all elements. + + The output array typically contains: + 1. Reasoning (optional) - thinking/reasoning before the response + 2. Message - the actual response with content + 3. Tool calls (optional) - any tool invocations + + Returns: + Dict with keys: text, images, tool_calls, reasoning, code_outputs + """ + result = { + "text": None, + "images": None, + "tool_calls": None, + "reasoning": None, + "code_outputs": None + } + + if not output_array: + return result + + # Process all items in the array + all_images = [] + all_tool_calls = [] + all_code_outputs = [] + all_reasoning = None + text = None + + for item in output_array: + item_type = getattr(item, 'type', None) + + if item_type == "reasoning": + # Parse reasoning item + parsed = _parse_reasoning(item) + if parsed.get("reasoning"): + all_reasoning = parsed["reasoning"] + + elif item_type == "message": + # Parse message item + parsed = _parse_message(item) + if parsed.get("text"): + text = parsed["text"] + + elif item_type == "image_generation_call": + # Parse image generation call separately + parsed = _parse_image(item) + if parsed.get("images"): + all_images.append(parsed["images"]) + + elif item_type and ('call' in item_type or 'tool' in item_type): + # Parse other tool calls + parsed = _parse_tool_call(item) + if parsed.get("tool_calls"): + all_tool_calls.extend(parsed["tool_calls"]) + if parsed.get("code_outputs"): + all_code_outputs.extend(parsed["code_outputs"]) + + + result["text"] = text if text else None # TODO: they can potentially send multiple complete text messages, we might need to save all of them and only return the first that can convert to outpu parser + + # Set other fields if they have content + result["images"] = all_images + if all_tool_calls: + result["tool_calls"] = all_tool_calls + if all_reasoning: + result["reasoning"] = all_reasoning + if all_code_outputs: + result["code_outputs"] = all_code_outputs + + return result + + +def _serialize_item(item) -> Dict[str, Any]: + """Convert an output item to a serializable dict.""" + result = {} + for attr in dir(item): + if not attr.startswith('_'): + value = getattr(item, attr, None) + if value is not None and not callable(value): + result[attr] = value + return result + + +# def _get_chat_completion_usage(completion: ChatCompletion) -> OpenAICompletionUsage: +# return completion.usage + + +# A simple heuristic to estimate token count for estimating number of tokens in a Streaming response +def estimate_token_count(text: str) -> int: + """ + Estimate the token count of a given text. + + Args: + text (str): The text to estimate token count for. + + Returns: + int: Estimated token count. + """ + # Split the text into tokens using spaces as a simple heuristic + tokens = text.split() + + # Return the number of tokens + return len(tokens) + + +# OLD CHAT COMPLETION STREAMING FUNCTIONS (COMMENTED OUT) +# def parse_stream_chat_completion(completion: ChatCompletionChunk) -> str: +# r"""Parse the completion chunks of the chat completion API.""" +# output = completion.choices[0].delta.content +# if hasattr(completion, "citations"): +# citations = completion.citations +# return output, citations +# return output + + +# def handle_streaming_chat_completion(generator: Stream[ChatCompletionChunk]): +# r"""Handle the streaming completion.""" +# for completion in generator: +# log.debug(f"Raw chunk completion: {completion}") +# parsed_content = parse_stream_chat_completion(completion) +# yield parsed_content + + +async def handle_streaming_response( + stream: AsyncIterable[Any], +) -> AsyncGenerator[str, None]: + """ + Async generator that processes a stream of SSE events from client.responses.create(..., stream=True). + + Args: + stream: An async iterable of SSE events from the OpenAI API + + Yields: + str: Non-empty text fragments parsed from the stream events + """ + async for event in stream: + yield event + + +def handle_streaming_response_sync(stream: Iterable) -> GeneratorType: + """ + Synchronous version: Iterate over an SSE stream from client.responses.create(..., stream=True), + logging each raw event and yielding non-empty text fragments. + """ + # already compatible as this is the OpenAI client + for event in stream: + yield event + + + + +class OpenAIClient(ModelClient): + __doc__ = r"""A component wrapper for the OpenAI API client. + + Support both embedding and response API, including multimodal capabilities. + + + Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client. + (2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project. + + Note: + We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API. + We do not know how OpenAI is doing the formating or what prompt they have added. + Instead + - use :ref:`OutputParser` for response parsing and formating. + + For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them. + The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini). + + For image generation, use model_type=ModelType.IMAGE_GENERATION and provide: + - model: "dall-e-3" or "dall-e-2" + - prompt: Text description of the image to generate + - size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 + - quality: "standard" or "hd" (DALL-E 3 only) + - n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) + - response_format: "url" or "b64_json" + + Examples: + Basic text generation:: + + from adalflow.components.model_client import OpenAIClient + from adalflow.core import Generator + + # Initialize client (uses OPENAI_API_KEY env var by default) + client = OpenAIClient() + + # Create a generator for text + generator = Generator( + model_client=client, + model_kwargs={"model": "gpt-4o-mini"} + ) + + # Generate response + response = generator(prompt_kwargs={"input_str": "What is machine learning?"}) + print(response.data) + + Multimodal with URL image:: + + # Vision model with image from URL + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o", + "images": "https://example.com/chart.jpg" + } + ) + + response = generator( + prompt_kwargs={"input_str": "Analyze this chart and explain the trends"} + ) + + Multimodal with local images:: + + # Multiple local images + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o", + "images": [ + "/path/to/image1.jpg", + "/path/to/image2.png" + ] + } + ) + + response = generator( + prompt_kwargs={"input_str": "Compare these two images"} + ) + + Pre-formatted images with custom encoding:: + + import base64 + from adalflow.core.functional import encode_image + + # Option 1: Using the encode_image helper + base64_img = encode_image("/path/to/image.jpg") + + # Option 2: Manual base64 encoding + with open("/path/to/image.png", "rb") as f: + base64_img = base64.b64encode(f.read()).decode('utf-8') + + # Use pre-formatted image data + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o", + "images": [ + # Pre-formatted as base64 data URI + f"data:image/png;base64,{base64_img}", + # Or as a dict with type and image_url + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_img}" + }, + # Mix with regular URLs + "https://example.com/chart.jpg" + ] + } + ) + + response = generator( + prompt_kwargs={"input_str": "Analyze these images"} + ) + + Reasoning models (O1, O3):: + + from adalflow.core.types import ModelType + + # O3 reasoning model with effort configuration + generator = Generator( + model_client=OpenAIClient(), + model_type=ModelType.LLM_REASONING, + model_kwargs={ + "model": "o3", + "reasoning": { + "effort": "medium", # low, medium, high + "summary": "auto" # detailed, auto, none + } + } + ) + + response = generator( + prompt_kwargs={"input_str": "Solve this complex problem: ..."} + ) + + Image generation with DALL-E (legacy method):: + + from adalflow.core.types import ModelType + + # Generate an image using ModelType.IMAGE_GENERATION + generator = Generator( + model_client=OpenAIClient(), + model_type=ModelType.IMAGE_GENERATION, + model_kwargs={ + "model": "dall-e-3", + "size": "1024x1792", + "quality": "hd", + "n": 1 + } + ) + + response = generator( + prompt_kwargs={"input_str": "A futuristic city with flying cars at sunset"} + ) + # response.data contains the image URL or base64 data + + Image generation via tools (new API):: + + import base64 + + # Generate images using the new tools API + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o-mini", # or any model that supports tools + "tools": [{"type": "image_generation"}] + } + ) + + # Generate an image + response = generator( + prompt_kwargs={ + "input_str": "Generate an image of a gray tabby cat hugging an otter with an orange scarf" + } + ) + + # Access the generated image(s) + if isinstance(response.data, list): + # Multiple images + for i, img_base64 in enumerate(response.data): + with open(f"generated_{i}.png", "wb") as f: + f.write(base64.b64decode(img_base64)) + elif isinstance(response.data, str): + # Single image + with open("generated.png", "wb") as f: + f.write(base64.b64decode(response.data)) + elif isinstance(response.data, dict) and "images" in response.data: + # Mixed response with text and images + print("Text:", response.data["text"]) + for i, img_base64 in enumerate(response.data["images"]): + with open(f"generated_{i}.png", "wb") as f: + f.write(base64.b64decode(img_base64)) + + Embeddings:: + + from adalflow.core import Embedder + + # Create embedder + embedder = Embedder( + model_client=OpenAIClient(), + model_kwargs={"model": "text-embedding-3-small"} + ) + + # Generate embeddings + embeddings = embedder(input=["Hello world", "Machine learning"]) + print(embeddings.data) # List of embedding vectors + + Streaming responses:: + + from adalflow.components.model_client.utils import extract_text_from_response_stream + + # Enable streaming + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o", + "stream": True + } + ) + + # Stream the response + response = generator(prompt_kwargs={"input_str": "Tell me a story"}) + + # Extract text from Response API streaming events + for event in response.raw_response: + text = extract_text_from_response_stream(event) + if text: + print(text, end="") + + Custom API endpoint:: + + # Use with third-party providers or local models + client = OpenAIClient( + base_url="https://api.custom-provider.com/v1/", + api_key="your-api-key", + headers={"X-Custom-Header": "value"} + ) + + Args: + api_key (Optional[str], optional): OpenAI API key. Defaults to `None`. + non_streaming_chat_completion_parser (Callable[[Completion], Any], optional): Legacy parser for chat completions. + Defaults to `None` (deprecated). + streaming_chat_completion_parser (Callable[[Completion], Any], optional): Legacy parser for streaming chat completions. + Defaults to `None` (deprecated). + non_streaming_response_parser (Callable[[Response], Any], optional): The parser for non-streaming responses. + Defaults to `get_response_output_text`. + streaming_response_parser (Callable[[Response], Any], optional): The parser for streaming responses. + Defaults to `handle_streaming_response`. + input_type (Literal["text", "messages"]): Input type for the client. Defaults to "text". + base_url (str): The API base URL to use when initializing the client. + Defaults to `"https://api.openai.com/v1/"`, but can be customized for third-party API providers or self-hosted models. + env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. + organization (Optional[str], optional): OpenAI organization key. Defaults to None. + headers (Optional[Dict[str, str]], optional): Additional headers to include in API requests. Defaults to None. + + References: + - OpenAI API Overview: https://platform.openai.com/docs/introduction, https://platform.openai.com/docs/guides/images-vision?api-mode=responses + - Embeddings Guide: https://platform.openai.com/docs/guides/embeddings + - Chat Completion Models: https://platform.openai.com/docs/guides/text-generation + - Response api: https://platform.openai.com/docs/api-reference/responses/create, Analyze images and use them as input and/or generate images as output + - Vision Models: https://platform.openai.com/docs/guides/vision + - Image Generation: https://platform.openai.com/docs/guides/images + - reasoning: https://platform.openai.com/docs/guides/reasoning + + Note: + - Ensure each OpenAIClient instance is used by one generator only. + """ + + def __init__( + self, + api_key: Optional[str] = None, + # OLD CHAT COMPLETION PARSER PARAMS (kept for backward compatibility) + non_streaming_chat_completion_parser: Optional[ + Callable[[Completion], Any] + ] = None, # non-streaming parser - deprecated but accepted + streaming_chat_completion_parser: Optional[ + Callable[[Completion], Any] + ] = None, # streaming parser - deprecated but accepted + # Response API parsers (used for reasoning models) + non_streaming_response_parser: Optional[Callable[[Response], Any]] = None, + streaming_response_parser: Optional[Callable[[Response], Any]] = None, + input_type: Literal["text", "messages"] = "text", + base_url: str = "https://api.openai.com/v1/", + env_api_key_name: str = "OPENAI_API_KEY", + organization: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): + r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument. + + Args: + api_key (Optional[str], optional): OpenAI API key. Defaults to None. + non_streaming_chat_completion_parser (Optional[Callable[[Completion], Any]], optional): DEPRECATED - Legacy parser for chat completions. Ignored, kept for backward compatibility. Defaults to None. + streaming_chat_completion_parser (Optional[Callable[[Completion], Any]], optional): DEPRECATED - Legacy parser for streaming chat completions. Ignored, kept for backward compatibility. Defaults to None. + non_streaming_response_parser (Optional[Callable[[Response], Any]], optional): Parser for non-streaming responses. Defaults to None. + streaming_response_parser (Optional[Callable[[Response], Any]], optional): Parser for streaming responses. Defaults to None. + input_type (Literal["text", "messages"]): Input type for the client. Defaults to "text". + base_url (str): The API base URL to use when initializing the client. + env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. + organization (Optional[str], optional): OpenAI organization key. Defaults to None. + headers (Optional[Dict[str, str]], optional): Additional headers to include in API requests. Defaults to None. + """ + # Log deprecation warning if old parsers are provided + if non_streaming_chat_completion_parser is not None: + log.warning( + "non_streaming_chat_completion_parser is deprecated and will be ignored. " + "The OpenAI client now uses the Response API exclusively." + ) + if streaming_chat_completion_parser is not None: + log.warning( + "streaming_chat_completion_parser is deprecated and will be ignored. " + "The OpenAI client now uses the Response API exclusively." + ) + + super().__init__() + self._api_key = api_key + self.base_url = base_url + self._env_api_key_name = env_api_key_name + self.organization = organization + self.headers = headers or {} + self.sync_client = self.init_sync_client() + self.async_client = None # only initialize if the async call is called + self._input_type = input_type + self._api_kwargs = {} # add api kwargs when the OpenAI Client is called + + # Response API parsers (RESPONSE API ONLY NOW) + # (used for both synchronous and asynchronous (stream + non-streaming) calls via Response API) + self.non_streaming_response_parser = ( + non_streaming_response_parser or get_response_output_text + ) + # Separate sync and async streaming parsers + self.streaming_response_parser_sync = handle_streaming_response_sync + self.streaming_response_parser_async = ( + streaming_response_parser or handle_streaming_response + ) + + # Default parsers (will be set dynamically based on sync/async context) + self.response_parser = self.non_streaming_response_parser + self.streaming_response_parser = ( + self.streaming_response_parser_async + ) # Default to async + # self.chat_completion_parser = self.non_streaming_chat_completion_parser # COMMENTED OUT + + def init_sync_client(self): + api_key = self._api_key or os.getenv(self._env_api_key_name) + if not api_key: + raise ValueError( + f"Environment variable {self._env_api_key_name} must be set" + ) + return OpenAI( + api_key=api_key, + base_url=self.base_url, + organization=self.organization, + default_headers=self.headers, + ) + + def init_async_client(self): + api_key = self._api_key or os.getenv(self._env_api_key_name) + if not api_key: + raise ValueError( + f"Environment variable {self._env_api_key_name} must be set" + ) + return AsyncOpenAI( + api_key=api_key, + base_url=self.base_url, + organization=self.organization, + default_headers=self.headers, + ) + + # NEW RESPONSE API ONLY FUNCTION + def parse_chat_completion( + self, + completion: Union[Response, AsyncIterable], + ) -> "GeneratorOutput": + """Parse the Response API completion and put it into the raw_response. + Fully migrated to Response API only.""" + + parser = self.response_parser + log.info(f"completion/response: {completion}, parser: {parser}") + + # Check if this is a Response with complex output (tools, images, etc.) + if isinstance(completion, Response): + parsed_content = parse_response_output(completion) + usage = self.track_completion_usage(completion) + + data = parsed_content.text + + thinking = None + if parsed_content.reasoning: + thinking = str(parsed_content.reasoning) + + + return GeneratorOutput( + data=data, # only text + thinking=thinking, + images=parsed_content.images, # List of image data (base64 or URLs) + tool_use=None, # Will be populated when we handle function tool calls + error=None, + raw_response=data, + usage=usage + ) + # Regular response handling (streaming or other) + data = parser(completion) + usage = self.track_completion_usage(completion) + return GeneratorOutput(data=None, error=None, raw_response=data, usage=usage) + + + # NEW RESPONSE API ONLY FUNCTION + def track_completion_usage( + self, + completion: Union[Response, AsyncIterable], + ) -> ResponseUsage: + """Track usage for Response API only.""" + if isinstance(completion, Response): + # Handle Response object with ResponseUsage structure + input_tokens_details = InputTokensDetails( + cached_tokens=getattr(completion.usage, "cached_tokens", 0) + ) + + output_tokens_details = OutputTokensDetails( + reasoning_tokens=getattr(completion.usage, "reasoning_tokens", 0) + ) + + return AdalFlowResponseUsage( + input_tokens=completion.usage.input_tokens, + input_tokens_details=input_tokens_details, + output_tokens=completion.usage.output_tokens, + output_tokens_details=output_tokens_details, + total_tokens=completion.usage.total_tokens, + ) + + # otherwise return the AdalFlowResponseUsage with None values with log warnings + elif hasattr(completion, "__aiter__") or hasattr(completion, "__iter__"): + log.debug( + "Cannot track usage for generator/iterator. Usage tracking should be handled when consuming the stream." + ) + else: + log.debug(f"Unknown completion type: {type(completion)}") + + return AdalFlowResponseUsage( + input_tokens=None, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens=None, + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + total_tokens=None, + ) + + def parse_embedding_response( + self, response: CreateEmbeddingResponse + ) -> EmbedderOutput: + r"""Parse the embedding response to a structure Adalflow components can understand. + + Should be called in ``Embedder``. + """ + try: + return parse_embedding_response(response) + except Exception as e: + log.error(f"Error parsing the embedding response: {e}") + return EmbedderOutput(data=[], error=str(e), raw_response=response) + + def _convert_llm_inputs_to_messages( + self, + input: Optional[Any] = None, + images: Optional[Any] = None, + detail: Optional[str] = "auto", + ) -> List[Dict[str, str]]: + # convert input to messages + messages: List[Dict[str, str]] = [] + if self._input_type == "messages": + system_start_tag = "" + system_end_tag = "" + user_start_tag = "" + user_end_tag = "" + + # new regex pattern to ignore special characters such as \n + pattern = ( + rf"{system_start_tag}\s*(.*?)\s*{system_end_tag}\s*" + rf"{user_start_tag}\s*(.*?)\s*{user_end_tag}" + ) + + # Compile the regular expression + regex = re.compile(pattern, re.DOTALL) + + # re.DOTALL is to allow . to match newline so that (.*?) does not match in a single line + regex = re.compile(pattern, re.DOTALL) + # Match the pattern + match = regex.match(input) + system_prompt, input_str = None, None + + if match: + system_prompt = match.group(1) + input_str = match.group(2) + else: + print("No match found.") + if system_prompt and input_str: + messages.append({"role": "system", "content": system_prompt}) + if images: + content = [{"type": "text", "text": input_str}] + if isinstance(images, (str, dict)): + images = [images] + for img in images: + content.append(self._prepare_image_content(img, detail)) + messages.append({"role": "user", "content": content}) + else: + messages.append({"role": "user", "content": input_str}) + if len(messages) == 0: + if images: + content = [{"type": "text", "text": input}] + if isinstance(images, (str, dict)): + images = [images] + for img in images: + content.append(self._prepare_image_content(img, detail)) + messages.append({"role": "user", "content": content}) + else: + messages.append({"role": "system", "content": input}) + return messages + + # adapted for the response api + def convert_inputs_to_api_kwargs( + self, + input: Optional[Any] = None, + model_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED, + ) -> Dict: + r""" + Specify the API input type and output api_kwargs that will be used in _call and _acall methods. + Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format. + For multimodal inputs, images can be provided in model_kwargs["images"] as a string path, URL, or list of them. + The model specified in model_kwargs["model"] must support multimodal capabilities when using images. + + Args: + input: The input text or messages to process + model_kwargs: Additional parameters including: + - images: Optional image source(s) as path, URL, or list of them + - detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto' + - model: The model to use (must support multimodal inputs if images are provided) + model_type: The type of model (EMBEDDER or LLM) + + Returns: + Dict: API-specific kwargs for the model call + """ + + final_model_kwargs = model_kwargs.copy() + if model_type == ModelType.EMBEDDER: + if isinstance(input, str): + input = [input] + # convert input to input + if not isinstance(input, Sequence): + raise TypeError("input must be a sequence of text") + final_model_kwargs["input"] = input + elif model_type == ModelType.LLM or model_type == ModelType.LLM_REASONING: + # Check if images are provided for multimodal input + images = final_model_kwargs.pop("images", None) + + if images: + # Use helper function to format content with images + content = format_content_for_response_api(input, images) + + # For responses.create API, wrap in user message format + final_model_kwargs["input"] = [ + { + "role": "user", + "content": content + } + ] + else: + # Text-only input + final_model_kwargs["input"] = input + + # Reasoning models (o1, o3-mini, etc.) do not support certain + # Chat Completion parameters in the Responses API. + # Strip them to avoid BadRequestError / unexpected keyword argument. + if model_type == ModelType.LLM_REASONING: + _REASONING_UNSUPPORTED_PARAMS = { + "frequency_penalty", + "presence_penalty", + "temperature", + "top_p", + } + removed = { + k: final_model_kwargs.pop(k) + for k in _REASONING_UNSUPPORTED_PARAMS + if k in final_model_kwargs + } + if removed: + log.warning( + f"Reasoning model does not support {sorted(removed.keys())}; " + f"removed from api_kwargs. Model: {final_model_kwargs.get('model', 'unknown')}" + ) + else: + raise ValueError(f"model_type {model_type} is not supported") + return final_model_kwargs + + def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput: + """Parse the image generation response into a GeneratorOutput.""" + try: + # Extract URLs or base64 data from the response + data = [img.url or img.b64_json for img in response] + # For single image responses, unwrap from list + if len(data) == 1: + data = data[0] + return GeneratorOutput( + data=data, + raw_response=str(response), + ) + except Exception as e: + log.error(f"Error parsing image generation response: {e}") + return GeneratorOutput(data=None, error=str(e), raw_response=str(response)) + + @backoff.on_exception( + backoff.expo, + ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + BadRequestError, + ), + max_time=5, + ) + def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + """ + kwargs is the combined input and model_kwargs. Support streaming call. + For reasoning model, users can add "reasoning" key to the api_kwargs to pass the reasoning config. + eg: + model_kwargs = { + "model": "gpt-4o-reasoning", + "reasoning": { + "effort": "medium", # low, medium, highc + "summary": "auto", #detailed, auto, none + } + } + """ + self._api_kwargs = api_kwargs + if model_type == ModelType.EMBEDDER: + return self.sync_client.embeddings.create(**api_kwargs) + # OLD CHAT COMPLETION CALLS (COMMENTED OUT) + # elif model_type == ModelType.LLM: + # if "stream" in api_kwargs and api_kwargs.get("stream", False): + # log.debug("streaming call") + # self.chat_completion_parser = self.streaming_chat_completion_parser + # return self.sync_client.chat.completions.create(**api_kwargs) + # else: + # log.debug("non-streaming call") + # self.chat_completion_parser = self.non_streaming_chat_completion_parser + # return self.sync_client.chat.completions.create(**api_kwargs) + elif model_type == ModelType.LLM_REASONING or model_type == ModelType.LLM: + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("streaming call") + self.response_parser = ( + self.streaming_response_parser_sync + ) # Use sync streaming parser + return self.sync_client.responses.create(**api_kwargs) + else: + log.debug("non-streaming call") + self.response_parser = self.non_streaming_response_parser + return self.sync_client.responses.create(**api_kwargs) + + else: + raise ValueError(f"model_type {model_type} is not supported") + + @backoff.on_exception( + backoff.expo, + ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + BadRequestError, + ), + max_time=5, + ) + async def acall( + self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED + ): + """ + kwargs is the combined input and model_kwargs. Support async streaming call. + + This method now relies on the OpenAI Responses API to handle streaming and non-streaming calls + with the asynchronous client + """ + # store the api kwargs in the client + self._api_kwargs = api_kwargs + if self.async_client is None: + self.async_client = self.init_async_client() + if model_type == ModelType.EMBEDDER: + return await self.async_client.embeddings.create(**api_kwargs) + # old chat completions api calls (commented out) + # elif model_type == ModelType.LLM: + # return await self.async_client.chat.completions.create(**api_kwargs) + # elif model_type == ModelType.LLM_REASONING: + # if "stream" in api_kwargs and api_kwargs.get("stream", False): + # log.debug("async streaming call") + # self.response_parser = self.streaming_response_parser + # # setting response parser as async streaming parser for Response API + # return await self.async_client.responses.create(**api_kwargs) + # else: + # log.debug("async non-streaming call") + # self.response_parser = self.non_streaming_response_parser + # # setting response parser as async non-streaming parser for Response API + # return await self.async_client.responses.create(**api_kwargs) + elif model_type == ModelType.LLM or model_type == ModelType.LLM_REASONING: + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("async streaming call") + self.response_parser = ( + self.streaming_response_parser_async + ) # Use async streaming parser + # setting response parser as async streaming parser for Response API + return await self.async_client.responses.create(**api_kwargs) + else: + log.debug("async non-streaming call") + self.response_parser = self.non_streaming_response_parser + # setting response parser as async non-streaming parser for Response API + return await self.async_client.responses.create(**api_kwargs) + elif model_type == ModelType.IMAGE_GENERATION: + # Determine which image API to call based on the presence of image/mask + if "image" in api_kwargs: + if "mask" in api_kwargs: + # Image edit + response = await self.async_client.images.edit(**api_kwargs) + else: + # Image variation + response = await self.async_client.images.create_variation( + **api_kwargs + ) + else: + # Image generation + response = await self.async_client.images.generate(**api_kwargs) + return response.data + else: + raise ValueError(f"model_type {model_type} is not supported") + + @classmethod + def from_dict(cls: type[T], data: Dict[str, Any]) -> T: + obj = super().from_dict(data) + # recreate the existing clients + obj.sync_client = obj.init_sync_client() + obj.async_client = obj.init_async_client() + return obj + + def to_dict(self) -> Dict[str, Any]: + r"""Convert the component to a dictionary.""" + # TODO: not exclude but save yes or no for recreating the clients + exclude = [ + "sync_client", + "async_client", + ] # unserializable object + output = super().to_dict(exclude=exclude) + return output + + def _encode_image(self, image_path: str) -> str: + """Encode image to base64 string. + + Args: + image_path: Path to image file. + + Returns: + Base64 encoded image string. + + Raises: + ValueError: If the file cannot be read or doesn't exist. + """ + try: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + except FileNotFoundError: + raise ValueError(f"Image file not found: {image_path}") + except PermissionError: + raise ValueError(f"Permission denied when reading image file: {image_path}") + except Exception as e: + raise ValueError(f"Error encoding image {image_path}: {str(e)}") + + def _prepare_image_content( + self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" + ) -> Dict[str, Any]: + """Prepare image content for API request. + + Args: + image_source: Either a path to local image or a URL. + detail: Image detail level ('auto', 'low', or 'high'). + + Returns: + Formatted image content for API request. + """ + if isinstance(image_source, str): + if image_source.startswith(("http://", "https://")): + return { + "type": "image_url", + "image_url": {"url": image_source, "detail": detail}, + } + else: + base64_image = self._encode_image(image_source) + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": detail, + }, + } + return image_source + + +# Example usage: +if __name__ == "__main__": + from adalflow.core import Generator + from adalflow.utils import setup_env + + # log = get_logger(level="DEBUG") + + setup_env() + prompt_kwargs = {"input_str": "What is the meaning of life?"} + + gen = Generator( + model_client=OpenAIClient(), + model_kwargs={"model": "gpt-3.5-turbo", "stream": False}, + ) + gen_response = gen(prompt_kwargs) + print(f"gen_response: {gen_response}") + + # for genout in gen_response.data: + # print(f"genout: {genout}") + + # test that to_dict and from_dict works + # model_client = OpenAIClient() + # model_client_dict = model_client.to_dict() + # from_dict_model_client = OpenAIClient.from_dict(model_client_dict) + # assert model_client_dict == from_dict_model_client.to_dict() + + +if __name__ == "__main__": + + def test_openai_llm(): + import adalflow as adal + + # setup env or pass the api_key + from adalflow.utils import setup_env + + setup_env() + + openai_llm = adal.Generator( + model_client=adal.OpenAIClient(), model_kwargs={"model": "gpt-3.5-turbo"} + ) + resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"}) + print(resopnse) + + def test_openai_reasoning(): + import adalflow as adal + + # setup env or pass the api_key + from adalflow.utils import setup_env + + setup_env() + + from adalflow.core.types import ModelType + + openai_llm = adal.Generator( + model_client=adal.OpenAIClient(), + model_type=ModelType.LLM_REASONING, + model_kwargs={ + "model": "o3", + "reasoning": {"effort": "medium", "summary": "auto"}, + }, + ) + + resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"}) + print(resopnse) + + test_openai_reasoning() diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index 0384b3a72..8b56896fd 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -1,1593 +1,1603 @@ -"""Generator is a user-facing orchestration component with a simple and unified interface for LLM prediction. - -It is a pipeline that consists of three subcomponents.""" - -import json -import re -import time -from pathlib import Path - -from typing import Any, Dict, Optional, Union, Callable, Tuple, List, AsyncGenerator -from collections.abc import AsyncIterable -import logging -from dataclasses import dataclass, field - -from openai.types.responses import ResponseCompletedEvent -from adalflow.core.tokenizer import Tokenizer - - -from adalflow.core.types import ( - ModelType, - GeneratorOutput, - GeneratorOutputType, -) -from adalflow.core.component import Component, DataComponent -from adalflow.optim.grad_component import GradComponent -from adalflow.core.base_data_class import DataClass - - -from adalflow.optim.parameter import ( - Parameter, - OutputParameter, -) -from adalflow.optim.gradient import GradientContext, Gradient -from adalflow.optim.types import ParameterType - -from adalflow.core.prompt_builder import Prompt -from adalflow.core.functional import compose_model_kwargs -from adalflow.core.model_client import ModelClient -from adalflow.core.default_prompt_template import DEFAULT_ADALFLOW_SYSTEM_PROMPT -from adalflow.optim.function import BackwardContext -from adalflow.utils.cache import CachedEngine -from adalflow.tracing.callback_manager import CallbackManager -from adalflow.tracing import generator_span -from adalflow.utils.global_config import get_adalflow_default_root_path -from adalflow.core.string_parser import JsonParser - -from adalflow.optim.text_grad.backend_engine_prompt import ( - FEEDBACK_ENGINE_TEMPLATE, - LLM_CONVERSATION_TEMPLATE, - ALL_PRED_INFO, - OUTPUT_INSTRUCTION, - VARIABLE_AND_PEERS_INFO, - CONVERSATION_START_INSTRUCTION_CHAIN, - OBJECTIVE_INSTRUCTION_BASE, - OBJECTIVE_INSTRUCTION_CHAIN, -) - -__all__ = ["Generator", "BackwardEngine", "create_teacher_generator"] - - -log = logging.getLogger(__name__) - - -PromptArgType = Dict[str, Union[str, Parameter]] - - -@dataclass -class BackwardPassSetup(DataClass): - all_pred_at_once: bool = field( - default=False, metadata={"desc": "Backward all predecessors at once."} - ) - threshold_score_to_compute_grad_for_errors: float = field( - default=0.9, - metadata={"desc": "Threshold score to compute gradient for errors."}, - ) - compute_grad_for_errors_only: bool = field( - default=True, metadata={"desc": "Compute gradient for errors only."} - ) - - -# TODO: better debug mode -class Generator(GradComponent, CachedEngine, CallbackManager): - __doc__ = """An user-facing orchestration component for LLM prediction. - - It is also a GradComponent that can be used for backpropagation through the LLM model. - - By orchestrating the following three components along with their required arguments, - it enables any LLM prediction with required task output format. - - Prompt - - Model client - - Output processors - - Args: - model_client (ModelClient): The model client to use for the generator. - model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}. Please refer to :ref:`ModelClient` for the details on how to set the model_kwargs for your specific model if it is from our library. - model_type (ModelType, optional): The type of the model. Defaults to ModelType.LLM. When using reasoning models which calls different api, you should set it to ModelType.LLM_REASONING. - template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT`. - prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None. - output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None. - trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to []. - - Note: - 1. The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output. - And we encourage you to only use it to parse the response to data format you will use later. - 2. For structured output, you should avoid using `stream` as the output_processors can only be run after all the data is available. - """ - - model_type: ModelType = ModelType.LLM - model_client: ModelClient # for better type checking - - _use_cache: bool = False - _kwargs: Dict[str, Any] = ( - {} - ) # to create teacher generator from student TODO: might reaccess this - - backward_pass_setup: BackwardPassSetup = ( - BackwardPassSetup() - ) # default setup for the backward pass - - def __init__( - self, - *, - # args for the model - model_client: ModelClient, # will be intialized in the main script - model_kwargs: PromptArgType = {}, - model_type: Optional[ModelType] = ModelType.LLM, - # args for the prompt - template: Optional[str] = None, - prompt_kwargs: Optional[Dict] = {}, - # args for the output processing - output_processors: Optional[DataComponent] = None, - name: Optional[str] = None, - # args for the cache - cache_path: Optional[str] = None, - use_cache: bool = True, - ) -> None: - r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables: - - task_desc_str - - tools_str - - example_str - - chat_history_str - - context_str - - steps_str - You can preset the prompt kwargs to fill in the variables in the prompt using prompt_kwargs. - But you can replace the prompt and set any variables you want and use the prompt_kwargs to fill in the variables. - """ - - if not isinstance(model_client, ModelClient): - raise TypeError( - f"{type(self).__name__} requires a ModelClient instance for model_client, please pass it as OpenAIClient() or GroqAPIClient() for example.\ - Got {model_client} instead." - ) - - template = template or DEFAULT_ADALFLOW_SYSTEM_PROMPT - - # create the cache path and initialize the cache engine - - self.set_cache_path( - cache_path, model_client, model_kwargs.get("model", "default") - ) - - CachedEngine.__init__(self, cache_path=self.cache_path) - - Component.__init__(self) - GradComponent.__init__(self, desc="Generate a response using LLM model.") - CallbackManager.__init__(self) - - self.name = name or self.__class__.__name__ - self.template = template - self.prompt_kwargs = prompt_kwargs.copy() - - self.model_kwargs = model_kwargs.copy() - # init the model client - self.model_client = model_client - self.model_type = model_type - - self.output_processors = output_processors - - if output_processors and (not isinstance(output_processors, DataComponent)): - raise ValueError( - f"output_processors should be a DataComponent instance, got {type(output_processors)}" - ) - - self.set_parameters(prompt_kwargs) - - # end of trainable parameters - self.backward_engine: "BackwardEngine" = None - log.info(f"Generator {self.name} initialized.") - # to support better testing on the parts beside of the model call - self.mock_output: bool = False - self.mock_output_data: str = "mock data" - - self._use_cache = use_cache - - self._kwargs = { - "model_client": model_client, - "model_kwargs": model_kwargs, - "template": template, - "prompt_kwargs": prompt_kwargs, - "output_processors": output_processors, - "name": name, - "cache_path": cache_path, - "use_cache": use_cache, - } - self._teacher: Optional["Generator"] = None - self._trace_api_kwargs: Dict[str, Any] = ( - {} - ) # used by dynamic computation graph and backpropagation - - self._tokenizer: Tokenizer = Tokenizer() - self._estimated_token_count: int = 0 - - @property - def use_cache(self): - return self._use_cache - - - @property - def estimated_token_count(self) -> int: - """Property to access the estimated token count from the last prompt. - - Returns: - int: The estimated token count from the last processed prompt. - Returns 0 if no prompt has been processed yet. - """ - return self._estimated_token_count - - def update_default_backward_pass_setup(self, setup: BackwardPassSetup): - self.backward_pass_setup = setup - - def set_cache_path(self, cache_path: str, model_client: object, model: str): - """Set the cache path for the generator.""" - - # Construct a valid model string using the client class name and model - self.model_str = f"{model_client.__class__.__name__}_{model}" - - # Remove any characters that are not allowed in file names (cross-platform) - # On Windows, characters like `:<>?/\|*` are prohibited. - self.model_str = re.sub(r"[^a-zA-Z0-9_\-]", "_", self.model_str) - - _cache_path = ( - get_adalflow_default_root_path() if cache_path is None else cache_path - ) - - # Use pathlib to handle paths more safely across OS - self.cache_path = Path(_cache_path) / f"cache_{self.model_str}.db" - - log.debug(f"Cache path set to: {self.cache_path}") - - def get_cache_path(self) -> str: - r"""Get the cache path for the generator.""" - return self.cache_path - - @staticmethod - def _get_default_mapping( - output: "GeneratorOutput" = None, - ) -> Tuple[Dict[str, Callable], List[str]]: - - if ( - output.data - and isinstance(output.data, DataClass) - and len(output.data.get_output_fields()) > 0 - ): - output_fields = output.data.get_output_fields() - - output_mapping = { - f: lambda x, f=f: getattr(x.data, f) for f in output_fields - } - elif output.raw_response: - output_fields = ["raw_response"] - output_mapping = {f: lambda x, f=f: getattr(x, f) for f in output_fields} - output_fields = ["Answer"] - output_mapping["Example"] = output_mapping["raw_response"] - del output_mapping["raw_response"] - - return output_mapping, output_fields - - def set_mock_output( - self, mock_output: bool = True, mock_output_data: str = "mock data" - ): - self.mock_output = mock_output - self.mock_output_data = mock_output_data - - def reset_mock_output(self): - self.mock_output = False - self.mock_output_data = "mock data" - - def set_parameters(self, prompt_kwargs: PromptArgType): - r"""Set name for each paramter and set all context for each other. - Make all parameters attributes to the generator for finding them easily - for optimizers and other components. - """ - for key, p in prompt_kwargs.items(): - if isinstance(p, Parameter): - if not p.name or p.name == "": - p.name = key - peers = [ - p - for k, p in prompt_kwargs.items() - if isinstance(p, Parameter) and k != key - # and p.param_type == ParameterType.PROMPT - ] - p.set_peers(peers) - setattr(self, key, p) - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "Generator": - r"""Create a Generator instance from the config dictionary. - - Example: - - .. code-block:: python - - config = { - "model_client": { - "component_name": "OpenAIClient", - "component_config": {} - }, - "model_kwargs": {"model": "gpt-3.5-turbo", "temperature": 0} - } - generator = Generator.from_config(config) - """ - # create init_kwargs from the config - assert "model_client" in config, "model_client is required in the config" - return super().from_config(config) - - def _compose_model_kwargs(self, **model_kwargs) -> Dict: - r""" - The model configuration exclude the input itself. - Combine the default model, model_kwargs with the passed model_kwargs. - Example: - model_kwargs = {"temperature": 0.5, "model": "gpt-3.5-turbo"} - self.model_kwargs = {"model": "gpt-3.5-turbo"} - combine_kwargs(model_kwargs) => {"temperature": 0.5, "model": "gpt-3.5-turbo"} - - """ - combined_model_kwargs = self.model_kwargs.copy() - - if model_kwargs: - combined_model_kwargs.update(model_kwargs) - return combined_model_kwargs - - # TODO: use prompt_kwargs as users are already familiar with it - def print_prompt(self, **kwargs) -> str: - prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) - return prompt.print_prompt(**kwargs) - - def get_prompt(self, **kwargs) -> str: - prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) - return prompt.call(**kwargs) - - def _extra_repr(self) -> str: - prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) - s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}, prompt={prompt}" - return s - - def _post_call(self, completion: Any) -> GeneratorOutput: - r"""Get string completion and process it with the output_processors.""" - # parse chat completion will only fill the raw_response - output: GeneratorOutput = self.model_client.parse_chat_completion(completion) - # save the api response - output.api_response = completion - - # Now adding the data field to the output - data = output.raw_response - - # TODO implement support for synchronous iterator in the future - if self.output_processors: - if data: - try: - data = self.output_processors(data) - output.data = data - except Exception as e: - log.error(f"Error processing the output processors: {e}") - output.error = str(e) - else: - output.data = data - - return output - - async def _output_processing( - self, data: AsyncIterable, generator_output: GeneratorOutput - ) -> AsyncGenerator[Any, None]: - r"""Create an async generator from the async_iterable returned by model client and yield from them - apply output processors and store in generator output. - Consume the raw_response and yield each event. - Store the final output text in the generator output data field. - """ - - # the raw response of the generator output should be an async iterable - if not isinstance(data, AsyncIterable): - raise ValueError("The data is not an async iterable") - - final_output_text = "" - - # iterate over the events in the data and yield each event and store the final output text - # this assumes that the raw_response of the generator output stores a - # async iterable of events based on the OpenAI Responses API documentation for streaming - async for event in data: - log.debug(f"Raw event: {event!r}") - yield event - - if ResponseCompletedEvent and isinstance(event, ResponseCompletedEvent): - resp = event.response - log.debug(f"Response completed: {event.response.output_text}") - if getattr(resp, "output_text", None): - final_output_text = resp.output_text - - if self.output_processors: - if final_output_text: - try: - final_output = self.output_processors(final_output_text) - generator_output.data = final_output - except Exception as e: - log.error(f"Error processing the output processors: {e}") - generator_output.error = str(e) - else: - generator_output.data = None - - async def _async_post_call(self, completion: Any) -> GeneratorOutput: - r"""Get completion and depending on whether the client is streaming from the model client, create a GeneratorOutput or an AsyncGenerator. - when the client is not streaming, the post call return a GeneratorOutput where the final output after applying the output processors is stored under the data field. - When the client is streaming, the post call returns an Async Generator which will yield the events under the raw response and then also - the final Generator Output. The final Generator Output will have under data the final output after applying the output processors. - """ - # parse chat completion will only fill the raw_response - output: GeneratorOutput = self.model_client.parse_chat_completion(completion) - # save the api response - output.api_response = completion - - log.info( - f"Response from the Model Client Stream before being processed: {output.raw_response}" - ) - - # Now adding the data field to the output and setting as the raw response - data = output.raw_response - - # Handle async iterables from OpenAI Agent/Responses API streaming - # return an async generator - if isinstance(output.raw_response, AsyncIterable): - original_raw_response = ( - output.raw_response - ) # pass in the raw response to the output processing to avoid circular dependency - output.raw_response = self._output_processing(original_raw_response, output) - else: - # return a GeneratorOutput if the raw response is not an async iterable - # process the model client's final response with the output processors - log.info(f"Response from the Model Client before being processed: {data}") - if self.output_processors: - if data: - try: - data = self.output_processors(data) - output.data = data - except Exception as e: - log.error(f"Error processing the output processors: {e}") - output.error = str(e) - else: - output.data = data - log.info(f"Response from the Model Client after being processed: {data}") - - return output - - def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]: - r"""Prepare the input, prompt_kwargs, model_kwargs for the model call.""" - # 1. render the prompt from the template - prompt_str = self.get_prompt(**prompt_kwargs) - # prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) - - # prompt_str = prompt.call(**prompt_kwargs).strip() - - # 2. combine the model_kwargs with the default model_kwargs - composed_model_kwargs = self._compose_model_kwargs(**model_kwargs) - - max_tokens = composed_model_kwargs.get("max_tokens", None) - prompt_tokens = self._tokenizer.count_tokens(prompt_str) - self._estimated_token_count = prompt_tokens - use_prompt_str = prompt_str - if max_tokens is not None: - if prompt_tokens > max_tokens: - use_prompt_str = prompt_str[:max_tokens] - log.warning( - f"Prompt is too long: {prompt_tokens} tokens, max tokens: {max_tokens}. Truncated prompt to: {use_prompt_str}" - ) - # delete max_tokens from the model_kwargs - del composed_model_kwargs["max_tokens"] - - # 3. convert app's inputs to api inputs - api_kwargs = self.model_client.convert_inputs_to_api_kwargs( - # rename from input since input is a builtin object - input=use_prompt_str, - model_kwargs=composed_model_kwargs, - model_type=self.model_type, - ) - return api_kwargs - - def _model_client_call(self, api_kwargs: Dict, use_cache: bool = False) -> Any: - # call the model client - try: - # check the cache - index_content = json.dumps(api_kwargs) # + f"training: {self.training}" - if use_cache: - # print(f"check cache first: {no_cache}") - - cached_completion = self._check_cache(index_content) - if cached_completion is not None: - return cached_completion - - completion = self.model_client.call( - api_kwargs=api_kwargs, model_type=self.model_type - ) - - # prepare cache - skip caching for streaming responses which contain unpickleable threading objects - if use_cache and not api_kwargs.get("stream", False): - self._save_cache(index_content, completion) - return completion - except Exception as e: - log.error(f"Error calling the model: {e}") - raise e - - async def _async_model_client_call( - self, api_kwargs: Dict, use_cache: bool = False - ) -> Any: - # async call the model client with caching support - try: - # check the cache - index_content = json.dumps(api_kwargs) # + f"training: {self.training}" - if use_cache: - # Check cache first - cache operations are sync - cached_completion = self._check_cache(index_content) - if cached_completion is not None: - log.debug("Cache hit for async call") - return cached_completion - - completion = await self.model_client.acall( - api_kwargs=api_kwargs, model_type=self.model_type - ) - # save to cache - skip caching for streaming responses which contain unpickleable threading objects - if use_cache and not api_kwargs.get("stream", False): - self._save_cache(index_content, completion) - return completion - except Exception as e: - log.error(f"Error calling the model: {e}") - raise e - - ############################################################################################################## - ### Forward, backwards, teacher generator, create demo data instance, - # are for training and backpropagation - ############################################################################################################## - - def create_demo_data_instance( - self, - input_prompt_kwargs: Dict[str, Any], - output: GeneratorOutput, - id: Optional[str] = None, - ): - r"""Automatically create a demo data instance from the input and output of the generator. - Used to trace the demos for the demo paramter in the prompt_kwargs. - Part of the few-shot learning. - """ - from adalflow.core.base_data_class import DynamicDataClassFactory - - # map the input fields - demo_data = {"id": id, "score": None} # add score to trace the prediction score - demo_data_class_output_mapping, output_fields = self._get_default_mapping( - output - ) - - for k, v in input_prompt_kwargs.items(): - if isinstance(v, Parameter): - demo_data[k] = v.map_to_successor(self) - else: - demo_data[k] = v - # map the output fields - for key, value in demo_data_class_output_mapping.items(): - demo_data[key] = value(output) - - obj = DynamicDataClassFactory.from_dict(demo_data) - obj.set_input_fields([k for k in input_prompt_kwargs.keys()]) - obj.set_output_fields(output_fields) - if obj is None: - raise ValueError(f"Error creating the demo data instance:{demo_data}") - return obj - - def set_backward_engine(self, backward_engine: "BackwardEngine" = None): - if backward_engine is None: - backward_engine = BackwardEngine( - model_client=self.model_client, - model_kwargs=self.model_kwargs, - ) - if self.mock_output: - backward_engine.set_mock_output() - self.backward_engine = backward_engine - - def set_teacher_generator(self, teacher: "Generator" = None): - self._teacher = teacher - print(f"Teacher generator set: {self._teacher}, teacher {teacher}") - log.debug(f"Teacher generator set: {self._teacher}") - - # def set_data_map_func(self, map_func: Callable = None): - # def default_map_func(data: "GeneratorOutputType") -> str: - # return ( - # data.data - # if data.data - # else self.failure_message_to_backward_engine(data) - # ) - - # self.data_map_func = map_func or default_map_func - - # log.debug(f"Data map function set: {self.data_map_func}") - - # TODO: limit to only one demo parameter. - @staticmethod - def find_demo_parameter(prompt_kwargs: Dict) -> Optional[Parameter]: - from adalflow.optim.parameter import Parameter, ParameterType - - for p in prompt_kwargs.values(): - if isinstance(p, Parameter) and p.param_type == ParameterType.DEMOS: - return p - return None - - def forward( - self, - prompt_kwargs: Optional[ - Dict[str, Union[str, Parameter]] - ] = {}, # the input need to be passed to the prompt - model_kwargs: Optional[Dict] = {}, - id: Optional[str] = None, - ) -> "Parameter": - r"""Customized forward pass on top of the GradComponent forward method.""" - # 1. convert prompt_kwargs to parameter if it is not - for k, v in prompt_kwargs.items(): - if not isinstance(v, Parameter): - prompt_kwargs[k] = Parameter( - data=v, - name=f"{self.name}_{k}", - requires_opt=False, - param_type=ParameterType.INPUT, - data_id=id, - ) - - # 2. call the model - unwrapped_prompt_kwargs: Dict[str, Any] = {} - for k, v in prompt_kwargs.items(): - if isinstance(v, Parameter): - if v.param_type == ParameterType.INPUT: - v.data_id = id - unwrapped_prompt_kwargs[k] = v.map_to_successor(self) - else: - unwrapped_prompt_kwargs[k] = v - log.debug( - f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}" - ) - log.debug(f"prompt template: {self.template}") - - output: GeneratorOutputType = None - input_args = {} - if self.mock_output: - output = GeneratorOutput(data=self.mock_output_data) - else: - if self.teacher_mode and not isinstance(self, BackwardEngine): - if not self._teacher: - log.debug( - f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}" - ) - log.debug(f"names: {self.name}") - raise ValueError("Teacher generator is not set.") - log.info(f"Using teacher: {self._teacher}") - input_args = { - "prompt_kwargs": compose_model_kwargs( - self._teacher.prompt_kwargs, unwrapped_prompt_kwargs - ), - "model_kwargs": compose_model_kwargs( - self._teacher.model_kwargs, model_kwargs - ), - } - output = self._teacher.call(**input_args, id=id) - else: - input_args = { - "prompt_kwargs": compose_model_kwargs( - self.prompt_kwargs, unwrapped_prompt_kwargs - ), - "model_kwargs": compose_model_kwargs( - self.model_kwargs, model_kwargs - ), - } - - output = self.call(**input_args, id=id) - if not isinstance(output, GeneratorOutput): - raise ValueError( - f"Output should be of type GeneratorOutput, got {type(output)}" - ) - # 2. Generate a Parameter object from the output - combined_prompt_kwargs = compose_model_kwargs(self.prompt_kwargs, prompt_kwargs) - # if self.data_map_func is None: - # self.set_data_map_func() - - predecessors = [ - p for p in combined_prompt_kwargs.values() if isinstance(p, Parameter) - ] - - log.debug(f"Predecessors: {predecessors} for generator {self.name}") - - def data_to_prompt_map_fn(data: Parameter) -> str: - """GeneratorOutput will show the raw response instead of just the final data. - The backward engine and optimizer should look at all reasoning to decide the gradient. - """ - data: GeneratorOutput = data.data - if data.error is not None: - return f"Response: {data.raw_response} parsed with error: {data.error}" - return f" {data.raw_response}" - - # TODO: all parameter should just wrap the whole output. - # this is for training. - param_data = output - response: Parameter = OutputParameter( - data=param_data, - name=self.name + "_output", - role_desc=f"Output from (llm) {self.name}", - param_type=ParameterType.GENERATOR_OUTPUT, - data_id=id, - full_response=output, # the data structure - data_in_prompt=data_to_prompt_map_fn, - ) - response.set_predecessors(predecessors) - response.trace_forward_pass( - input_args=input_args, full_response=output, id=self.id, name=self.name - ) - # setattr(response, "full_response", output) - # *** special to the generator *** - response.trace_api_kwargs(api_kwargs=self._trace_api_kwargs) - # attach the demo to the demo parameter - # if self.tracing: - demo_param = self.find_demo_parameter(combined_prompt_kwargs) - - if demo_param: - if id is None: - raise ValueError( - "ID is required for tracing. Please pass it to your Geneartor call." - ) - - demo = self.create_demo_data_instance( - prompt_kwargs, - output, - id=id, - ) - demo_param.add_dataclass_to_trace(demo, is_teacher=self.teacher_mode) - else: - log.debug( - "No demo parameter found in the prompt_kwargs. You can not trace the demo data." - ) - - # **** end of the special to the generator **** - - # if not self.backward_engine: - # # self.set_backward_engine() - # log.debug(f"Backward engine: {self.backward_engine}") - - # attach a funtion to compute gradient for predecessors - log.debug(f"disable_backward_engine: {self._disable_backward_engine}") - - response.set_grad_fn( - BackwardContext( - backward_fn=self.backward, - backward_engine=self.backward_engine, - response=response, - prompt_kwargs=prompt_kwargs, - template=self.template, - prompt_str=self.get_prompt(**combined_prompt_kwargs), - disable_backward_engine=self._disable_backward_engine, - id=id, - ) - ) - return response - - def backward( - self, - response: Parameter, # the output of the forward pass - prompt_kwargs: Dict, - template: str, - prompt_str: str, - backward_engine: Optional["Generator"] = None, - id: Optional[str] = None, # the id of the input - disable_backward_engine: bool = False, - ) -> Parameter: - - log.info(f"Generator: Backward: {response.name}") - - backward_pass_setup = ( - backward_engine.backward_pass_setup if backward_engine else None - ) - log.debug( - f"backward pass setup: {backward_pass_setup}, name: {self.name}", - color="red", - ) - - children_params = response.predecessors - is_intermediate_node = True - if response.get_gradient_and_context_text().strip() == "": - log.info(f"Generator: Backward: No gradient found for {response}.") - - # backward score to the demo parameter - for pred in children_params: - # if pred.requires_opt: - if response.score is not None: - pred.set_score(response.score) - log.debug( - f"backpropagate the score {response.score} to {pred.name}, is_teacher: {self.teacher_mode}" - ) - if pred.param_type == ParameterType.DEMOS: - # Accumulate the score to the demo - pred.add_score_to_trace( - trace_id=id, score=response.score, is_teacher=self.teacher_mode - ) - log.debug(f"Pred: {pred.name}, traces: {pred._traces}") - - # 1.backward for text-gradients - if backward_engine: - log.debug( - f"Generator: Backward engine is set for the generator. {backward_engine}" - ) - # if response.backward_engine_disabled: - # for pred in children_params: - # pred.backward_engine_disabled = True - # return - - all_pred_at_once = backward_pass_setup.all_pred_at_once - - if not all_pred_at_once: - for pred in children_params: - if not pred.requires_opt or pred.param_type == ParameterType.DEMOS: - log.debug( - f"EvalFnToTextLoss: Skipping {pred} as it does not require optimization." - ) - continue - - self._backward_through_one_predecessor( - pred=pred, - response=response, - prompt_kwargs=prompt_kwargs, - # template=template, - backward_engine=backward_engine, - prompt_str=prompt_str, - backward_pass_setup=backward_pass_setup, - is_intermediate_node=is_intermediate_node, - disable_backward_engine=disable_backward_engine, - ) - else: - backward = False - for pred in children_params: - if pred.requires_opt and pred.param_type in [ - ParameterType.PROMPT, - ParameterType.GENERATOR_OUTPUT, - ParameterType.RETRIEVER_OUTPUT, - ParameterType.OUTPUT, - ]: - backward = True - break - if backward: - # 2nd approach, backward all that need opt at once. - self._backward_through_all_predecessors( - children_params=children_params, - response=response, - prompt_kwargs=prompt_kwargs, - template=template, - backward_engine=backward_engine, - prompt_str=prompt_str, - backward_pass_setup=backward_pass_setup, - is_intermediate_node=is_intermediate_node, - ) - else: - log.debug("Backward engine is not set for the generator. No text gradient.") - - @staticmethod - def _backward_through_all_predecessors( - children_params: List[Parameter], - response: Parameter, - prompt_kwargs: Dict[str, str], - backward_engine: "BackwardEngine", - backward_pass_setup: BackwardPassSetup, - is_intermediate_node: bool = False, - ): - parser = JsonParser() - # instruction and objective is the same for all the children - instruction_str, objective_str = None, None - - # 1. Generate the conversation input and output - input_prompt_kwargs = { - k: v.get_prompt_data() if isinstance(v, Parameter) else v - for k, v in prompt_kwargs.items() - } - - print(f"gt: {response.get_gt()}") - - # TODO: pass all the parameters and even the templates - conversation_prompt_kwargs = { - "input_value": input_prompt_kwargs, - "llm_output": response.get_prompt_data(), - } - - conversation_str = Prompt( - prompt_kwargs=conversation_prompt_kwargs, - template=LLM_CONVERSATION_TEMPLATE, - )() - - all_pred_info = Prompt( - prompt_kwargs={"variables": [p.get_param_info() for p in children_params]}, - template=ALL_PRED_INFO, - )() - - conv_ins_template = None # CONVERSATION_START_INSTRUCTION_BASE - obj_ins_template = OBJECTIVE_INSTRUCTION_BASE - if is_intermediate_node: # TODO: this will always be true - conv_ins_template = CONVERSATION_START_INSTRUCTION_CHAIN - obj_ins_template = OBJECTIVE_INSTRUCTION_CHAIN - response_gradient = response.get_gradients_str() - # response_gradient = response.get_gradients_component_schema() - # response_gradient = response.get_gradients_component_schema( - # skip_correct_sample=False - # ) - if not response_gradient: - raise ValueError( - f"Generator: No gradient found for {response}. Please check the response." - ) - - # replace variable and peers with all_pred_info - - instruction_str = Prompt( - template=conv_ins_template, - prompt_kwargs={ - "variable_and_peers_info": all_pred_info, - "conversation_str": conversation_str, - }, - )() - objective_str = Prompt( - template=obj_ins_template, - prompt_kwargs={ - "response_desc": response.role_desc, - "response_gradient": response_gradient, - "instruction_to_backward_engine": response.instruction_to_backward_engine, - }, - )() - - backward_engine_prompt_kwargs = { - "conversation_sec": instruction_str, - "objective_instruction_sec": objective_str, - "output_format_str": OUTPUT_INSTRUCTION, - } - - backward_engine_prompt_str = backward_engine.get_prompt( - **backward_engine_prompt_kwargs - ) - # print(f"Backward engine prompt: {backward_engine_prompt_str}") - - gradient_output: GeneratorOutput = None - response_gradient_list = [""] * len(children_params) - if ( - backward_pass_setup.compute_grad_for_errors_only - and response.score is not None - and float(response.score) - > backward_pass_setup.threshold_score_to_compute_grad_for_errors - ): - manual_response_1 = f"Eval score: {response.score}. No noticeable error." - response_gradient_list = [manual_response_1] * len(children_params) - raw_response = str(response_gradient_list) - gradient_output = GeneratorOutput( - data=response_gradient_list, raw_response=raw_response - ) - else: - - gradient_output: GeneratorOutput = backward_engine( - prompt_kwargs=backward_engine_prompt_kwargs - ) - if not isinstance(gradient_output, GeneratorOutput): - raise ValueError( - f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead." - ) - - # parse the list of gradients - - try: - response_gradient_list = parser.call(gradient_output.data) - except Exception as e: - log.error(f"Error parsing the response_gradient_list: {e}") - failure_message = backward_engine.failure_message_to_optimizer( - gradient_output - ) - if failure_message: - response_gradient_list = [failure_message] * len(children_params) - - log.debug(f"failure_message: {failure_message}") - - # computes gradient for each prompt predecessor - for i, pred in enumerate(children_params): - if not pred.requires_opt or pred.param_type == ParameterType.DEMOS: - log.debug( - f"Generator: Skipping {pred} as it does not require optimization." - ) - continue - - gradient_data = ( - response_gradient_list[i] - if response_gradient_list and len(response_gradient_list) > i - else "Failed to get the gradient." - ) - - var_gradient = Gradient( - data=gradient_data, - data_id=response.data_id, - score=response.score, # add score to gradient - from_response=response, - to_pred=pred, - ) - var_gradient.add_context( - GradientContext( - input_output=conversation_str, - response_desc=response.role_desc, - variable_desc=pred.role_desc, # the only difference for each pred - ) - ) - var_gradient.add_prompt(backward_engine_prompt_str) - pred.add_gradient(var_gradient) - if response.score is not None: - pred.set_score(response.score) - - @staticmethod - def _backward_through_one_predecessor( - pred: Parameter, - response: Parameter, - prompt_kwargs: Dict[str, str], - backward_engine: "BackwardEngine", - prompt_str: str, - backward_pass_setup: BackwardPassSetup, - is_intermediate_node: bool = False, - disable_backward_engine: bool = False, - ): - """Creating gradient/textual feedback for prompt type parameters.""" - if not pred.requires_opt: - if response.score is not None: - pred.set_score(response.score) - log.debug( - f"Generator: Skipping {pred} as it does not require optimization." - ) - return - - if pred.check_if_already_computed_gradient_respect_to(response.id): - log.debug( - f"Generator: Skipping {pred} as the gradient is already computed." - ) - - return - - if backward_engine is None: - log.error( - "EvalFnToTextLoss: backward_engine is required for text prompt optimization." - ) - raise ValueError( - "EvalFnToTextLoss: backward_engine is required for text prompt optimization." - ) - - instruction_str, objective_str = None, None - - # 1. Generate the conversation string - input_prompt_kwargs = { - k: v.get_prompt_data() if isinstance(v, Parameter) else v - for k, v in prompt_kwargs.items() - } - - conversation_prompt_kwargs = { - "input_value": input_prompt_kwargs, - "llm_output": response.get_prompt_data(), - "gt": response.get_gt(), - } - - conversation_str = Prompt( - prompt_kwargs=conversation_prompt_kwargs, - template=LLM_CONVERSATION_TEMPLATE, - )() - - variable_dict = pred.get_param_info() - - peers = [p.get_param_info() for p in pred.peers] - - variable_and_peers_info = Prompt( - prompt_kwargs={"variable": variable_dict, "peers": peers}, - template=VARIABLE_AND_PEERS_INFO, - )() - - # generator is almost always intermediate node - conv_ins_template = None # CONVERSATION_START_INSTRUCTION_BASE - obj_ins_template = OBJECTIVE_INSTRUCTION_BASE - if is_intermediate_node: # TODO: this will always be true - conv_ins_template = CONVERSATION_START_INSTRUCTION_CHAIN - obj_ins_template = OBJECTIVE_INSTRUCTION_CHAIN - response_gradient = response.get_gradients_str() - # response_gradient = response.get_gradients_component_schema() - if not response_gradient: - raise ValueError( - f"Generator: No gradient found for {response}. Please check the response. pred: {pred}" - ) - predecessors = [ - pred.get_param_info() - for pred in response.predecessors - if pred not in pred.peers - ] - instruction_str = Prompt( - template=conv_ins_template, - prompt_kwargs={ - "variable_and_peers_info": variable_and_peers_info, - "conversation_str": conversation_str, - "predecessors": predecessors, - }, - )() - log.info(f"Conversation start instruction base str: {instruction_str}") - objective_str = Prompt( - template=obj_ins_template, - prompt_kwargs={ - "response_desc": response.role_desc, - "response_gradient": response_gradient, - "instruction_to_backward_engine": pred.instruction_to_backward_engine, - }, - )() - - backward_engine_prompt_kwargs = { - "conversation_sec": instruction_str, - "objective_instruction_sec": objective_str, - } - backward_engine_prompt_str = backward_engine.get_prompt( - **backward_engine_prompt_kwargs - ) - # print(f"Backward engine prompt: {backward_engine_prompt_str}") - gradient_value = None - if not disable_backward_engine: - gradient_output: GeneratorOutput = None - if ( - backward_pass_setup.compute_grad_for_errors_only - and response.score is not None - and float(response.score) - > backward_pass_setup.threshold_score_to_compute_grad_for_errors - ): - log.debug( - f"EvalFnToTextLoss: Skipping {pred} as the score is high enough." - ) - # TODO: plus score descriptions - manual_response = f"Eval score: {response.score}. No noticeable error." - gradient_output = GeneratorOutput( - data=manual_response, raw_response=manual_response - ) - else: - - gradient_output: GeneratorOutput = backward_engine( - prompt_kwargs=backward_engine_prompt_kwargs - ) - prompt_str = backward_engine.get_prompt( # noqa F841 - **backward_engine_prompt_kwargs - ) - # printc(f"Backward engine prompt: {prompt_str}") - if not isinstance(gradient_output, GeneratorOutput): - raise ValueError( - f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead." - ) - # printc(f"Backward engine gradient: {gradient_output}") - - # USE this to trace each node's input and output, all nodes can be visualized - log.info( - f"Generator Backward Engine Prompt: {backward_engine.get_prompt( **backward_engine_prompt_kwargs)}" - ) - gradient_value = ( - gradient_output.data - or backward_engine.failure_message_to_optimizer(gradient_output) - ) - var_gradient = Gradient( - data=gradient_value, - data_id=response.data_id, - score=response.score, # add score to gradient - from_response=response, - to_pred=pred, - ) - # Component-level input and output. - var_gradient.add_context( - GradientContext( - input_output=conversation_str, - response_desc=response.role_desc, - variable_desc=pred.role_desc, # parameter_desc - ) - ) - var_gradient.add_prompt(backward_engine_prompt_str) - pred.add_gradient(var_gradient) - if response.score is not None: - pred.set_score(response.score) - - def _run_callbacks( - self, - output: GeneratorOutput, - input: Dict, - prompt_kwargs: Dict, - model_kwargs: Dict, - ): - self.trigger_callbacks( - "on_complete", - output=output, - input=input, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - ) - if output.error: - self.trigger_callbacks( - "on_failure", - output=output, - input=input, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - ) - else: - self.trigger_callbacks( - "on_success", - output=output, - input=input, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - ) - - def call( - self, - prompt_kwargs: Optional[Dict] = {}, # supports both str and parameter value - model_kwargs: Optional[Dict] = {}, # can take images = {input: }, tools=[{"type": "image_generation"}], for image generation - use_cache: Optional[bool] = None, - id: Optional[str] = None, - ) -> GeneratorOutputType: - r""" - Call the model_client by formatting prompt from the prompt_kwargs, - and passing the combined model_kwargs to the model client. - """ - prompt_str = self.get_prompt(**prompt_kwargs) - with generator_span( - generator_id="generator" + (id if id else ""), - model_kwargs=self._compose_model_kwargs(**model_kwargs), - prompt_kwargs=prompt_kwargs, - prompt_template_with_keywords=prompt_str, - ) as generator_span_data: - generation_time = time.time() - if self.mock_output: - return GeneratorOutput(data=self.mock_output_data, id=id, input=prompt_str) - - log.debug(f"prompt_kwargs: {prompt_kwargs}") - log.debug(f"model_kwargs: {model_kwargs}") - - api_kwargs = self._pre_call(prompt_kwargs, model_kwargs) - - log.debug(f"api_kwargs: {api_kwargs}") - output: GeneratorOutputType = None - # call the model client - - completion = None - use_cache = use_cache if use_cache is not None else self._use_cache - try: - completion = self._model_client_call( - api_kwargs=api_kwargs, use_cache=use_cache - ) - except Exception as e: - log.error(f"Error calling the model: {e}") - output = GeneratorOutput(error=str(e), id=id, input=prompt_str) - # process the completion - if completion is not None: - try: - log.debug(f"Entering _post_call with completion: {completion}") - output = self._post_call(completion) - except Exception as e: - log.error(f"Error processing the output: {e}") - output = GeneratorOutput( - raw_response=str(completion), error=str(e), id=id, input=prompt_str - ) - - # User only need to use one of them, no need to use them all. - output.id = id - output.input = prompt_str - self._run_callbacks( - output, - input=api_kwargs, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - ) - - log.info(f"output: {output}") - self._trace_api_kwargs = api_kwargs # tracing - - generator_span_data.span_data.update_attributes( - {"raw_response": output.raw_response} - ) - generator_span_data.span_data.update_attributes( - {"api_response": output.api_response} - ) - generator_span_data.span_data.update_attributes( - {"final_response": output.data} - ) - generator_span_data.span_data.update_attributes( - {"generation_time_in_seconds": time.time() - generation_time} - ) - generator_span_data.span_data.update_attributes( - {"token_usage": output.usage} - ) - generator_span_data.span_data.update_attributes({"api_kwargs": api_kwargs}) - - return output - - # TODO: training is not supported in async call yet - async def acall( - self, - prompt_kwargs: Optional[Dict] = {}, - model_kwargs: Optional[Dict] = {}, - use_cache: Optional[bool] = None, - id: Optional[str] = None, - ) -> Union[GeneratorOutputType, AsyncGenerator[GeneratorOutputType, None]]: - r"""Async call the model with the input and model_kwargs. - - :warning:: - Training is not supported in async call yet. - """ - prompt_str = self.get_prompt(**prompt_kwargs) - - with generator_span( - generator_id="generator" + (id if id else ""), - model_kwargs=self._compose_model_kwargs(**model_kwargs), - prompt_kwargs=prompt_kwargs, - prompt_template_with_keywords=prompt_str, - ) as generator_span_data: - - if self.mock_output: - generator_span_data.span_data.update_attributes( - {"final_response": self.mock_output_data} - ) - return GeneratorOutput(data=self.mock_output_data, id=id, input=prompt_str) - - generation_time = time.time() - log.info(f"prompt_kwargs: {prompt_kwargs}") - log.info(f"model_kwargs: {model_kwargs}") - - api_kwargs = self._pre_call(prompt_kwargs, model_kwargs) - output: GeneratorOutputType = None - # call the model client - completion = None - use_cache = use_cache if use_cache is not None else self._use_cache - - log.info(f"api_kwargs: {api_kwargs}") - - try: - completion = await self._async_model_client_call( - api_kwargs=api_kwargs, use_cache=use_cache - ) - except Exception as e: - log.error(f"Error calling the model: {e}") - output = GeneratorOutput(error=str(e), id=id, input=prompt_str) - - if completion is not None: - try: - # set ouput id in async post call instead - output = await self._async_post_call(completion) - except Exception as e: - log.error(f"Error processing the output: {e}") - output = GeneratorOutput( - raw_response=str(completion), error=str(e), id=id, input=prompt_str - ) - - # User only need to use one of them, no need to use them all. - output.id = id - output.input = prompt_str - log.info(f"output: {output}") - # if the output is not an async generator then set its id and run call backs - # TODO support when output is an Async Generator - if not isinstance(output, AsyncGenerator): - self._run_callbacks( - output, - input=api_kwargs, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - ) - self._trace_api_kwargs = api_kwargs # tracing - - # Update generator span attributes similar to call method - generator_span_data.span_data.update_attributes( - {"raw_response": output.raw_response if output else None} - ) - generator_span_data.span_data.update_attributes( - {"api_response": output.api_response if output else None} - ) - generator_span_data.span_data.update_attributes( - {"final_response": output.data if output else None} - ) - generator_span_data.span_data.update_attributes( - {"generation_time_in_seconds": time.time() - generation_time} - ) - generator_span_data.span_data.update_attributes( - {"token_usage": output.usage if output else None} - ) - generator_span_data.span_data.update_attributes({"api_kwargs": api_kwargs}) - - return output - - def __call__(self, *args, **kwargs) -> Union[GeneratorOutputType, Any]: - if self.training: - log.debug("Training mode") - return self.forward(*args, **kwargs) - else: - log.debug("Inference mode") - return self.call(*args, **kwargs) - - def _extra_repr(self) -> str: - # Create the string for model_kwargs - s = f"model_kwargs={self.model_kwargs}, " - - # Create the string for trainable prompt_kwargs - prompt_kwargs_repr = [ - k - for k, v in self.prompt_kwargs.items() - if isinstance(v, Parameter) and v.requires_opt - ] - - s += f"trainable_prompt_kwargs={prompt_kwargs_repr}" - return s - - def to_dict(self) -> Dict[str, Any]: - r"""Convert the generator to a dictionary.""" - # TODO: exclude default functions - return super().to_dict() - - @staticmethod - def failure_message_to_backward_engine( - gradient_response: GeneratorOutput, - ) -> Optional[str]: - response_value = None - if gradient_response.error or not gradient_response.data: - response_value = f"Error: {gradient_response.error}, Raw response: {gradient_response.raw_response}" - return response_value - - -class BackwardEngine(Generator): # it is a generator with defaule template - - __doc__ = """A Generator with a default template for the backward pass in auto-differentiation. - - As a component, the forward pass is simply the same as the call method. - So it will always return GeneratorOutputType instead of Parameter. - - If you want to customize the template, you can create your own backward engine. - Yet, we will forever keep the training mode to False for the backward engine. - This is achieved by making forward the same as call. - """ - - def __init__(self, **kwargs): - if kwargs is None: - kwargs = {} - kwargs["template"] = FEEDBACK_ENGINE_TEMPLATE - - super().__init__(**kwargs) - self.name = "BackwardEngine" - self.teacher_mode = False - - def call(self, **kwargs) -> GeneratorOutputType: - r"""Catch the rate limit error and raise it.""" - output = super().call(**kwargs) - if output and output.error is not None and "429" in output.error: - raise ValueError(f"Error in the backward engine: {output.error}") - return output - - def forward(self, **kwargs): - r"""Forward pass for the backward engine.""" - return self.call(**kwargs) - - @staticmethod - def failure_message_to_optimizer( - gradient_response: GeneratorOutput, - ) -> Optional[str]: - gradient_value_data = None - if gradient_response.error or not gradient_response.data: - gradient_value_data = f"The backward engine failed to compute the gradient. Raw response: {gradient_response.raw_response}, Error: {gradient_response.error}" - - return gradient_value_data - - -def create_teacher_generator( - student: Generator, - model_client: ModelClient, - model_kwargs: Dict[str, Any], - template: Optional[str] = None, -) -> Generator: - r"""Create a teacher generator from the student generator. - - Note: - Teacher generator will have no parameters. - If you want to keep it to be the same as the student, just create one each time your student has been updated. - Or else, task.parameters will list teacher parameters. - - Args: - student (Generator): The student generator. - model_client (ModelClient): The model client to use for the teacher generator. - model_kwargs (Dict[str, Any]): The model kwargs to pass to the model client. - name (str, optional): The name of the teacher generator. Defaults to "teacher". - - Returns: - Generator: The teacher generator. - """ - kwargs = student._kwargs.copy() - kwargs["model_client"] = model_client - kwargs["model_kwargs"] = model_kwargs - if template: - kwargs["template"] = template - kwargs["name"] = f"{student.name}_teacher" - - prompt_kwargs_str: Dict[str, str] = {} - for key, p in kwargs["prompt_kwargs"].items(): - if isinstance(p, Parameter): - prompt_kwargs_str[key] = str(p.data) - else: - prompt_kwargs_str[key] = p - kwargs["prompt_kwargs"] = prompt_kwargs_str - teacher = Generator( - **kwargs, - ) - return teacher - - -if __name__ == "__main__": - # test the generator with backward engine - # TODO: move this to external local tests before packaging - from adalflow.components.model_client import ( - GroqAPIClient, - OpenAIClient, - GoogleGenAIClient, - AnthropicAPIClient, - ) - from adalflow.utils import setup_env - from adalflow.core.model_client import ModelClient - - setup_env() - # log = get_logger(level="DEBUG") - llama3_model = { - "model_client": GroqAPIClient(), - "model_kwargs": { - "model": "llama-3.1-8b-instant", - }, - } - gpt_3_model = { - "model_client": OpenAIClient(), - "model_kwargs": { - "model": "gpt-3.5-turbo", - }, - } - gemini_model = { - "model_client": GoogleGenAIClient(), - "model_kwargs": { - "model": "gemini-1.0-pro", - }, - } - claude_model = { - "model_client": AnthropicAPIClient(), - "model_kwargs": { - "model": "claude-3-opus-20240229", - "max_tokens": 100, - }, - } - from adalflow.tracing.generator_call_logger import GeneratorCallLogger - from functools import partial - - # setup the logger - call_logger = GeneratorCallLogger(save_dir="traces") - - def on_complete(output, input, prompt_kwargs, model_kwargs, logger_call: Callable): - logger_call( - output=output, - input=input, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - ) - - for model in [llama3_model, gpt_3_model, gemini_model, claude_model]: - generator = Generator(**model) - - teacher = create_teacher_generator(generator, **claude_model) - - call_logger.register_generator("generator", "generator_call") - # setup the callback - logger_call = partial(call_logger.log_call, name="generator") - generator.register_callback( - "on_complete", partial(on_complete, logger_call=logger_call) - ) - - output = generator( - prompt_kwargs={ - "input_str": "Hello, world!", - } - ) - break - - # test the backward engine - # TODO: test ollama and transformer client to update the change +"""Generator is a user-facing orchestration component with a simple and unified interface for LLM prediction. + +It is a pipeline that consists of three subcomponents.""" + +import json +import re +import time +from pathlib import Path + +from typing import Any, Dict, Optional, Union, Callable, Tuple, List, AsyncGenerator +from collections.abc import AsyncIterable +import logging +from dataclasses import dataclass, field + +from openai.types.responses import ResponseCompletedEvent +from adalflow.core.tokenizer import Tokenizer + + +from adalflow.core.types import ( + ModelType, + GeneratorOutput, + GeneratorOutputType, +) +from adalflow.core.component import Component, DataComponent +from adalflow.optim.grad_component import GradComponent +from adalflow.core.base_data_class import DataClass + + +from adalflow.optim.parameter import ( + Parameter, + OutputParameter, +) +from adalflow.optim.gradient import GradientContext, Gradient +from adalflow.optim.types import ParameterType + +from adalflow.core.prompt_builder import Prompt +from adalflow.core.functional import compose_model_kwargs +from adalflow.core.model_client import ModelClient +from adalflow.core.default_prompt_template import DEFAULT_ADALFLOW_SYSTEM_PROMPT +from adalflow.optim.function import BackwardContext +from adalflow.utils.cache import CachedEngine +from adalflow.tracing.callback_manager import CallbackManager +from adalflow.tracing import generator_span +from adalflow.utils.global_config import get_adalflow_default_root_path +from adalflow.core.string_parser import JsonParser + +from adalflow.optim.text_grad.backend_engine_prompt import ( + FEEDBACK_ENGINE_TEMPLATE, + LLM_CONVERSATION_TEMPLATE, + ALL_PRED_INFO, + OUTPUT_INSTRUCTION, + VARIABLE_AND_PEERS_INFO, + CONVERSATION_START_INSTRUCTION_CHAIN, + OBJECTIVE_INSTRUCTION_BASE, + OBJECTIVE_INSTRUCTION_CHAIN, +) + +__all__ = ["Generator", "BackwardEngine", "create_teacher_generator"] + + +log = logging.getLogger(__name__) + + +PromptArgType = Dict[str, Union[str, Parameter]] + + +@dataclass +class BackwardPassSetup(DataClass): + all_pred_at_once: bool = field( + default=False, metadata={"desc": "Backward all predecessors at once."} + ) + threshold_score_to_compute_grad_for_errors: float = field( + default=0.9, + metadata={"desc": "Threshold score to compute gradient for errors."}, + ) + compute_grad_for_errors_only: bool = field( + default=True, metadata={"desc": "Compute gradient for errors only."} + ) + + +# TODO: better debug mode +class Generator(GradComponent, CachedEngine, CallbackManager): + __doc__ = """An user-facing orchestration component for LLM prediction. + + It is also a GradComponent that can be used for backpropagation through the LLM model. + + By orchestrating the following three components along with their required arguments, + it enables any LLM prediction with required task output format. + - Prompt + - Model client + - Output processors + + Args: + model_client (ModelClient): The model client to use for the generator. + model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}. Please refer to :ref:`ModelClient` for the details on how to set the model_kwargs for your specific model if it is from our library. + model_type (ModelType, optional): The type of the model. Defaults to ModelType.LLM. When using reasoning models which calls different api, you should set it to ModelType.LLM_REASONING. + template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT`. + prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None. + output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None. + trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to []. + + Note: + 1. The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output. + And we encourage you to only use it to parse the response to data format you will use later. + 2. For structured output, you should avoid using `stream` as the output_processors can only be run after all the data is available. + """ + + model_type: ModelType = ModelType.LLM + model_client: ModelClient # for better type checking + + _use_cache: bool = False + _kwargs: Dict[str, Any] = ( + {} + ) # to create teacher generator from student TODO: might reaccess this + + backward_pass_setup: BackwardPassSetup = ( + BackwardPassSetup() + ) # default setup for the backward pass + + def __init__( + self, + *, + # args for the model + model_client: ModelClient, # will be intialized in the main script + model_kwargs: PromptArgType = {}, + model_type: Optional[ModelType] = ModelType.LLM, + # args for the prompt + template: Optional[str] = None, + prompt_kwargs: Optional[Dict] = {}, + # args for the output processing + output_processors: Optional[DataComponent] = None, + name: Optional[str] = None, + # args for the cache + cache_path: Optional[str] = None, + use_cache: bool = True, + ) -> None: + r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables: + - task_desc_str + - tools_str + - example_str + - chat_history_str + - context_str + - steps_str + You can preset the prompt kwargs to fill in the variables in the prompt using prompt_kwargs. + But you can replace the prompt and set any variables you want and use the prompt_kwargs to fill in the variables. + """ + + if not isinstance(model_client, ModelClient): + raise TypeError( + f"{type(self).__name__} requires a ModelClient instance for model_client, please pass it as OpenAIClient() or GroqAPIClient() for example.\ + Got {model_client} instead." + ) + + template = template or DEFAULT_ADALFLOW_SYSTEM_PROMPT + + # create the cache path and initialize the cache engine + + self.set_cache_path( + cache_path, model_client, model_kwargs.get("model", "default") + ) + + CachedEngine.__init__(self, cache_path=self.cache_path) + + Component.__init__(self) + GradComponent.__init__(self, desc="Generate a response using LLM model.") + CallbackManager.__init__(self) + + self.name = name or self.__class__.__name__ + self.template = template + self.prompt_kwargs = prompt_kwargs.copy() + + self.model_kwargs = model_kwargs.copy() + # init the model client + self.model_client = model_client + self.model_type = model_type + + self.output_processors = output_processors + + if output_processors and (not isinstance(output_processors, DataComponent)): + raise ValueError( + f"output_processors should be a DataComponent instance, got {type(output_processors)}" + ) + + self.set_parameters(prompt_kwargs) + + # end of trainable parameters + self.backward_engine: "BackwardEngine" = None + log.info(f"Generator {self.name} initialized.") + # to support better testing on the parts beside of the model call + self.mock_output: bool = False + self.mock_output_data: str = "mock data" + + self._use_cache = use_cache + + self._kwargs = { + "model_client": model_client, + "model_kwargs": model_kwargs, + "template": template, + "prompt_kwargs": prompt_kwargs, + "output_processors": output_processors, + "name": name, + "cache_path": cache_path, + "use_cache": use_cache, + } + self._teacher: Optional["Generator"] = None + self._trace_api_kwargs: Dict[str, Any] = ( + {} + ) # used by dynamic computation graph and backpropagation + + self._tokenizer: Tokenizer = Tokenizer() + self._estimated_token_count: int = 0 + + @property + def use_cache(self): + return self._use_cache + + + @property + def estimated_token_count(self) -> int: + """Property to access the estimated token count from the last prompt. + + Returns: + int: The estimated token count from the last processed prompt. + Returns 0 if no prompt has been processed yet. + """ + return self._estimated_token_count + + def update_default_backward_pass_setup(self, setup: BackwardPassSetup): + self.backward_pass_setup = setup + + def set_cache_path(self, cache_path: str, model_client: object, model: str): + """Set the cache path for the generator.""" + + # Construct a valid model string using the client class name and model + self.model_str = f"{model_client.__class__.__name__}_{model}" + + # Remove any characters that are not allowed in file names (cross-platform) + # On Windows, characters like `:<>?/\|*` are prohibited. + self.model_str = re.sub(r"[^a-zA-Z0-9_\-]", "_", self.model_str) + + _cache_path = ( + get_adalflow_default_root_path() if cache_path is None else cache_path + ) + + # Use pathlib to handle paths more safely across OS + self.cache_path = Path(_cache_path) / f"cache_{self.model_str}.db" + + log.debug(f"Cache path set to: {self.cache_path}") + + def get_cache_path(self) -> str: + r"""Get the cache path for the generator.""" + return self.cache_path + + @staticmethod + def _get_default_mapping( + output: "GeneratorOutput" = None, + ) -> Tuple[Dict[str, Callable], List[str]]: + + if ( + output.data + and isinstance(output.data, DataClass) + and len(output.data.get_output_fields()) > 0 + ): + output_fields = output.data.get_output_fields() + + output_mapping = { + f: lambda x, f=f: getattr(x.data, f) for f in output_fields + } + elif output.raw_response: + output_fields = ["raw_response"] + output_mapping = {f: lambda x, f=f: getattr(x, f) for f in output_fields} + output_fields = ["Answer"] + output_mapping["Example"] = output_mapping["raw_response"] + del output_mapping["raw_response"] + else: + # When both output.data and output.raw_response are None + # (e.g. API call failed entirely), provide a safe fallback + # to prevent UnboundLocalError at the return statement. + output_fields = [] + output_mapping = {} + log.warning( + "Generator output has neither data nor raw_response. " + "Returning empty mapping. This usually indicates a failed API call." + ) + + return output_mapping, output_fields + + def set_mock_output( + self, mock_output: bool = True, mock_output_data: str = "mock data" + ): + self.mock_output = mock_output + self.mock_output_data = mock_output_data + + def reset_mock_output(self): + self.mock_output = False + self.mock_output_data = "mock data" + + def set_parameters(self, prompt_kwargs: PromptArgType): + r"""Set name for each paramter and set all context for each other. + Make all parameters attributes to the generator for finding them easily + for optimizers and other components. + """ + for key, p in prompt_kwargs.items(): + if isinstance(p, Parameter): + if not p.name or p.name == "": + p.name = key + peers = [ + p + for k, p in prompt_kwargs.items() + if isinstance(p, Parameter) and k != key + # and p.param_type == ParameterType.PROMPT + ] + p.set_peers(peers) + setattr(self, key, p) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Generator": + r"""Create a Generator instance from the config dictionary. + + Example: + + .. code-block:: python + + config = { + "model_client": { + "component_name": "OpenAIClient", + "component_config": {} + }, + "model_kwargs": {"model": "gpt-3.5-turbo", "temperature": 0} + } + generator = Generator.from_config(config) + """ + # create init_kwargs from the config + assert "model_client" in config, "model_client is required in the config" + return super().from_config(config) + + def _compose_model_kwargs(self, **model_kwargs) -> Dict: + r""" + The model configuration exclude the input itself. + Combine the default model, model_kwargs with the passed model_kwargs. + Example: + model_kwargs = {"temperature": 0.5, "model": "gpt-3.5-turbo"} + self.model_kwargs = {"model": "gpt-3.5-turbo"} + combine_kwargs(model_kwargs) => {"temperature": 0.5, "model": "gpt-3.5-turbo"} + + """ + combined_model_kwargs = self.model_kwargs.copy() + + if model_kwargs: + combined_model_kwargs.update(model_kwargs) + return combined_model_kwargs + + # TODO: use prompt_kwargs as users are already familiar with it + def print_prompt(self, **kwargs) -> str: + prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) + return prompt.print_prompt(**kwargs) + + def get_prompt(self, **kwargs) -> str: + prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) + return prompt.call(**kwargs) + + def _extra_repr(self) -> str: + prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) + s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}, prompt={prompt}" + return s + + def _post_call(self, completion: Any) -> GeneratorOutput: + r"""Get string completion and process it with the output_processors.""" + # parse chat completion will only fill the raw_response + output: GeneratorOutput = self.model_client.parse_chat_completion(completion) + # save the api response + output.api_response = completion + + # Now adding the data field to the output + data = output.raw_response + + # TODO implement support for synchronous iterator in the future + if self.output_processors: + if data: + try: + data = self.output_processors(data) + output.data = data + except Exception as e: + log.error(f"Error processing the output processors: {e}") + output.error = str(e) + else: + output.data = data + + return output + + async def _output_processing( + self, data: AsyncIterable, generator_output: GeneratorOutput + ) -> AsyncGenerator[Any, None]: + r"""Create an async generator from the async_iterable returned by model client and yield from them + apply output processors and store in generator output. + Consume the raw_response and yield each event. + Store the final output text in the generator output data field. + """ + + # the raw response of the generator output should be an async iterable + if not isinstance(data, AsyncIterable): + raise ValueError("The data is not an async iterable") + + final_output_text = "" + + # iterate over the events in the data and yield each event and store the final output text + # this assumes that the raw_response of the generator output stores a + # async iterable of events based on the OpenAI Responses API documentation for streaming + async for event in data: + log.debug(f"Raw event: {event!r}") + yield event + + if ResponseCompletedEvent and isinstance(event, ResponseCompletedEvent): + resp = event.response + log.debug(f"Response completed: {event.response.output_text}") + if getattr(resp, "output_text", None): + final_output_text = resp.output_text + + if self.output_processors: + if final_output_text: + try: + final_output = self.output_processors(final_output_text) + generator_output.data = final_output + except Exception as e: + log.error(f"Error processing the output processors: {e}") + generator_output.error = str(e) + else: + generator_output.data = None + + async def _async_post_call(self, completion: Any) -> GeneratorOutput: + r"""Get completion and depending on whether the client is streaming from the model client, create a GeneratorOutput or an AsyncGenerator. + when the client is not streaming, the post call return a GeneratorOutput where the final output after applying the output processors is stored under the data field. + When the client is streaming, the post call returns an Async Generator which will yield the events under the raw response and then also + the final Generator Output. The final Generator Output will have under data the final output after applying the output processors. + """ + # parse chat completion will only fill the raw_response + output: GeneratorOutput = self.model_client.parse_chat_completion(completion) + # save the api response + output.api_response = completion + + log.info( + f"Response from the Model Client Stream before being processed: {output.raw_response}" + ) + + # Now adding the data field to the output and setting as the raw response + data = output.raw_response + + # Handle async iterables from OpenAI Agent/Responses API streaming + # return an async generator + if isinstance(output.raw_response, AsyncIterable): + original_raw_response = ( + output.raw_response + ) # pass in the raw response to the output processing to avoid circular dependency + output.raw_response = self._output_processing(original_raw_response, output) + else: + # return a GeneratorOutput if the raw response is not an async iterable + # process the model client's final response with the output processors + log.info(f"Response from the Model Client before being processed: {data}") + if self.output_processors: + if data: + try: + data = self.output_processors(data) + output.data = data + except Exception as e: + log.error(f"Error processing the output processors: {e}") + output.error = str(e) + else: + output.data = data + log.info(f"Response from the Model Client after being processed: {data}") + + return output + + def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]: + r"""Prepare the input, prompt_kwargs, model_kwargs for the model call.""" + # 1. render the prompt from the template + prompt_str = self.get_prompt(**prompt_kwargs) + # prompt = Prompt(template=self.template, prompt_kwargs=self.prompt_kwargs) + + # prompt_str = prompt.call(**prompt_kwargs).strip() + + # 2. combine the model_kwargs with the default model_kwargs + composed_model_kwargs = self._compose_model_kwargs(**model_kwargs) + + max_tokens = composed_model_kwargs.get("max_tokens", None) + prompt_tokens = self._tokenizer.count_tokens(prompt_str) + self._estimated_token_count = prompt_tokens + use_prompt_str = prompt_str + if max_tokens is not None: + if prompt_tokens > max_tokens: + use_prompt_str = prompt_str[:max_tokens] + log.warning( + f"Prompt is too long: {prompt_tokens} tokens, max tokens: {max_tokens}. Truncated prompt to: {use_prompt_str}" + ) + # delete max_tokens from the model_kwargs + del composed_model_kwargs["max_tokens"] + + # 3. convert app's inputs to api inputs + api_kwargs = self.model_client.convert_inputs_to_api_kwargs( + # rename from input since input is a builtin object + input=use_prompt_str, + model_kwargs=composed_model_kwargs, + model_type=self.model_type, + ) + return api_kwargs + + def _model_client_call(self, api_kwargs: Dict, use_cache: bool = False) -> Any: + # call the model client + try: + # check the cache + index_content = json.dumps(api_kwargs) # + f"training: {self.training}" + if use_cache: + # print(f"check cache first: {no_cache}") + + cached_completion = self._check_cache(index_content) + if cached_completion is not None: + return cached_completion + + completion = self.model_client.call( + api_kwargs=api_kwargs, model_type=self.model_type + ) + + # prepare cache - skip caching for streaming responses which contain unpickleable threading objects + if use_cache and not api_kwargs.get("stream", False): + self._save_cache(index_content, completion) + return completion + except Exception as e: + log.error(f"Error calling the model: {e}") + raise e + + async def _async_model_client_call( + self, api_kwargs: Dict, use_cache: bool = False + ) -> Any: + # async call the model client with caching support + try: + # check the cache + index_content = json.dumps(api_kwargs) # + f"training: {self.training}" + if use_cache: + # Check cache first - cache operations are sync + cached_completion = self._check_cache(index_content) + if cached_completion is not None: + log.debug("Cache hit for async call") + return cached_completion + + completion = await self.model_client.acall( + api_kwargs=api_kwargs, model_type=self.model_type + ) + # save to cache - skip caching for streaming responses which contain unpickleable threading objects + if use_cache and not api_kwargs.get("stream", False): + self._save_cache(index_content, completion) + return completion + except Exception as e: + log.error(f"Error calling the model: {e}") + raise e + + ############################################################################################################## + ### Forward, backwards, teacher generator, create demo data instance, + # are for training and backpropagation + ############################################################################################################## + + def create_demo_data_instance( + self, + input_prompt_kwargs: Dict[str, Any], + output: GeneratorOutput, + id: Optional[str] = None, + ): + r"""Automatically create a demo data instance from the input and output of the generator. + Used to trace the demos for the demo paramter in the prompt_kwargs. + Part of the few-shot learning. + """ + from adalflow.core.base_data_class import DynamicDataClassFactory + + # map the input fields + demo_data = {"id": id, "score": None} # add score to trace the prediction score + demo_data_class_output_mapping, output_fields = self._get_default_mapping( + output + ) + + for k, v in input_prompt_kwargs.items(): + if isinstance(v, Parameter): + demo_data[k] = v.map_to_successor(self) + else: + demo_data[k] = v + # map the output fields + for key, value in demo_data_class_output_mapping.items(): + demo_data[key] = value(output) + + obj = DynamicDataClassFactory.from_dict(demo_data) + obj.set_input_fields([k for k in input_prompt_kwargs.keys()]) + obj.set_output_fields(output_fields) + if obj is None: + raise ValueError(f"Error creating the demo data instance:{demo_data}") + return obj + + def set_backward_engine(self, backward_engine: "BackwardEngine" = None): + if backward_engine is None: + backward_engine = BackwardEngine( + model_client=self.model_client, + model_kwargs=self.model_kwargs, + ) + if self.mock_output: + backward_engine.set_mock_output() + self.backward_engine = backward_engine + + def set_teacher_generator(self, teacher: "Generator" = None): + self._teacher = teacher + print(f"Teacher generator set: {self._teacher}, teacher {teacher}") + log.debug(f"Teacher generator set: {self._teacher}") + + # def set_data_map_func(self, map_func: Callable = None): + # def default_map_func(data: "GeneratorOutputType") -> str: + # return ( + # data.data + # if data.data + # else self.failure_message_to_backward_engine(data) + # ) + + # self.data_map_func = map_func or default_map_func + + # log.debug(f"Data map function set: {self.data_map_func}") + + # TODO: limit to only one demo parameter. + @staticmethod + def find_demo_parameter(prompt_kwargs: Dict) -> Optional[Parameter]: + from adalflow.optim.parameter import Parameter, ParameterType + + for p in prompt_kwargs.values(): + if isinstance(p, Parameter) and p.param_type == ParameterType.DEMOS: + return p + return None + + def forward( + self, + prompt_kwargs: Optional[ + Dict[str, Union[str, Parameter]] + ] = {}, # the input need to be passed to the prompt + model_kwargs: Optional[Dict] = {}, + id: Optional[str] = None, + ) -> "Parameter": + r"""Customized forward pass on top of the GradComponent forward method.""" + # 1. convert prompt_kwargs to parameter if it is not + for k, v in prompt_kwargs.items(): + if not isinstance(v, Parameter): + prompt_kwargs[k] = Parameter( + data=v, + name=f"{self.name}_{k}", + requires_opt=False, + param_type=ParameterType.INPUT, + data_id=id, + ) + + # 2. call the model + unwrapped_prompt_kwargs: Dict[str, Any] = {} + for k, v in prompt_kwargs.items(): + if isinstance(v, Parameter): + if v.param_type == ParameterType.INPUT: + v.data_id = id + unwrapped_prompt_kwargs[k] = v.map_to_successor(self) + else: + unwrapped_prompt_kwargs[k] = v + log.debug( + f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}" + ) + log.debug(f"prompt template: {self.template}") + + output: GeneratorOutputType = None + input_args = {} + if self.mock_output: + output = GeneratorOutput(data=self.mock_output_data) + else: + if self.teacher_mode and not isinstance(self, BackwardEngine): + if not self._teacher: + log.debug( + f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}" + ) + log.debug(f"names: {self.name}") + raise ValueError("Teacher generator is not set.") + log.info(f"Using teacher: {self._teacher}") + input_args = { + "prompt_kwargs": compose_model_kwargs( + self._teacher.prompt_kwargs, unwrapped_prompt_kwargs + ), + "model_kwargs": compose_model_kwargs( + self._teacher.model_kwargs, model_kwargs + ), + } + output = self._teacher.call(**input_args, id=id) + else: + input_args = { + "prompt_kwargs": compose_model_kwargs( + self.prompt_kwargs, unwrapped_prompt_kwargs + ), + "model_kwargs": compose_model_kwargs( + self.model_kwargs, model_kwargs + ), + } + + output = self.call(**input_args, id=id) + if not isinstance(output, GeneratorOutput): + raise ValueError( + f"Output should be of type GeneratorOutput, got {type(output)}" + ) + # 2. Generate a Parameter object from the output + combined_prompt_kwargs = compose_model_kwargs(self.prompt_kwargs, prompt_kwargs) + # if self.data_map_func is None: + # self.set_data_map_func() + + predecessors = [ + p for p in combined_prompt_kwargs.values() if isinstance(p, Parameter) + ] + + log.debug(f"Predecessors: {predecessors} for generator {self.name}") + + def data_to_prompt_map_fn(data: Parameter) -> str: + """GeneratorOutput will show the raw response instead of just the final data. + The backward engine and optimizer should look at all reasoning to decide the gradient. + """ + data: GeneratorOutput = data.data + if data.error is not None: + return f"Response: {data.raw_response} parsed with error: {data.error}" + return f" {data.raw_response}" + + # TODO: all parameter should just wrap the whole output. + # this is for training. + param_data = output + response: Parameter = OutputParameter( + data=param_data, + name=self.name + "_output", + role_desc=f"Output from (llm) {self.name}", + param_type=ParameterType.GENERATOR_OUTPUT, + data_id=id, + full_response=output, # the data structure + data_in_prompt=data_to_prompt_map_fn, + ) + response.set_predecessors(predecessors) + response.trace_forward_pass( + input_args=input_args, full_response=output, id=self.id, name=self.name + ) + # setattr(response, "full_response", output) + # *** special to the generator *** + response.trace_api_kwargs(api_kwargs=self._trace_api_kwargs) + # attach the demo to the demo parameter + # if self.tracing: + demo_param = self.find_demo_parameter(combined_prompt_kwargs) + + if demo_param: + if id is None: + raise ValueError( + "ID is required for tracing. Please pass it to your Geneartor call." + ) + + demo = self.create_demo_data_instance( + prompt_kwargs, + output, + id=id, + ) + demo_param.add_dataclass_to_trace(demo, is_teacher=self.teacher_mode) + else: + log.debug( + "No demo parameter found in the prompt_kwargs. You can not trace the demo data." + ) + + # **** end of the special to the generator **** + + # if not self.backward_engine: + # # self.set_backward_engine() + # log.debug(f"Backward engine: {self.backward_engine}") + + # attach a funtion to compute gradient for predecessors + log.debug(f"disable_backward_engine: {self._disable_backward_engine}") + + response.set_grad_fn( + BackwardContext( + backward_fn=self.backward, + backward_engine=self.backward_engine, + response=response, + prompt_kwargs=prompt_kwargs, + template=self.template, + prompt_str=self.get_prompt(**combined_prompt_kwargs), + disable_backward_engine=self._disable_backward_engine, + id=id, + ) + ) + return response + + def backward( + self, + response: Parameter, # the output of the forward pass + prompt_kwargs: Dict, + template: str, + prompt_str: str, + backward_engine: Optional["Generator"] = None, + id: Optional[str] = None, # the id of the input + disable_backward_engine: bool = False, + ) -> Parameter: + + log.info(f"Generator: Backward: {response.name}") + + backward_pass_setup = ( + backward_engine.backward_pass_setup if backward_engine else None + ) + log.debug( + f"backward pass setup: {backward_pass_setup}, name: {self.name}", + color="red", + ) + + children_params = response.predecessors + is_intermediate_node = True + if response.get_gradient_and_context_text().strip() == "": + log.info(f"Generator: Backward: No gradient found for {response}.") + + # backward score to the demo parameter + for pred in children_params: + # if pred.requires_opt: + if response.score is not None: + pred.set_score(response.score) + log.debug( + f"backpropagate the score {response.score} to {pred.name}, is_teacher: {self.teacher_mode}" + ) + if pred.param_type == ParameterType.DEMOS: + # Accumulate the score to the demo + pred.add_score_to_trace( + trace_id=id, score=response.score, is_teacher=self.teacher_mode + ) + log.debug(f"Pred: {pred.name}, traces: {pred._traces}") + + # 1.backward for text-gradients + if backward_engine: + log.debug( + f"Generator: Backward engine is set for the generator. {backward_engine}" + ) + # if response.backward_engine_disabled: + # for pred in children_params: + # pred.backward_engine_disabled = True + # return + + all_pred_at_once = backward_pass_setup.all_pred_at_once + + if not all_pred_at_once: + for pred in children_params: + if not pred.requires_opt or pred.param_type == ParameterType.DEMOS: + log.debug( + f"EvalFnToTextLoss: Skipping {pred} as it does not require optimization." + ) + continue + + self._backward_through_one_predecessor( + pred=pred, + response=response, + prompt_kwargs=prompt_kwargs, + # template=template, + backward_engine=backward_engine, + prompt_str=prompt_str, + backward_pass_setup=backward_pass_setup, + is_intermediate_node=is_intermediate_node, + disable_backward_engine=disable_backward_engine, + ) + else: + backward = False + for pred in children_params: + if pred.requires_opt and pred.param_type in [ + ParameterType.PROMPT, + ParameterType.GENERATOR_OUTPUT, + ParameterType.RETRIEVER_OUTPUT, + ParameterType.OUTPUT, + ]: + backward = True + break + if backward: + # 2nd approach, backward all that need opt at once. + self._backward_through_all_predecessors( + children_params=children_params, + response=response, + prompt_kwargs=prompt_kwargs, + template=template, + backward_engine=backward_engine, + prompt_str=prompt_str, + backward_pass_setup=backward_pass_setup, + is_intermediate_node=is_intermediate_node, + ) + else: + log.debug("Backward engine is not set for the generator. No text gradient.") + + @staticmethod + def _backward_through_all_predecessors( + children_params: List[Parameter], + response: Parameter, + prompt_kwargs: Dict[str, str], + backward_engine: "BackwardEngine", + backward_pass_setup: BackwardPassSetup, + is_intermediate_node: bool = False, + ): + parser = JsonParser() + # instruction and objective is the same for all the children + instruction_str, objective_str = None, None + + # 1. Generate the conversation input and output + input_prompt_kwargs = { + k: v.get_prompt_data() if isinstance(v, Parameter) else v + for k, v in prompt_kwargs.items() + } + + print(f"gt: {response.get_gt()}") + + # TODO: pass all the parameters and even the templates + conversation_prompt_kwargs = { + "input_value": input_prompt_kwargs, + "llm_output": response.get_prompt_data(), + } + + conversation_str = Prompt( + prompt_kwargs=conversation_prompt_kwargs, + template=LLM_CONVERSATION_TEMPLATE, + )() + + all_pred_info = Prompt( + prompt_kwargs={"variables": [p.get_param_info() for p in children_params]}, + template=ALL_PRED_INFO, + )() + + conv_ins_template = None # CONVERSATION_START_INSTRUCTION_BASE + obj_ins_template = OBJECTIVE_INSTRUCTION_BASE + if is_intermediate_node: # TODO: this will always be true + conv_ins_template = CONVERSATION_START_INSTRUCTION_CHAIN + obj_ins_template = OBJECTIVE_INSTRUCTION_CHAIN + response_gradient = response.get_gradients_str() + # response_gradient = response.get_gradients_component_schema() + # response_gradient = response.get_gradients_component_schema( + # skip_correct_sample=False + # ) + if not response_gradient: + raise ValueError( + f"Generator: No gradient found for {response}. Please check the response." + ) + + # replace variable and peers with all_pred_info + + instruction_str = Prompt( + template=conv_ins_template, + prompt_kwargs={ + "variable_and_peers_info": all_pred_info, + "conversation_str": conversation_str, + }, + )() + objective_str = Prompt( + template=obj_ins_template, + prompt_kwargs={ + "response_desc": response.role_desc, + "response_gradient": response_gradient, + "instruction_to_backward_engine": response.instruction_to_backward_engine, + }, + )() + + backward_engine_prompt_kwargs = { + "conversation_sec": instruction_str, + "objective_instruction_sec": objective_str, + "output_format_str": OUTPUT_INSTRUCTION, + } + + backward_engine_prompt_str = backward_engine.get_prompt( + **backward_engine_prompt_kwargs + ) + # print(f"Backward engine prompt: {backward_engine_prompt_str}") + + gradient_output: GeneratorOutput = None + response_gradient_list = [""] * len(children_params) + if ( + backward_pass_setup.compute_grad_for_errors_only + and response.score is not None + and float(response.score) + > backward_pass_setup.threshold_score_to_compute_grad_for_errors + ): + manual_response_1 = f"Eval score: {response.score}. No noticeable error." + response_gradient_list = [manual_response_1] * len(children_params) + raw_response = str(response_gradient_list) + gradient_output = GeneratorOutput( + data=response_gradient_list, raw_response=raw_response + ) + else: + + gradient_output: GeneratorOutput = backward_engine( + prompt_kwargs=backward_engine_prompt_kwargs + ) + if not isinstance(gradient_output, GeneratorOutput): + raise ValueError( + f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead." + ) + + # parse the list of gradients + + try: + response_gradient_list = parser.call(gradient_output.data) + except Exception as e: + log.error(f"Error parsing the response_gradient_list: {e}") + failure_message = backward_engine.failure_message_to_optimizer( + gradient_output + ) + if failure_message: + response_gradient_list = [failure_message] * len(children_params) + + log.debug(f"failure_message: {failure_message}") + + # computes gradient for each prompt predecessor + for i, pred in enumerate(children_params): + if not pred.requires_opt or pred.param_type == ParameterType.DEMOS: + log.debug( + f"Generator: Skipping {pred} as it does not require optimization." + ) + continue + + gradient_data = ( + response_gradient_list[i] + if response_gradient_list and len(response_gradient_list) > i + else "Failed to get the gradient." + ) + + var_gradient = Gradient( + data=gradient_data, + data_id=response.data_id, + score=response.score, # add score to gradient + from_response=response, + to_pred=pred, + ) + var_gradient.add_context( + GradientContext( + input_output=conversation_str, + response_desc=response.role_desc, + variable_desc=pred.role_desc, # the only difference for each pred + ) + ) + var_gradient.add_prompt(backward_engine_prompt_str) + pred.add_gradient(var_gradient) + if response.score is not None: + pred.set_score(response.score) + + @staticmethod + def _backward_through_one_predecessor( + pred: Parameter, + response: Parameter, + prompt_kwargs: Dict[str, str], + backward_engine: "BackwardEngine", + prompt_str: str, + backward_pass_setup: BackwardPassSetup, + is_intermediate_node: bool = False, + disable_backward_engine: bool = False, + ): + """Creating gradient/textual feedback for prompt type parameters.""" + if not pred.requires_opt: + if response.score is not None: + pred.set_score(response.score) + log.debug( + f"Generator: Skipping {pred} as it does not require optimization." + ) + return + + if pred.check_if_already_computed_gradient_respect_to(response.id): + log.debug( + f"Generator: Skipping {pred} as the gradient is already computed." + ) + + return + + if backward_engine is None: + log.error( + "EvalFnToTextLoss: backward_engine is required for text prompt optimization." + ) + raise ValueError( + "EvalFnToTextLoss: backward_engine is required for text prompt optimization." + ) + + instruction_str, objective_str = None, None + + # 1. Generate the conversation string + input_prompt_kwargs = { + k: v.get_prompt_data() if isinstance(v, Parameter) else v + for k, v in prompt_kwargs.items() + } + + conversation_prompt_kwargs = { + "input_value": input_prompt_kwargs, + "llm_output": response.get_prompt_data(), + "gt": response.get_gt(), + } + + conversation_str = Prompt( + prompt_kwargs=conversation_prompt_kwargs, + template=LLM_CONVERSATION_TEMPLATE, + )() + + variable_dict = pred.get_param_info() + + peers = [p.get_param_info() for p in pred.peers] + + variable_and_peers_info = Prompt( + prompt_kwargs={"variable": variable_dict, "peers": peers}, + template=VARIABLE_AND_PEERS_INFO, + )() + + # generator is almost always intermediate node + conv_ins_template = None # CONVERSATION_START_INSTRUCTION_BASE + obj_ins_template = OBJECTIVE_INSTRUCTION_BASE + if is_intermediate_node: # TODO: this will always be true + conv_ins_template = CONVERSATION_START_INSTRUCTION_CHAIN + obj_ins_template = OBJECTIVE_INSTRUCTION_CHAIN + response_gradient = response.get_gradients_str() + # response_gradient = response.get_gradients_component_schema() + if not response_gradient: + raise ValueError( + f"Generator: No gradient found for {response}. Please check the response. pred: {pred}" + ) + predecessors = [ + pred.get_param_info() + for pred in response.predecessors + if pred not in pred.peers + ] + instruction_str = Prompt( + template=conv_ins_template, + prompt_kwargs={ + "variable_and_peers_info": variable_and_peers_info, + "conversation_str": conversation_str, + "predecessors": predecessors, + }, + )() + log.info(f"Conversation start instruction base str: {instruction_str}") + objective_str = Prompt( + template=obj_ins_template, + prompt_kwargs={ + "response_desc": response.role_desc, + "response_gradient": response_gradient, + "instruction_to_backward_engine": pred.instruction_to_backward_engine, + }, + )() + + backward_engine_prompt_kwargs = { + "conversation_sec": instruction_str, + "objective_instruction_sec": objective_str, + } + backward_engine_prompt_str = backward_engine.get_prompt( + **backward_engine_prompt_kwargs + ) + # print(f"Backward engine prompt: {backward_engine_prompt_str}") + gradient_value = None + if not disable_backward_engine: + gradient_output: GeneratorOutput = None + if ( + backward_pass_setup.compute_grad_for_errors_only + and response.score is not None + and float(response.score) + > backward_pass_setup.threshold_score_to_compute_grad_for_errors + ): + log.debug( + f"EvalFnToTextLoss: Skipping {pred} as the score is high enough." + ) + # TODO: plus score descriptions + manual_response = f"Eval score: {response.score}. No noticeable error." + gradient_output = GeneratorOutput( + data=manual_response, raw_response=manual_response + ) + else: + + gradient_output: GeneratorOutput = backward_engine( + prompt_kwargs=backward_engine_prompt_kwargs + ) + prompt_str = backward_engine.get_prompt( # noqa F841 + **backward_engine_prompt_kwargs + ) + # printc(f"Backward engine prompt: {prompt_str}") + if not isinstance(gradient_output, GeneratorOutput): + raise ValueError( + f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead." + ) + # printc(f"Backward engine gradient: {gradient_output}") + + # USE this to trace each node's input and output, all nodes can be visualized + log.info( + f"Generator Backward Engine Prompt: {backward_engine.get_prompt( **backward_engine_prompt_kwargs)}" + ) + gradient_value = ( + gradient_output.data + or backward_engine.failure_message_to_optimizer(gradient_output) + ) + var_gradient = Gradient( + data=gradient_value, + data_id=response.data_id, + score=response.score, # add score to gradient + from_response=response, + to_pred=pred, + ) + # Component-level input and output. + var_gradient.add_context( + GradientContext( + input_output=conversation_str, + response_desc=response.role_desc, + variable_desc=pred.role_desc, # parameter_desc + ) + ) + var_gradient.add_prompt(backward_engine_prompt_str) + pred.add_gradient(var_gradient) + if response.score is not None: + pred.set_score(response.score) + + def _run_callbacks( + self, + output: GeneratorOutput, + input: Dict, + prompt_kwargs: Dict, + model_kwargs: Dict, + ): + self.trigger_callbacks( + "on_complete", + output=output, + input=input, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + ) + if output.error: + self.trigger_callbacks( + "on_failure", + output=output, + input=input, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + ) + else: + self.trigger_callbacks( + "on_success", + output=output, + input=input, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + ) + + def call( + self, + prompt_kwargs: Optional[Dict] = {}, # supports both str and parameter value + model_kwargs: Optional[Dict] = {}, # can take images = {input: }, tools=[{"type": "image_generation"}], for image generation + use_cache: Optional[bool] = None, + id: Optional[str] = None, + ) -> GeneratorOutputType: + r""" + Call the model_client by formatting prompt from the prompt_kwargs, + and passing the combined model_kwargs to the model client. + """ + prompt_str = self.get_prompt(**prompt_kwargs) + with generator_span( + generator_id="generator" + (id if id else ""), + model_kwargs=self._compose_model_kwargs(**model_kwargs), + prompt_kwargs=prompt_kwargs, + prompt_template_with_keywords=prompt_str, + ) as generator_span_data: + generation_time = time.time() + if self.mock_output: + return GeneratorOutput(data=self.mock_output_data, id=id, input=prompt_str) + + log.debug(f"prompt_kwargs: {prompt_kwargs}") + log.debug(f"model_kwargs: {model_kwargs}") + + api_kwargs = self._pre_call(prompt_kwargs, model_kwargs) + + log.debug(f"api_kwargs: {api_kwargs}") + output: GeneratorOutputType = None + # call the model client + + completion = None + use_cache = use_cache if use_cache is not None else self._use_cache + try: + completion = self._model_client_call( + api_kwargs=api_kwargs, use_cache=use_cache + ) + except Exception as e: + log.error(f"Error calling the model: {e}") + output = GeneratorOutput(error=str(e), id=id, input=prompt_str) + # process the completion + if completion is not None: + try: + log.debug(f"Entering _post_call with completion: {completion}") + output = self._post_call(completion) + except Exception as e: + log.error(f"Error processing the output: {e}") + output = GeneratorOutput( + raw_response=str(completion), error=str(e), id=id, input=prompt_str + ) + + # User only need to use one of them, no need to use them all. + output.id = id + output.input = prompt_str + self._run_callbacks( + output, + input=api_kwargs, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + ) + + log.info(f"output: {output}") + self._trace_api_kwargs = api_kwargs # tracing + + generator_span_data.span_data.update_attributes( + {"raw_response": output.raw_response} + ) + generator_span_data.span_data.update_attributes( + {"api_response": output.api_response} + ) + generator_span_data.span_data.update_attributes( + {"final_response": output.data} + ) + generator_span_data.span_data.update_attributes( + {"generation_time_in_seconds": time.time() - generation_time} + ) + generator_span_data.span_data.update_attributes( + {"token_usage": output.usage} + ) + generator_span_data.span_data.update_attributes({"api_kwargs": api_kwargs}) + + return output + + # TODO: training is not supported in async call yet + async def acall( + self, + prompt_kwargs: Optional[Dict] = {}, + model_kwargs: Optional[Dict] = {}, + use_cache: Optional[bool] = None, + id: Optional[str] = None, + ) -> Union[GeneratorOutputType, AsyncGenerator[GeneratorOutputType, None]]: + r"""Async call the model with the input and model_kwargs. + + :warning:: + Training is not supported in async call yet. + """ + prompt_str = self.get_prompt(**prompt_kwargs) + + with generator_span( + generator_id="generator" + (id if id else ""), + model_kwargs=self._compose_model_kwargs(**model_kwargs), + prompt_kwargs=prompt_kwargs, + prompt_template_with_keywords=prompt_str, + ) as generator_span_data: + + if self.mock_output: + generator_span_data.span_data.update_attributes( + {"final_response": self.mock_output_data} + ) + return GeneratorOutput(data=self.mock_output_data, id=id, input=prompt_str) + + generation_time = time.time() + log.info(f"prompt_kwargs: {prompt_kwargs}") + log.info(f"model_kwargs: {model_kwargs}") + + api_kwargs = self._pre_call(prompt_kwargs, model_kwargs) + output: GeneratorOutputType = None + # call the model client + completion = None + use_cache = use_cache if use_cache is not None else self._use_cache + + log.info(f"api_kwargs: {api_kwargs}") + + try: + completion = await self._async_model_client_call( + api_kwargs=api_kwargs, use_cache=use_cache + ) + except Exception as e: + log.error(f"Error calling the model: {e}") + output = GeneratorOutput(error=str(e), id=id, input=prompt_str) + + if completion is not None: + try: + # set ouput id in async post call instead + output = await self._async_post_call(completion) + except Exception as e: + log.error(f"Error processing the output: {e}") + output = GeneratorOutput( + raw_response=str(completion), error=str(e), id=id, input=prompt_str + ) + + # User only need to use one of them, no need to use them all. + output.id = id + output.input = prompt_str + log.info(f"output: {output}") + # if the output is not an async generator then set its id and run call backs + # TODO support when output is an Async Generator + if not isinstance(output, AsyncGenerator): + self._run_callbacks( + output, + input=api_kwargs, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + ) + self._trace_api_kwargs = api_kwargs # tracing + + # Update generator span attributes similar to call method + generator_span_data.span_data.update_attributes( + {"raw_response": output.raw_response if output else None} + ) + generator_span_data.span_data.update_attributes( + {"api_response": output.api_response if output else None} + ) + generator_span_data.span_data.update_attributes( + {"final_response": output.data if output else None} + ) + generator_span_data.span_data.update_attributes( + {"generation_time_in_seconds": time.time() - generation_time} + ) + generator_span_data.span_data.update_attributes( + {"token_usage": output.usage if output else None} + ) + generator_span_data.span_data.update_attributes({"api_kwargs": api_kwargs}) + + return output + + def __call__(self, *args, **kwargs) -> Union[GeneratorOutputType, Any]: + if self.training: + log.debug("Training mode") + return self.forward(*args, **kwargs) + else: + log.debug("Inference mode") + return self.call(*args, **kwargs) + + def _extra_repr(self) -> str: + # Create the string for model_kwargs + s = f"model_kwargs={self.model_kwargs}, " + + # Create the string for trainable prompt_kwargs + prompt_kwargs_repr = [ + k + for k, v in self.prompt_kwargs.items() + if isinstance(v, Parameter) and v.requires_opt + ] + + s += f"trainable_prompt_kwargs={prompt_kwargs_repr}" + return s + + def to_dict(self) -> Dict[str, Any]: + r"""Convert the generator to a dictionary.""" + # TODO: exclude default functions + return super().to_dict() + + @staticmethod + def failure_message_to_backward_engine( + gradient_response: GeneratorOutput, + ) -> Optional[str]: + response_value = None + if gradient_response.error or not gradient_response.data: + response_value = f"Error: {gradient_response.error}, Raw response: {gradient_response.raw_response}" + return response_value + + +class BackwardEngine(Generator): # it is a generator with defaule template + + __doc__ = """A Generator with a default template for the backward pass in auto-differentiation. + + As a component, the forward pass is simply the same as the call method. + So it will always return GeneratorOutputType instead of Parameter. + + If you want to customize the template, you can create your own backward engine. + Yet, we will forever keep the training mode to False for the backward engine. + This is achieved by making forward the same as call. + """ + + def __init__(self, **kwargs): + if kwargs is None: + kwargs = {} + kwargs["template"] = FEEDBACK_ENGINE_TEMPLATE + + super().__init__(**kwargs) + self.name = "BackwardEngine" + self.teacher_mode = False + + def call(self, **kwargs) -> GeneratorOutputType: + r"""Catch the rate limit error and raise it.""" + output = super().call(**kwargs) + if output and output.error is not None and "429" in output.error: + raise ValueError(f"Error in the backward engine: {output.error}") + return output + + def forward(self, **kwargs): + r"""Forward pass for the backward engine.""" + return self.call(**kwargs) + + @staticmethod + def failure_message_to_optimizer( + gradient_response: GeneratorOutput, + ) -> Optional[str]: + gradient_value_data = None + if gradient_response.error or not gradient_response.data: + gradient_value_data = f"The backward engine failed to compute the gradient. Raw response: {gradient_response.raw_response}, Error: {gradient_response.error}" + + return gradient_value_data + + +def create_teacher_generator( + student: Generator, + model_client: ModelClient, + model_kwargs: Dict[str, Any], + template: Optional[str] = None, +) -> Generator: + r"""Create a teacher generator from the student generator. + + Note: + Teacher generator will have no parameters. + If you want to keep it to be the same as the student, just create one each time your student has been updated. + Or else, task.parameters will list teacher parameters. + + Args: + student (Generator): The student generator. + model_client (ModelClient): The model client to use for the teacher generator. + model_kwargs (Dict[str, Any]): The model kwargs to pass to the model client. + name (str, optional): The name of the teacher generator. Defaults to "teacher". + + Returns: + Generator: The teacher generator. + """ + kwargs = student._kwargs.copy() + kwargs["model_client"] = model_client + kwargs["model_kwargs"] = model_kwargs + if template: + kwargs["template"] = template + kwargs["name"] = f"{student.name}_teacher" + + prompt_kwargs_str: Dict[str, str] = {} + for key, p in kwargs["prompt_kwargs"].items(): + if isinstance(p, Parameter): + prompt_kwargs_str[key] = str(p.data) + else: + prompt_kwargs_str[key] = p + kwargs["prompt_kwargs"] = prompt_kwargs_str + teacher = Generator( + **kwargs, + ) + return teacher + + +if __name__ == "__main__": + # test the generator with backward engine + # TODO: move this to external local tests before packaging + from adalflow.components.model_client import ( + GroqAPIClient, + OpenAIClient, + GoogleGenAIClient, + AnthropicAPIClient, + ) + from adalflow.utils import setup_env + from adalflow.core.model_client import ModelClient + + setup_env() + # log = get_logger(level="DEBUG") + llama3_model = { + "model_client": GroqAPIClient(), + "model_kwargs": { + "model": "llama-3.1-8b-instant", + }, + } + gpt_3_model = { + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "gpt-3.5-turbo", + }, + } + gemini_model = { + "model_client": GoogleGenAIClient(), + "model_kwargs": { + "model": "gemini-1.0-pro", + }, + } + claude_model = { + "model_client": AnthropicAPIClient(), + "model_kwargs": { + "model": "claude-3-opus-20240229", + "max_tokens": 100, + }, + } + from adalflow.tracing.generator_call_logger import GeneratorCallLogger + from functools import partial + + # setup the logger + call_logger = GeneratorCallLogger(save_dir="traces") + + def on_complete(output, input, prompt_kwargs, model_kwargs, logger_call: Callable): + logger_call( + output=output, + input=input, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + ) + + for model in [llama3_model, gpt_3_model, gemini_model, claude_model]: + generator = Generator(**model) + + teacher = create_teacher_generator(generator, **claude_model) + + call_logger.register_generator("generator", "generator_call") + # setup the callback + logger_call = partial(call_logger.log_call, name="generator") + generator.register_callback( + "on_complete", partial(on_complete, logger_call=logger_call) + ) + + output = generator( + prompt_kwargs={ + "input_str": "Hello, world!", + } + ) + break + + # test the backward engine + # TODO: test ollama and transformer client to update the change diff --git a/adalflow/tests/test_generator.py b/adalflow/tests/test_generator.py index b4e397bde..d277c84b7 100644 --- a/adalflow/tests/test_generator.py +++ b/adalflow/tests/test_generator.py @@ -1,317 +1,363 @@ -from unittest import IsolatedAsyncioTestCase -from unittest.mock import patch, Mock -import unittest -import os -import shutil -from pathlib import Path - -from openai.types import CompletionUsage -from openai.types.chat import ChatCompletion - -from adalflow.core.types import GeneratorOutput -from adalflow.core.generator import Generator - - -from adalflow.core.model_client import ModelClient -from adalflow.components.model_client.groq_client import GroqAPIClient -from adalflow.tracing import GeneratorStateLogger - - -class TestGenerator(IsolatedAsyncioTestCase): - def setUp(self): - # Assuming that OpenAIClient is correctly mocked and passed to Generator - with patch( - "adalflow.core.model_client.ModelClient", spec=ModelClient - ) as MockAPI: - mock_api_client = Mock(ModelClient) - MockAPI.return_value = mock_api_client - mock_api_client.call.return_value = "Generated text response" - - mock_api_client.parse_chat_completion.return_value = ( - "Generated text response" - ) - self.mock_api_client = mock_api_client - - self.generator = Generator(model_client=mock_api_client) - self.save_dir = "./tests/log" - self.project_name = "TestGenerator" - self.filename = "prompt_logger_test.json" - - def _clean_up(self): - dir_path = os.path.join(self.save_dir, self.project_name) - - # Use shutil.rmtree to remove the directory recursively - shutil.rmtree( - dir_path, ignore_errors=True - ) # ignore_errors will prevent throwing an error if the directory doesn't exist - - def test_generator_call(self): - prompt_kwargs = {"input_str": "Hello, world!"} - model_kwargs = {"model": "gpt-3.5-turbo"} - - output = self.generator.call( - prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs - ) - self.assertIsInstance(output, GeneratorOutput) - print(f"output: {output}") - # Verify GeneratorOutput has expected attributes - self.assertTrue(hasattr(output, "data")) - self.assertTrue(hasattr(output, "raw_response")) - # self.assertEqual(output.data, "Generated text response") - - def test_cache_path(self): - prompt_kwargs = {"input_str": "Hello, world!"} - model_kwargs = {"model": "phi3.5:latest"} - - self.test_generator = Generator( - model_client=self.mock_api_client, - prompt_kwargs=prompt_kwargs, - model_kwargs=model_kwargs, - use_cache=True, - ) - - # Convert the path to a string to avoid the TypeError - cache_path = self.test_generator.get_cache_path() - cache_path_str = str(cache_path) - - print(f"cache path: {cache_path}") - - # Check if the sanitized model string is in the cache path - self.assertIn("phi3_5_latest", cache_path_str) - - # Check if the cache path exists as a file (or directory, depending on your use case) - - self.assertTrue( - Path(cache_path).exists(), f"Cache path {cache_path_str} does not exist" - ) - - def test_generator_prompt_logger_first_record(self): - # prompt_kwargs = {"input_str": "Hello, world!"} - # model_kwargs = {"model": "gpt-3.5-turbo"} - generator = Generator(model_client=self.mock_api_client) - prompt_logger = GeneratorStateLogger( - save_dir=self.save_dir, - project_name=self.project_name, - filename=self.filename, - ) - prompt_logger.log_prompt(generator=generator, name="Test Generator") - # Check if the prompt is logged - self.assertTrue("Test Generator" in prompt_logger._trace_map) - self._clean_up() - - def test_generator_prompt_update(self): - self._clean_up() - generator = Generator(model_client=self.mock_api_client) - prompt_logger = GeneratorStateLogger( - save_dir=self.save_dir, - project_name=self.project_name, - filename=self.filename, - ) - prompt_logger.log_prompt(generator=generator, name="Test Generator") - print(f"""prompt_logger._trace_map: {prompt_logger._trace_map}""") - self.assertTrue("Test Generator" in prompt_logger._trace_map) - - # Update the prompt variable and value - # preset_prompt_kwargs = {"input_str": "Hello, updated world!"} - # generator = Generator( - # model_client=self.mock_api_client, prompt_kwargs=preset_prompt_kwargs - # ) - - # prompt_logger.log_prompt(generator=generator, name="Test Generator") - - # print(f"""preset_prompt_kwargs: {prompt_logger._trace_map["Test Generator"]}""") - # self.assertEqual( - # prompt_logger._trace_map["Test Generator"][1].prompt_states[ - # "prompt_kwargs" - # ]["input_str"], - # "Hello, updated world!", - # ) - - # update the template - # template = "Hello, {{ input_str }}!" - # generator = Generator(model_client=self.mock_api_client, template=template) - # prompt_logger.log_prompt(generator=generator, name="Test Generator") - # self.assertEqual( - # prompt_logger._trace_map["Test Generator"][2].prompt_states["template"], - # "Hello, {{ input_str }}!", - # ) - self._clean_up() - - -def getenv_side_effect(key): - # This dictionary can hold more keys and values as needed - env_vars = {"GROQ_API_KEY": "fake_api_key"} - return env_vars.get(key, None) # Returns None if key is not found - - -class TestGeneratorWithGroqClient(unittest.TestCase): - # @patch("os.getenv", side_effect=getenv_side_effect) - def setUp(self) -> None: - with patch( - "os.getenv", side_effect=getenv_side_effect - ): # Mock the environment variable - self.client = GroqAPIClient() - self.mock_response = { - "id": "cmpl-3Q8Z5J9Z1Z5z5", - "created": 1635820005, - "object": "chat.completion", - "model": "gpt-3.5-turbo", - "choices": [ - { - "message": { - "content": "Hello, world!", - "role": "assistant", - }, - "index": 0, - "finish_reason": "stop", - } - ], - "usage": CompletionUsage( - completion_tokens=10, prompt_tokens=20, total_tokens=30 - ), - } - self.mock_response = ChatCompletion(**self.mock_response) - - @patch.object(GroqAPIClient, "call") - def test_groq_client_call(self, mock_call): - # Mock the response - - mock_call.return_value = self.mock_response - - # Define prompt and model kwargs - prompt_kwargs = {"input_str": "Hello, world!"} - model_kwargs = {"model": "gpt-3.5-turbo"} - template = "Hello, {{ input_str }}!" - - # Initialize the Generator with the mocked client - generator = Generator(model_client=self.client, template=template) - - # Call the generator and get the output - output = generator.call(prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs) - - self.assertIsInstance(output, GeneratorOutput) - print(f"output groq: {output}") - # Verify GeneratorOutput structure for Groq client - self.assertTrue(hasattr(output, "data")) - self.assertTrue(hasattr(output, "raw_response")) - # self.assertEqual(output.data, "Generated text response") - - -class TestGeneratorIntegration(unittest.TestCase): - """Test Generator integration with Agent and Runner workflows.""" - - def setUp(self): - # Mock ModelClient for integration tests - with patch( - "adalflow.core.model_client.ModelClient", spec=ModelClient - ) as MockAPI: - mock_api_client = Mock(ModelClient) - MockAPI.return_value = mock_api_client - mock_api_client.call.return_value = "Integration test response" - mock_api_client.parse_chat_completion.return_value = ( - "Integration test response" - ) - self.mock_api_client = mock_api_client - - def test_generator_output_for_agent_planner(self): - """Test that Generator produces output suitable for Agent planner use.""" - from adalflow.components.output_parsers import JsonOutputParser - from adalflow.core.types import Function - - # Create a generator with Function output parser (like Agent planner) - output_parser = JsonOutputParser( - data_class=Function, - return_data_class=True, - include_fields=["thought", "name", "kwargs"], - ) - - generator = Generator( - model_client=self.mock_api_client, output_processors=output_parser - ) - - # Mock the model client to return a JSON-like response - self.mock_api_client.call.return_value = '{"thought": "I need to search", "name": "search", "kwargs": {"query": "test"}}' - - output = generator.call(prompt_kwargs={"input_str": "test query"}) - - # Verify output is GeneratorOutput - self.assertIsInstance(output, GeneratorOutput) - # Verify it can be used by Agent/Runner workflow - self.assertTrue(hasattr(output, "data")) - - def test_generator_template_integration(self): - """Test Generator with custom template like Agent uses.""" - template = ( - "System: You are a helpful assistant.\nUser: {{input_str}}\nAssistant:" - ) - - generator = Generator(model_client=self.mock_api_client, template=template) - - # Test that generator accepts template and can generate prompt - prompt = generator.get_prompt(input_str="Hello world") - self.assertIn("Hello world", prompt) - self.assertIn("System: You are a helpful assistant", prompt) - - # Test generation works with template - output = generator.call(prompt_kwargs={"input_str": "Hello world"}) - self.assertIsInstance(output, GeneratorOutput) - - def test_generator_async_capability(self): - """Test Generator async methods that Runner.acall uses.""" - - async def async_test(): - # Mock async call - async def async_mock_call(*args, **kwargs): - return "Async response" - - self.mock_api_client.acall = async_mock_call - - generator = Generator(model_client=self.mock_api_client) - - # Test async call - output = await generator.acall(prompt_kwargs={"input_str": "async test"}) - self.assertIsInstance(output, GeneratorOutput) - - import asyncio - - asyncio.run(async_test()) - - def test_generator_training_mode(self): - """Test Generator training mode that Agent.is_training() uses.""" - generator = Generator(model_client=self.mock_api_client) - - # Initially not in training mode - self.assertFalse(generator.training) - - # Set to training mode - generator.training = True - self.assertTrue(generator.training) - - # Can switch back - generator.training = False - self.assertFalse(generator.training) - - def test_generator_prompt_kwargs_persistence(self): - """Test Generator maintains prompt_kwargs like Agent planner needs.""" - initial_prompt_kwargs = { - "tools": "[tool1, tool2]", - "output_format_str": "JSON format", - "task_desc": "Agent task", - "max_steps": 10, - "step_history": [], - } - - generator = Generator( - model_client=self.mock_api_client, prompt_kwargs=initial_prompt_kwargs - ) - - # Verify prompt_kwargs are stored - self.assertEqual(generator.prompt_kwargs, initial_prompt_kwargs) - - # Test that additional kwargs can be passed to call - output = generator.call(prompt_kwargs={"input_str": "test", "current_step": 1}) - self.assertIsInstance(output, GeneratorOutput) - - -if __name__ == "__main__": - unittest.main() +from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch, Mock +import unittest +import os +import shutil +from pathlib import Path + +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion + +from adalflow.core.types import GeneratorOutput +from adalflow.core.generator import Generator + + +from adalflow.core.model_client import ModelClient +from adalflow.components.model_client.groq_client import GroqAPIClient +from adalflow.tracing import GeneratorStateLogger + + +class TestGenerator(IsolatedAsyncioTestCase): + def setUp(self): + # Assuming that OpenAIClient is correctly mocked and passed to Generator + with patch( + "adalflow.core.model_client.ModelClient", spec=ModelClient + ) as MockAPI: + mock_api_client = Mock(ModelClient) + MockAPI.return_value = mock_api_client + mock_api_client.call.return_value = "Generated text response" + + mock_api_client.parse_chat_completion.return_value = ( + "Generated text response" + ) + self.mock_api_client = mock_api_client + + self.generator = Generator(model_client=mock_api_client) + self.save_dir = "./tests/log" + self.project_name = "TestGenerator" + self.filename = "prompt_logger_test.json" + + def _clean_up(self): + dir_path = os.path.join(self.save_dir, self.project_name) + + # Use shutil.rmtree to remove the directory recursively + shutil.rmtree( + dir_path, ignore_errors=True + ) # ignore_errors will prevent throwing an error if the directory doesn't exist + + def test_generator_call(self): + prompt_kwargs = {"input_str": "Hello, world!"} + model_kwargs = {"model": "gpt-3.5-turbo"} + + output = self.generator.call( + prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs + ) + self.assertIsInstance(output, GeneratorOutput) + print(f"output: {output}") + # Verify GeneratorOutput has expected attributes + self.assertTrue(hasattr(output, "data")) + self.assertTrue(hasattr(output, "raw_response")) + # self.assertEqual(output.data, "Generated text response") + + def test_cache_path(self): + prompt_kwargs = {"input_str": "Hello, world!"} + model_kwargs = {"model": "phi3.5:latest"} + + self.test_generator = Generator( + model_client=self.mock_api_client, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + use_cache=True, + ) + + # Convert the path to a string to avoid the TypeError + cache_path = self.test_generator.get_cache_path() + cache_path_str = str(cache_path) + + print(f"cache path: {cache_path}") + + # Check if the sanitized model string is in the cache path + self.assertIn("phi3_5_latest", cache_path_str) + + # Check if the cache path exists as a file (or directory, depending on your use case) + + self.assertTrue( + Path(cache_path).exists(), f"Cache path {cache_path_str} does not exist" + ) + + def test_generator_prompt_logger_first_record(self): + # prompt_kwargs = {"input_str": "Hello, world!"} + # model_kwargs = {"model": "gpt-3.5-turbo"} + generator = Generator(model_client=self.mock_api_client) + prompt_logger = GeneratorStateLogger( + save_dir=self.save_dir, + project_name=self.project_name, + filename=self.filename, + ) + prompt_logger.log_prompt(generator=generator, name="Test Generator") + # Check if the prompt is logged + self.assertTrue("Test Generator" in prompt_logger._trace_map) + self._clean_up() + + def test_generator_prompt_update(self): + self._clean_up() + generator = Generator(model_client=self.mock_api_client) + prompt_logger = GeneratorStateLogger( + save_dir=self.save_dir, + project_name=self.project_name, + filename=self.filename, + ) + prompt_logger.log_prompt(generator=generator, name="Test Generator") + print(f"""prompt_logger._trace_map: {prompt_logger._trace_map}""") + self.assertTrue("Test Generator" in prompt_logger._trace_map) + + # Update the prompt variable and value + # preset_prompt_kwargs = {"input_str": "Hello, updated world!"} + # generator = Generator( + # model_client=self.mock_api_client, prompt_kwargs=preset_prompt_kwargs + # ) + + # prompt_logger.log_prompt(generator=generator, name="Test Generator") + + # print(f"""preset_prompt_kwargs: {prompt_logger._trace_map["Test Generator"]}""") + # self.assertEqual( + # prompt_logger._trace_map["Test Generator"][1].prompt_states[ + # "prompt_kwargs" + # ]["input_str"], + # "Hello, updated world!", + # ) + + # update the template + # template = "Hello, {{ input_str }}!" + # generator = Generator(model_client=self.mock_api_client, template=template) + # prompt_logger.log_prompt(generator=generator, name="Test Generator") + # self.assertEqual( + # prompt_logger._trace_map["Test Generator"][2].prompt_states["template"], + # "Hello, {{ input_str }}!", + # ) + self._clean_up() + + +def getenv_side_effect(key): + # This dictionary can hold more keys and values as needed + env_vars = {"GROQ_API_KEY": "fake_api_key"} + return env_vars.get(key, None) # Returns None if key is not found + + +class TestGeneratorWithGroqClient(unittest.TestCase): + # @patch("os.getenv", side_effect=getenv_side_effect) + def setUp(self) -> None: + with patch( + "os.getenv", side_effect=getenv_side_effect + ): # Mock the environment variable + self.client = GroqAPIClient() + self.mock_response = { + "id": "cmpl-3Q8Z5J9Z1Z5z5", + "created": 1635820005, + "object": "chat.completion", + "model": "gpt-3.5-turbo", + "choices": [ + { + "message": { + "content": "Hello, world!", + "role": "assistant", + }, + "index": 0, + "finish_reason": "stop", + } + ], + "usage": CompletionUsage( + completion_tokens=10, prompt_tokens=20, total_tokens=30 + ), + } + self.mock_response = ChatCompletion(**self.mock_response) + + @patch.object(GroqAPIClient, "call") + def test_groq_client_call(self, mock_call): + # Mock the response + + mock_call.return_value = self.mock_response + + # Define prompt and model kwargs + prompt_kwargs = {"input_str": "Hello, world!"} + model_kwargs = {"model": "gpt-3.5-turbo"} + template = "Hello, {{ input_str }}!" + + # Initialize the Generator with the mocked client + generator = Generator(model_client=self.client, template=template) + + # Call the generator and get the output + output = generator.call(prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs) + + self.assertIsInstance(output, GeneratorOutput) + print(f"output groq: {output}") + # Verify GeneratorOutput structure for Groq client + self.assertTrue(hasattr(output, "data")) + self.assertTrue(hasattr(output, "raw_response")) + # self.assertEqual(output.data, "Generated text response") + + +class TestGeneratorIntegration(unittest.TestCase): + """Test Generator integration with Agent and Runner workflows.""" + + def setUp(self): + # Mock ModelClient for integration tests + with patch( + "adalflow.core.model_client.ModelClient", spec=ModelClient + ) as MockAPI: + mock_api_client = Mock(ModelClient) + MockAPI.return_value = mock_api_client + mock_api_client.call.return_value = "Integration test response" + mock_api_client.parse_chat_completion.return_value = ( + "Integration test response" + ) + self.mock_api_client = mock_api_client + + def test_generator_output_for_agent_planner(self): + """Test that Generator produces output suitable for Agent planner use.""" + from adalflow.components.output_parsers import JsonOutputParser + from adalflow.core.types import Function + + # Create a generator with Function output parser (like Agent planner) + output_parser = JsonOutputParser( + data_class=Function, + return_data_class=True, + include_fields=["thought", "name", "kwargs"], + ) + + generator = Generator( + model_client=self.mock_api_client, output_processors=output_parser + ) + + # Mock the model client to return a JSON-like response + self.mock_api_client.call.return_value = '{"thought": "I need to search", "name": "search", "kwargs": {"query": "test"}}' + + output = generator.call(prompt_kwargs={"input_str": "test query"}) + + # Verify output is GeneratorOutput + self.assertIsInstance(output, GeneratorOutput) + # Verify it can be used by Agent/Runner workflow + self.assertTrue(hasattr(output, "data")) + + def test_generator_template_integration(self): + """Test Generator with custom template like Agent uses.""" + template = ( + "System: You are a helpful assistant.\nUser: {{input_str}}\nAssistant:" + ) + + generator = Generator(model_client=self.mock_api_client, template=template) + + # Test that generator accepts template and can generate prompt + prompt = generator.get_prompt(input_str="Hello world") + self.assertIn("Hello world", prompt) + self.assertIn("System: You are a helpful assistant", prompt) + + # Test generation works with template + output = generator.call(prompt_kwargs={"input_str": "Hello world"}) + self.assertIsInstance(output, GeneratorOutput) + + def test_generator_async_capability(self): + """Test Generator async methods that Runner.acall uses.""" + + async def async_test(): + # Mock async call + async def async_mock_call(*args, **kwargs): + return "Async response" + + self.mock_api_client.acall = async_mock_call + + generator = Generator(model_client=self.mock_api_client) + + # Test async call + output = await generator.acall(prompt_kwargs={"input_str": "async test"}) + self.assertIsInstance(output, GeneratorOutput) + + import asyncio + + asyncio.run(async_test()) + + def test_generator_training_mode(self): + """Test Generator training mode that Agent.is_training() uses.""" + generator = Generator(model_client=self.mock_api_client) + + # Initially not in training mode + self.assertFalse(generator.training) + + # Set to training mode + generator.training = True + self.assertTrue(generator.training) + + # Can switch back + generator.training = False + self.assertFalse(generator.training) + + def test_generator_prompt_kwargs_persistence(self): + """Test Generator maintains prompt_kwargs like Agent planner needs.""" + initial_prompt_kwargs = { + "tools": "[tool1, tool2]", + "output_format_str": "JSON format", + "task_desc": "Agent task", + "max_steps": 10, + "step_history": [], + } + + generator = Generator( + model_client=self.mock_api_client, prompt_kwargs=initial_prompt_kwargs + ) + + # Verify prompt_kwargs are stored + self.assertEqual(generator.prompt_kwargs, initial_prompt_kwargs) + + # Test that additional kwargs can be passed to call + output = generator.call(prompt_kwargs={"input_str": "test", "current_step": 1}) + self.assertIsInstance(output, GeneratorOutput) + + +class TestGetDefaultMapping(unittest.TestCase): + """Test Generator._get_default_mapping handles all edge cases.""" + + def test_data_with_output_fields(self): + """When output.data is a DataClass with output fields, mapping should use them.""" + from adalflow.core.base_data_class import DataClass + from adalflow.core.types import GeneratorOutput + + class SampleOutput(DataClass): + answer: str = "" + score: float = 0.0 + + SampleOutput.set_output_fields(["answer", "score"]) + data = SampleOutput() + data.answer = "test" + data.score = 0.9 + output = GeneratorOutput(data=data) + mapping, fields = Generator._get_default_mapping(output) + self.assertIn("answer", fields) + self.assertIn("score", fields) + self.assertIn("answer", mapping) + self.assertIn("score", mapping) + + def test_raw_response_only(self): + """When only raw_response is present, mapping should map 'Example' to raw_response.""" + output = GeneratorOutput(raw_response="some text") + mapping, fields = Generator._get_default_mapping(output) + self.assertEqual(fields, ["Answer"]) + self.assertIn("Example", mapping) + + def test_both_data_and_raw_response_none(self): + """When both data and raw_response are None (API call failed), + should return empty mapping instead of raising UnboundLocalError.""" + output = GeneratorOutput(data=None, raw_response=None, error="API call failed") + mapping, fields = Generator._get_default_mapping(output) + self.assertEqual(fields, []) + self.assertEqual(mapping, {}) + + def test_data_none_raw_response_empty_string(self): + """When raw_response is an empty string (falsy), should fall through to the else branch.""" + output = GeneratorOutput(data=None, raw_response="") + mapping, fields = Generator._get_default_mapping(output) + self.assertEqual(fields, []) + self.assertEqual(mapping, {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py index e0b50a394..b36942c60 100644 --- a/adalflow/tests/test_openai_client.py +++ b/adalflow/tests/test_openai_client.py @@ -1,766 +1,811 @@ -import unittest -from unittest.mock import patch, AsyncMock, Mock, MagicMock -import os -import base64 - -from openai.types import Image -from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseTextDeltaEvent, - ResponseUsage, -) - -from adalflow.core.types import ModelType, GeneratorOutput -from adalflow.components.model_client.openai_client import OpenAIClient -from adalflow.components.model_client.utils import extract_text_from_response_stream -import asyncio - - -def getenv_side_effect(key): - # This dictionary can hold more keys and values as needed - env_vars = {"OPENAI_API_KEY": "fake_api_key"} - return env_vars.get(key, None) # Returns None if key is not found - - -class TestOpenAIClient(unittest.IsolatedAsyncioTestCase): - def setUp(self): - self.client = OpenAIClient(api_key="fake_api_key") - # Create a mock Response object with the required fields - self.mock_response = Mock(spec=Response) - self.mock_response.id = "resp-3Q8Z5J9Z1Z5z5" - self.mock_response.created_at = 1635820005.0 - self.mock_response.model = "gpt-4o" - self.mock_response.object = "response" - self.mock_response.output_text = "Hello, world!" - self.mock_response.usage = ResponseUsage( - input_tokens=20, - output_tokens=10, - total_tokens=30, - input_tokens_details={"cached_tokens": 0}, - output_tokens_details={"reasoning_tokens": 0}, - ) - self.mock_vision_response = Mock(spec=Response) - self.mock_vision_response.id = "resp-4Q8Z5J9Z1Z5z5" - self.mock_vision_response.created_at = 1635820005.0 - self.mock_vision_response.model = "gpt-4o" - self.mock_vision_response.object = "response" - self.mock_vision_response.output_text = ( - "The image shows a beautiful sunset over mountains." - ) - self.mock_vision_response.usage = ResponseUsage( - input_tokens=25, - output_tokens=15, - total_tokens=40, - input_tokens_details={"cached_tokens": 0}, - output_tokens_details={"reasoning_tokens": 0}, - ) - self.mock_image_response = [ - Image( - url="https://example.com/generated_image.jpg", - b64_json=None, - revised_prompt="A white siamese cat sitting elegantly", - model="dall-e-3", - ) - ] - self.api_kwargs = { - "input": "Hello", - "model": "gpt-4o", - } - self.vision_api_kwargs = { - "input": "Describe this image: https://example.com/image.jpg", - "model": "gpt-4o", - } - self.image_generation_kwargs = { - "model": "dall-e-3", - "prompt": "a white siamese cat", - "size": "1024x1024", - "quality": "standard", - "n": 1, - } - - # Add streaming test data for response API using Mock objects - mock_delta_event1 = Mock(spec=ResponseTextDeltaEvent) - mock_delta_event1.type = "response.output_text.delta" - mock_delta_event1.delta = "Once " - - mock_delta_event2 = Mock(spec=ResponseTextDeltaEvent) - mock_delta_event2.type = "response.output_text.delta" - mock_delta_event2.delta = "upon " - - mock_response_obj = Mock(spec=Response) - mock_response_obj.id = "resp-123" - mock_response_obj.created_at = 1635820005.0 - mock_response_obj.model = "gpt-4" - mock_response_obj.object = "response" - mock_response_obj.output_text = "Once upon " - mock_response_obj.usage = ResponseUsage( - input_tokens=10, - output_tokens=2, - total_tokens=12, - input_tokens_details={"cached_tokens": 0}, - output_tokens_details={"reasoning_tokens": 0}, - ) - - mock_completed_event = Mock(spec=ResponseCompletedEvent) - mock_completed_event.type = "response.completed" - mock_completed_event.response = mock_response_obj - - self.streaming_events = [ - mock_delta_event1, - mock_delta_event2, - mock_completed_event, - ] - - def test_encode_image(self): - # Create a temporary test image file - test_image_path = "test_image.jpg" - test_content = b"fake image content" - try: - with open(test_image_path, "wb") as f: - f.write(test_content) - - # Test successful encoding - encoded = self.client._encode_image(test_image_path) - self.assertEqual(encoded, base64.b64encode(test_content).decode("utf-8")) - - # Test file not found - with self.assertRaises(ValueError) as context: - self.client._encode_image("nonexistent.jpg") - self.assertIn("Image file not found", str(context.exception)) - - finally: - # Cleanup - if os.path.exists(test_image_path): - os.remove(test_image_path) - - def test_prepare_image_content(self): - # Test URL image - url = "https://example.com/image.jpg" - result = self.client._prepare_image_content(url) - self.assertEqual( - result, - {"type": "image_url", "image_url": {"url": url, "detail": "auto"}}, - ) - - # Test with custom detail level - result = self.client._prepare_image_content(url, detail="high") - self.assertEqual( - result, - {"type": "image_url", "image_url": {"url": url, "detail": "high"}}, - ) - - # Test with pre-formatted content - pre_formatted = { - "type": "image_url", - "image_url": {"url": url, "detail": "low"}, - } - result = self.client._prepare_image_content(pre_formatted) - self.assertEqual(result, pre_formatted) - - def test_convert_inputs_to_api_kwargs_with_images(self): - # Test with single image URL - Response API uses message format for multimodal - model_kwargs = { - "model": "gpt-4o", - "images": "https://example.com/image.jpg", - } - result = self.client.convert_inputs_to_api_kwargs( - input="Describe this image", - model_kwargs=model_kwargs, - model_type=ModelType.LLM, - ) - print(result) - # Response API expects message format with content array for multimodal - expected_input = [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "Describe this image"}, - {"type": "input_image", "image_url": "https://example.com/image.jpg"} - ] - } - ] - self.assertEqual(result["input"], expected_input) - self.assertEqual(result["model"], "gpt-4o") - - # Test with multiple images - Response API uses message format for multimodal - model_kwargs = { - "model": "gpt-4o", - "images": [ - "https://example.com/image1.jpg", - "https://example.com/image2.jpg", - ], - "detail": "high", - } - result = self.client.convert_inputs_to_api_kwargs( - input="Compare these images", - model_kwargs=model_kwargs, - model_type=ModelType.LLM, - ) - # Response API expects message format with content array for multimodal - expected_input = [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "Compare these images"}, - {"type": "input_image", "image_url": "https://example.com/image1.jpg"}, - {"type": "input_image", "image_url": "https://example.com/image2.jpg"} - ] - } - ] - self.assertEqual(result["input"], expected_input) - self.assertEqual(result["model"], "gpt-4o") - - @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") - async def test_acall_llm(self, MockAsyncOpenAI): - mock_async_client = AsyncMock() - MockAsyncOpenAI.return_value = mock_async_client - - # Mock the response - - mock_async_client.responses.create = AsyncMock(return_value=self.mock_response) - - # Call the _acall method - - result = await self.client.acall( - api_kwargs=self.api_kwargs, model_type=ModelType.LLM - ) - - # Assertions - MockAsyncOpenAI.assert_called_once() - mock_async_client.responses.create.assert_awaited_once_with(**self.api_kwargs) - self.assertEqual(result, self.mock_response) - - @patch( - "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" - ) - @patch("adalflow.components.model_client.openai_client.OpenAI") - def test_call(self, MockSyncOpenAI, mock_init_sync_client): - mock_sync_client = Mock() - MockSyncOpenAI.return_value = mock_sync_client - mock_init_sync_client.return_value = mock_sync_client - - # Mock the client's api: responses.create - mock_sync_client.responses.create = Mock(return_value=self.mock_response) - - # Set the sync client - self.client.sync_client = mock_sync_client - - # Call the call method - result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) - - # Assertions - mock_sync_client.responses.create.assert_called_once_with(**self.api_kwargs) - self.assertEqual(result, self.mock_response) - - # test parse_response - output = self.client.parse_chat_completion(completion=self.mock_response) - self.assertTrue(isinstance(output, GeneratorOutput)) - self.assertEqual(output.raw_response, "Hello, world!") - self.assertEqual(output.usage.output_tokens, 10) - self.assertEqual(output.usage.input_tokens, 20) - self.assertEqual(output.usage.total_tokens, 30) - - @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") - async def test_acall_llm_with_vision(self, MockAsyncOpenAI): - mock_async_client = AsyncMock() - MockAsyncOpenAI.return_value = mock_async_client - - # Mock the vision model response - mock_async_client.responses.create = AsyncMock( - return_value=self.mock_vision_response - ) - - # Call the _acall method with vision model - result = await self.client.acall( - api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM - ) - - # Assertions - MockAsyncOpenAI.assert_called_once() - mock_async_client.responses.create.assert_awaited_once_with( - **self.vision_api_kwargs - ) - self.assertEqual(result, self.mock_vision_response) - - @patch( - "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" - ) - @patch("adalflow.components.model_client.openai_client.OpenAI") - def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client): - mock_sync_client = Mock() - MockSyncOpenAI.return_value = mock_sync_client - mock_init_sync_client.return_value = mock_sync_client - - # Mock the vision model response - mock_sync_client.responses.create = Mock(return_value=self.mock_vision_response) - - # Set the sync client - self.client.sync_client = mock_sync_client - - # Call the call method with vision model - result = self.client.call( - api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM - ) - - # Assertions - mock_sync_client.responses.create.assert_called_once_with( - **self.vision_api_kwargs - ) - self.assertEqual(result, self.mock_vision_response) - - # Test parse_response for vision model - output = self.client.parse_chat_completion(completion=self.mock_vision_response) - self.assertTrue(isinstance(output, GeneratorOutput)) - self.assertEqual( - output.raw_response, "The image shows a beautiful sunset over mountains." - ) - self.assertEqual(output.usage.output_tokens, 15) - self.assertEqual(output.usage.input_tokens, 25) - self.assertEqual(output.usage.total_tokens, 40) - - def test_from_dict_to_dict(self): - test_api_key = "fake_api" - client = OpenAIClient(api_key=test_api_key) - client_dict = client.to_dict() - new_client = OpenAIClient.from_dict(client_dict) - self.assertEqual(new_client.to_dict(), client_dict) - - @patch("adalflow.components.model_client.openai_client.OpenAI") - def test_init_sync_client_with_headers_and_organization(self, MockOpenAI): - headers = {"Custom-Header": "CustomValue"} - organization = "test-organization" - - # First call happens during __init__ - client = OpenAIClient( - api_key="fake_api_key", - headers=headers, - organization=organization, - ) - - # Clear previous calls so we only test the explicit one below - MockOpenAI.reset_mock() - - # Now call init_sync_client explicitly to trigger the OpenAI call - _ = client.init_sync_client() - - # Assert OpenAI was called with correct parameters - MockOpenAI.assert_called_once_with( - api_key="fake_api_key", - base_url="https://api.openai.com/v1/", - organization=organization, - default_headers=headers, - ) - - @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") - async def test_init_async_client_with_headers_and_organization( - self, MockAsyncOpenAI - ): - headers = {"Custom-Header": "CustomValue"} - organization = "test-organization" - - # Manually assign an AsyncMock to the return value - mock_async_client = AsyncMock() - MockAsyncOpenAI.return_value = mock_async_client - - client = OpenAIClient( - api_key="fake_api_key", - headers=headers, - organization=organization, - ) - - async_client = client.init_async_client() # Do NOT await here - - MockAsyncOpenAI.assert_called_once_with( - api_key="fake_api_key", - base_url="https://api.openai.com/v1/", - organization=organization, - default_headers=headers, - ) - self.assertEqual(async_client, mock_async_client) - - @patch("adalflow.components.model_client.openai_client.OpenAI") - def test_call_with_custom_headers_and_organization(self, MockOpenAI): - # Test that headers and organization are passed during a call - headers = {"Custom-Header": "CustomValue"} - organization = "test-organization" - mock_sync_client = Mock() - MockOpenAI.return_value = mock_sync_client - - client = OpenAIClient( - api_key="fake_api_key", - headers=headers, - organization=organization, - ) - client.sync_client = mock_sync_client - - # Mock the API call - mock_sync_client.responses.create = Mock(return_value=self.mock_response) - - # Call the method - result = client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) - - # Assertions - mock_sync_client.responses.create.assert_called_once_with(**self.api_kwargs) - self.assertEqual(result, self.mock_response) - - @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") - async def test_acall_with_custom_headers_and_organization(self, MockAsyncOpenAI): - # Test that headers and organization are passed during an async call - headers = {"Custom-Header": "CustomValue"} - organization = "test-organization" - mock_async_client = AsyncMock() - MockAsyncOpenAI.return_value = mock_async_client - - client = OpenAIClient( - api_key="fake_api_key", - headers=headers, - organization=organization, - ) - client.async_client = mock_async_client - - # Mock the API call - mock_async_client.responses.create = AsyncMock(return_value=self.mock_response) - - # Call the method - result = await client.acall( - api_kwargs=self.api_kwargs, model_type=ModelType.LLM - ) - - # Assertions - mock_async_client.responses.create.assert_awaited_once_with(**self.api_kwargs) - self.assertEqual(result, self.mock_response) - - async def test_async_streaming(self): - """Test the async streaming method for OpenAIClient.""" - # Setup mock - mock_async_client = AsyncMock() - - # Create an async generator for the mock stream - async def mock_stream(): - for event in self.streaming_events: - yield event - await asyncio.sleep(0.01) - - mock_async_client.responses.create.return_value = mock_stream() - self.client.async_client = mock_async_client - - # Test API kwargs for streaming - Response API uses input as string - api_kwargs = { - "model": "gpt-4", - "input": "You are a helpful assistant. Tell me a short story.", - "stream": True, - "max_tokens": 200, - } - - # Call the async streaming method - stream = await self.client.acall(api_kwargs, ModelType.LLM) - - # Verify the streaming parser is set - self.assertEqual( - self.client.response_parser, - self.client.streaming_response_parser, - ) - - # Process the stream - full_response = "" - async for event in stream: - if hasattr(event, "delta"): # Mock ResponseTextDeltaEvent - full_response += event.delta - elif hasattr(event, "response") and hasattr( - event.response, "output_text" - ): # Mock ResponseCompletedEvent - full_response = event.response.output_text - - # Verify the response - self.assertIn("Once upon", full_response) - - # Verify the API was called correctly - mock_async_client.responses.create.assert_called_once_with(**api_kwargs) - - async def test_parser_switching(self): - """Test that parser switching works correctly.""" - # Initially should be non-streaming parser - self.assertEqual( - self.client.response_parser, - self.client.non_streaming_response_parser, - ) - - # Setup mock for streaming call - mock_async_client = AsyncMock() - - async def mock_stream(): - yield self.streaming_events[0] - - mock_async_client.responses.create.return_value = mock_stream() - self.client.async_client = mock_async_client - - # Test streaming call - should switch to streaming parser - await self.client.acall( - {"model": "gpt-4", "input": "Hello", "stream": True}, ModelType.LLM - ) - self.assertEqual( - self.client.response_parser, - self.client.streaming_response_parser, - ) - - # Test non-streaming call - should switch back to non-streaming parser - mock_async_client.responses.create.return_value = self.mock_response - await self.client.acall( - {"model": "gpt-4", "input": "Hello", "stream": False}, ModelType.LLM - ) - self.assertEqual( - self.client.response_parser, - self.client.non_streaming_response_parser, - ) - - def test_reasoning_model_response(self): - """Test parsing of reasoning model responses with reasoning field.""" - # Create a mock Response with reasoning output - mock_reasoning_response = Mock(spec=Response) - mock_reasoning_response.id = "resp-reasoning-123" - mock_reasoning_response.created_at = 1635820005.0 - mock_reasoning_response.model = "o1" - mock_reasoning_response.object = "response" - mock_reasoning_response.output_text = None # Reasoning models may not have output_text - - # Mock output array with reasoning and message - mock_reasoning_item = Mock() - mock_reasoning_item.type = "reasoning" - mock_reasoning_item.id = "rs_123" - mock_reasoning_item.summary = [ - Mock(type="summary_text", text="I'm thinking about the problem step by step...") - ] - - mock_message_item = Mock() - mock_message_item.type = "message" - mock_message_item.content = [ - Mock(type="output_text", text="The answer is 42.") - ] - - mock_reasoning_response.output = [mock_reasoning_item, mock_message_item] - mock_reasoning_response.usage = ResponseUsage( - input_tokens=50, - output_tokens=100, - total_tokens=150, - input_tokens_details={"cached_tokens": 0}, - output_tokens_details={"reasoning_tokens": 80}, - ) - - # Parse the response - result = self.client.parse_chat_completion(mock_reasoning_response) - - # Assertions - self.assertIsInstance(result, GeneratorOutput) - self.assertEqual(result.data, "The answer is 42.") - self.assertIsNotNone(result.thinking) - self.assertIn("thinking about the problem", result.thinking) - # Check that usage was properly tracked - self.assertEqual(result.usage.output_tokens, 100) - self.assertEqual(result.usage.total_tokens, 150) - - def test_multimodal_input_with_images(self): - """Test multimodal input with images (vision models).""" - # Test with URL image - url_kwargs = self.client.convert_inputs_to_api_kwargs( - input="What's in this image?", - model_kwargs={ - "model": "gpt-4o", - "images": "https://example.com/image.jpg" - }, - model_type=ModelType.LLM - ) - - # Should format as message with content array - self.assertIn("input", url_kwargs) - self.assertIsInstance(url_kwargs["input"], list) - self.assertEqual(url_kwargs["input"][0]["role"], "user") - content = url_kwargs["input"][0]["content"] - self.assertIsInstance(content, list) - - # Check text content - text_content = next((c for c in content if c["type"] == "input_text"), None) - self.assertIsNotNone(text_content) - self.assertEqual(text_content["text"], "What's in this image?") - - # Check image content - image_content = next((c for c in content if c["type"] == "input_image"), None) - self.assertIsNotNone(image_content) - self.assertEqual(image_content["image_url"], "https://example.com/image.jpg") - - # Test with base64 image - base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" - base64_kwargs = self.client.convert_inputs_to_api_kwargs( - input="Describe this image", - model_kwargs={ - "model": "gpt-4o", - "images": f"data:image/png;base64,{base64_image}" - }, - model_type=ModelType.LLM - ) - - # Check base64 image content - content = base64_kwargs["input"][0]["content"] - image_content = next((c for c in content if c["type"] == "input_image"), None) - self.assertIsNotNone(image_content) - self.assertTrue(image_content["image_url"].startswith("data:image/png;base64,")) - - def test_image_generation_response(self): - """Test parsing of image generation responses.""" - # Create a mock Response with image generation output - mock_image_response = Mock(spec=Response) - mock_image_response.id = "resp-img-gen-123" - mock_image_response.created_at = 1635820005.0 - mock_image_response.model = "gpt-4o" - mock_image_response.object = "response" - mock_image_response.output_text = None - - # Mock output array with image generation call - mock_image_item = Mock() - mock_image_item.type = "image_generation_call" - mock_image_item.result = "base64_encoded_image_data_here" - - mock_message_item = Mock() - mock_message_item.type = "message" - mock_message_item.content = [ - Mock(type="output_text", text="I've generated an image of a cat for you.") - ] - - mock_image_response.output = [mock_image_item, mock_message_item] - mock_image_response.usage = ResponseUsage( - input_tokens=30, - output_tokens=50, - total_tokens=80, - input_tokens_details={"cached_tokens": 0}, - output_tokens_details={"reasoning_tokens": 0}, - ) - - # Parse the response - result = self.client.parse_chat_completion(mock_image_response) - - # Assertions - self.assertIsInstance(result, GeneratorOutput) - self.assertEqual(result.data, "I've generated an image of a cat for you.") - self.assertIsNotNone(result.images) - self.assertEqual(result.images, ["base64_encoded_image_data_here"]) - - - def test_streaming_with_helper_function(self): - """Test streaming response with text extraction helper.""" - # Create streaming events with proper structure - event1 = Mock() - event1.type = "response.created" - - event2 = Mock() - event2.type = "response.output_text.delta" - event2.delta = "Hello " - - event3 = Mock() - event3.type = "response.output_text.delta" - event3.delta = "world!" - - event4 = Mock() - event4.type = "response.done" - - events = [event1, event2, event3, event4] - - # Test text extraction - extracted_text = [] - for event in events: - text = extract_text_from_response_stream(event) - if text: - extracted_text.append(text) - - # Assertions - self.assertEqual(extracted_text, ["Hello ", "world!"]) - self.assertEqual("".join(extracted_text), "Hello world!") - - async def test_reasoning_model_streaming(self): - """Test streaming with reasoning model responses.""" - # Setup mock - mock_async_client = AsyncMock() - - # Create reasoning streaming events with proper structure - async def mock_reasoning_stream(): - # Reasoning events - event1 = Mock() - event1.type = "reasoning.start" - yield event1 - - event2 = Mock() - event2.type = "reasoning.delta" - event2.delta = "Thinking..." - yield event2 - - # Text output events - event3 = Mock() - event3.type = "response.output_text.delta" - event3.delta = "The answer " - yield event3 - - event4 = Mock() - event4.type = "response.output_text.delta" - event4.delta = "is 42." - yield event4 - - event5 = Mock() - event5.type = "response.done" - yield event5 - - mock_async_client.responses.create.return_value = mock_reasoning_stream() - self.client.async_client = mock_async_client - - # Call with reasoning model - api_kwargs = { - "model": "o1", - "input": "What is the meaning of life?", - "stream": True, - "reasoning": {"effort": "medium", "summary": "auto"} - } - - stream = await self.client.acall(api_kwargs, ModelType.LLM_REASONING) - - # Process the stream - text_chunks = [] - async for event in stream: - text = extract_text_from_response_stream(event) - if text: - text_chunks.append(text) - - # Assertions - self.assertEqual("".join(text_chunks), "The answer is 42.") - - def test_multiple_images_input(self): - """Test multimodal input with multiple images.""" - # Test with multiple images - multi_image_kwargs = self.client.convert_inputs_to_api_kwargs( - input="Compare these images", - model_kwargs={ - "model": "gpt-4o", - "images": [ - "https://example.com/image1.jpg", - "https://example.com/image2.jpg", - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" - ] - }, - model_type=ModelType.LLM - ) - - # Check content array - content = multi_image_kwargs["input"][0]["content"] - - # Should have 1 text + 3 images = 4 items - self.assertEqual(len(content), 4) - - # Count image contents - image_contents = [c for c in content if c["type"] == "input_image"] - self.assertEqual(len(image_contents), 3) - - # Verify each image - self.assertEqual(image_contents[0]["image_url"], "https://example.com/image1.jpg") - self.assertEqual(image_contents[1]["image_url"], "https://example.com/image2.jpg") - self.assertTrue(image_contents[2]["image_url"].startswith("data:image/png;base64,")) - - -if __name__ == "__main__": - unittest.main() +import unittest +from unittest.mock import patch, AsyncMock, Mock, MagicMock +import os +import base64 + +from openai.types import Image +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseTextDeltaEvent, + ResponseUsage, +) + +from adalflow.core.types import ModelType, GeneratorOutput +from adalflow.components.model_client.openai_client import OpenAIClient +from adalflow.components.model_client.utils import extract_text_from_response_stream +import asyncio + + +def getenv_side_effect(key): + # This dictionary can hold more keys and values as needed + env_vars = {"OPENAI_API_KEY": "fake_api_key"} + return env_vars.get(key, None) # Returns None if key is not found + + +class TestOpenAIClient(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.client = OpenAIClient(api_key="fake_api_key") + # Create a mock Response object with the required fields + self.mock_response = Mock(spec=Response) + self.mock_response.id = "resp-3Q8Z5J9Z1Z5z5" + self.mock_response.created_at = 1635820005.0 + self.mock_response.model = "gpt-4o" + self.mock_response.object = "response" + self.mock_response.output_text = "Hello, world!" + self.mock_response.usage = ResponseUsage( + input_tokens=20, + output_tokens=10, + total_tokens=30, + input_tokens_details={"cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 0}, + ) + self.mock_vision_response = Mock(spec=Response) + self.mock_vision_response.id = "resp-4Q8Z5J9Z1Z5z5" + self.mock_vision_response.created_at = 1635820005.0 + self.mock_vision_response.model = "gpt-4o" + self.mock_vision_response.object = "response" + self.mock_vision_response.output_text = ( + "The image shows a beautiful sunset over mountains." + ) + self.mock_vision_response.usage = ResponseUsage( + input_tokens=25, + output_tokens=15, + total_tokens=40, + input_tokens_details={"cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 0}, + ) + self.mock_image_response = [ + Image( + url="https://example.com/generated_image.jpg", + b64_json=None, + revised_prompt="A white siamese cat sitting elegantly", + model="dall-e-3", + ) + ] + self.api_kwargs = { + "input": "Hello", + "model": "gpt-4o", + } + self.vision_api_kwargs = { + "input": "Describe this image: https://example.com/image.jpg", + "model": "gpt-4o", + } + self.image_generation_kwargs = { + "model": "dall-e-3", + "prompt": "a white siamese cat", + "size": "1024x1024", + "quality": "standard", + "n": 1, + } + + # Add streaming test data for response API using Mock objects + mock_delta_event1 = Mock(spec=ResponseTextDeltaEvent) + mock_delta_event1.type = "response.output_text.delta" + mock_delta_event1.delta = "Once " + + mock_delta_event2 = Mock(spec=ResponseTextDeltaEvent) + mock_delta_event2.type = "response.output_text.delta" + mock_delta_event2.delta = "upon " + + mock_response_obj = Mock(spec=Response) + mock_response_obj.id = "resp-123" + mock_response_obj.created_at = 1635820005.0 + mock_response_obj.model = "gpt-4" + mock_response_obj.object = "response" + mock_response_obj.output_text = "Once upon " + mock_response_obj.usage = ResponseUsage( + input_tokens=10, + output_tokens=2, + total_tokens=12, + input_tokens_details={"cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 0}, + ) + + mock_completed_event = Mock(spec=ResponseCompletedEvent) + mock_completed_event.type = "response.completed" + mock_completed_event.response = mock_response_obj + + self.streaming_events = [ + mock_delta_event1, + mock_delta_event2, + mock_completed_event, + ] + + def test_encode_image(self): + # Create a temporary test image file + test_image_path = "test_image.jpg" + test_content = b"fake image content" + try: + with open(test_image_path, "wb") as f: + f.write(test_content) + + # Test successful encoding + encoded = self.client._encode_image(test_image_path) + self.assertEqual(encoded, base64.b64encode(test_content).decode("utf-8")) + + # Test file not found + with self.assertRaises(ValueError) as context: + self.client._encode_image("nonexistent.jpg") + self.assertIn("Image file not found", str(context.exception)) + + finally: + # Cleanup + if os.path.exists(test_image_path): + os.remove(test_image_path) + + def test_prepare_image_content(self): + # Test URL image + url = "https://example.com/image.jpg" + result = self.client._prepare_image_content(url) + self.assertEqual( + result, + {"type": "image_url", "image_url": {"url": url, "detail": "auto"}}, + ) + + # Test with custom detail level + result = self.client._prepare_image_content(url, detail="high") + self.assertEqual( + result, + {"type": "image_url", "image_url": {"url": url, "detail": "high"}}, + ) + + # Test with pre-formatted content + pre_formatted = { + "type": "image_url", + "image_url": {"url": url, "detail": "low"}, + } + result = self.client._prepare_image_content(pre_formatted) + self.assertEqual(result, pre_formatted) + + def test_convert_inputs_to_api_kwargs_with_images(self): + # Test with single image URL - Response API uses message format for multimodal + model_kwargs = { + "model": "gpt-4o", + "images": "https://example.com/image.jpg", + } + result = self.client.convert_inputs_to_api_kwargs( + input="Describe this image", + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + print(result) + # Response API expects message format with content array for multimodal + expected_input = [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "Describe this image"}, + {"type": "input_image", "image_url": "https://example.com/image.jpg"} + ] + } + ] + self.assertEqual(result["input"], expected_input) + self.assertEqual(result["model"], "gpt-4o") + + # Test with multiple images - Response API uses message format for multimodal + model_kwargs = { + "model": "gpt-4o", + "images": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + ], + "detail": "high", + } + result = self.client.convert_inputs_to_api_kwargs( + input="Compare these images", + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + # Response API expects message format with content array for multimodal + expected_input = [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "Compare these images"}, + {"type": "input_image", "image_url": "https://example.com/image1.jpg"}, + {"type": "input_image", "image_url": "https://example.com/image2.jpg"} + ] + } + ] + self.assertEqual(result["input"], expected_input) + self.assertEqual(result["model"], "gpt-4o") + + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") + async def test_acall_llm(self, MockAsyncOpenAI): + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + # Mock the response + + mock_async_client.responses.create = AsyncMock(return_value=self.mock_response) + + # Call the _acall method + + result = await self.client.acall( + api_kwargs=self.api_kwargs, model_type=ModelType.LLM + ) + + # Assertions + MockAsyncOpenAI.assert_called_once() + mock_async_client.responses.create.assert_awaited_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + @patch( + "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" + ) + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_call(self, MockSyncOpenAI, mock_init_sync_client): + mock_sync_client = Mock() + MockSyncOpenAI.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the client's api: responses.create + mock_sync_client.responses.create = Mock(return_value=self.mock_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.responses.create.assert_called_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + # test parse_response + output = self.client.parse_chat_completion(completion=self.mock_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "Hello, world!") + self.assertEqual(output.usage.output_tokens, 10) + self.assertEqual(output.usage.input_tokens, 20) + self.assertEqual(output.usage.total_tokens, 30) + + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") + async def test_acall_llm_with_vision(self, MockAsyncOpenAI): + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + # Mock the vision model response + mock_async_client.responses.create = AsyncMock( + return_value=self.mock_vision_response + ) + + # Call the _acall method with vision model + result = await self.client.acall( + api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM + ) + + # Assertions + MockAsyncOpenAI.assert_called_once() + mock_async_client.responses.create.assert_awaited_once_with( + **self.vision_api_kwargs + ) + self.assertEqual(result, self.mock_vision_response) + + @patch( + "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" + ) + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client): + mock_sync_client = Mock() + MockSyncOpenAI.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the vision model response + mock_sync_client.responses.create = Mock(return_value=self.mock_vision_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method with vision model + result = self.client.call( + api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM + ) + + # Assertions + mock_sync_client.responses.create.assert_called_once_with( + **self.vision_api_kwargs + ) + self.assertEqual(result, self.mock_vision_response) + + # Test parse_response for vision model + output = self.client.parse_chat_completion(completion=self.mock_vision_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual( + output.raw_response, "The image shows a beautiful sunset over mountains." + ) + self.assertEqual(output.usage.output_tokens, 15) + self.assertEqual(output.usage.input_tokens, 25) + self.assertEqual(output.usage.total_tokens, 40) + + def test_from_dict_to_dict(self): + test_api_key = "fake_api" + client = OpenAIClient(api_key=test_api_key) + client_dict = client.to_dict() + new_client = OpenAIClient.from_dict(client_dict) + self.assertEqual(new_client.to_dict(), client_dict) + + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_init_sync_client_with_headers_and_organization(self, MockOpenAI): + headers = {"Custom-Header": "CustomValue"} + organization = "test-organization" + + # First call happens during __init__ + client = OpenAIClient( + api_key="fake_api_key", + headers=headers, + organization=organization, + ) + + # Clear previous calls so we only test the explicit one below + MockOpenAI.reset_mock() + + # Now call init_sync_client explicitly to trigger the OpenAI call + _ = client.init_sync_client() + + # Assert OpenAI was called with correct parameters + MockOpenAI.assert_called_once_with( + api_key="fake_api_key", + base_url="https://api.openai.com/v1/", + organization=organization, + default_headers=headers, + ) + + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") + async def test_init_async_client_with_headers_and_organization( + self, MockAsyncOpenAI + ): + headers = {"Custom-Header": "CustomValue"} + organization = "test-organization" + + # Manually assign an AsyncMock to the return value + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + client = OpenAIClient( + api_key="fake_api_key", + headers=headers, + organization=organization, + ) + + async_client = client.init_async_client() # Do NOT await here + + MockAsyncOpenAI.assert_called_once_with( + api_key="fake_api_key", + base_url="https://api.openai.com/v1/", + organization=organization, + default_headers=headers, + ) + self.assertEqual(async_client, mock_async_client) + + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_call_with_custom_headers_and_organization(self, MockOpenAI): + # Test that headers and organization are passed during a call + headers = {"Custom-Header": "CustomValue"} + organization = "test-organization" + mock_sync_client = Mock() + MockOpenAI.return_value = mock_sync_client + + client = OpenAIClient( + api_key="fake_api_key", + headers=headers, + organization=organization, + ) + client.sync_client = mock_sync_client + + # Mock the API call + mock_sync_client.responses.create = Mock(return_value=self.mock_response) + + # Call the method + result = client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.responses.create.assert_called_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") + async def test_acall_with_custom_headers_and_organization(self, MockAsyncOpenAI): + # Test that headers and organization are passed during an async call + headers = {"Custom-Header": "CustomValue"} + organization = "test-organization" + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + client = OpenAIClient( + api_key="fake_api_key", + headers=headers, + organization=organization, + ) + client.async_client = mock_async_client + + # Mock the API call + mock_async_client.responses.create = AsyncMock(return_value=self.mock_response) + + # Call the method + result = await client.acall( + api_kwargs=self.api_kwargs, model_type=ModelType.LLM + ) + + # Assertions + mock_async_client.responses.create.assert_awaited_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + async def test_async_streaming(self): + """Test the async streaming method for OpenAIClient.""" + # Setup mock + mock_async_client = AsyncMock() + + # Create an async generator for the mock stream + async def mock_stream(): + for event in self.streaming_events: + yield event + await asyncio.sleep(0.01) + + mock_async_client.responses.create.return_value = mock_stream() + self.client.async_client = mock_async_client + + # Test API kwargs for streaming - Response API uses input as string + api_kwargs = { + "model": "gpt-4", + "input": "You are a helpful assistant. Tell me a short story.", + "stream": True, + "max_tokens": 200, + } + + # Call the async streaming method + stream = await self.client.acall(api_kwargs, ModelType.LLM) + + # Verify the streaming parser is set + self.assertEqual( + self.client.response_parser, + self.client.streaming_response_parser, + ) + + # Process the stream + full_response = "" + async for event in stream: + if hasattr(event, "delta"): # Mock ResponseTextDeltaEvent + full_response += event.delta + elif hasattr(event, "response") and hasattr( + event.response, "output_text" + ): # Mock ResponseCompletedEvent + full_response = event.response.output_text + + # Verify the response + self.assertIn("Once upon", full_response) + + # Verify the API was called correctly + mock_async_client.responses.create.assert_called_once_with(**api_kwargs) + + async def test_parser_switching(self): + """Test that parser switching works correctly.""" + # Initially should be non-streaming parser + self.assertEqual( + self.client.response_parser, + self.client.non_streaming_response_parser, + ) + + # Setup mock for streaming call + mock_async_client = AsyncMock() + + async def mock_stream(): + yield self.streaming_events[0] + + mock_async_client.responses.create.return_value = mock_stream() + self.client.async_client = mock_async_client + + # Test streaming call - should switch to streaming parser + await self.client.acall( + {"model": "gpt-4", "input": "Hello", "stream": True}, ModelType.LLM + ) + self.assertEqual( + self.client.response_parser, + self.client.streaming_response_parser, + ) + + # Test non-streaming call - should switch back to non-streaming parser + mock_async_client.responses.create.return_value = self.mock_response + await self.client.acall( + {"model": "gpt-4", "input": "Hello", "stream": False}, ModelType.LLM + ) + self.assertEqual( + self.client.response_parser, + self.client.non_streaming_response_parser, + ) + + def test_reasoning_model_response(self): + """Test parsing of reasoning model responses with reasoning field.""" + # Create a mock Response with reasoning output + mock_reasoning_response = Mock(spec=Response) + mock_reasoning_response.id = "resp-reasoning-123" + mock_reasoning_response.created_at = 1635820005.0 + mock_reasoning_response.model = "o1" + mock_reasoning_response.object = "response" + mock_reasoning_response.output_text = None # Reasoning models may not have output_text + + # Mock output array with reasoning and message + mock_reasoning_item = Mock() + mock_reasoning_item.type = "reasoning" + mock_reasoning_item.id = "rs_123" + mock_reasoning_item.summary = [ + Mock(type="summary_text", text="I'm thinking about the problem step by step...") + ] + + mock_message_item = Mock() + mock_message_item.type = "message" + mock_message_item.content = [ + Mock(type="output_text", text="The answer is 42.") + ] + + mock_reasoning_response.output = [mock_reasoning_item, mock_message_item] + mock_reasoning_response.usage = ResponseUsage( + input_tokens=50, + output_tokens=100, + total_tokens=150, + input_tokens_details={"cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 80}, + ) + + # Parse the response + result = self.client.parse_chat_completion(mock_reasoning_response) + + # Assertions + self.assertIsInstance(result, GeneratorOutput) + self.assertEqual(result.data, "The answer is 42.") + self.assertIsNotNone(result.thinking) + self.assertIn("thinking about the problem", result.thinking) + # Check that usage was properly tracked + self.assertEqual(result.usage.output_tokens, 100) + self.assertEqual(result.usage.total_tokens, 150) + + def test_multimodal_input_with_images(self): + """Test multimodal input with images (vision models).""" + # Test with URL image + url_kwargs = self.client.convert_inputs_to_api_kwargs( + input="What's in this image?", + model_kwargs={ + "model": "gpt-4o", + "images": "https://example.com/image.jpg" + }, + model_type=ModelType.LLM + ) + + # Should format as message with content array + self.assertIn("input", url_kwargs) + self.assertIsInstance(url_kwargs["input"], list) + self.assertEqual(url_kwargs["input"][0]["role"], "user") + content = url_kwargs["input"][0]["content"] + self.assertIsInstance(content, list) + + # Check text content + text_content = next((c for c in content if c["type"] == "input_text"), None) + self.assertIsNotNone(text_content) + self.assertEqual(text_content["text"], "What's in this image?") + + # Check image content + image_content = next((c for c in content if c["type"] == "input_image"), None) + self.assertIsNotNone(image_content) + self.assertEqual(image_content["image_url"], "https://example.com/image.jpg") + + # Test with base64 image + base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + base64_kwargs = self.client.convert_inputs_to_api_kwargs( + input="Describe this image", + model_kwargs={ + "model": "gpt-4o", + "images": f"data:image/png;base64,{base64_image}" + }, + model_type=ModelType.LLM + ) + + # Check base64 image content + content = base64_kwargs["input"][0]["content"] + image_content = next((c for c in content if c["type"] == "input_image"), None) + self.assertIsNotNone(image_content) + self.assertTrue(image_content["image_url"].startswith("data:image/png;base64,")) + + def test_image_generation_response(self): + """Test parsing of image generation responses.""" + # Create a mock Response with image generation output + mock_image_response = Mock(spec=Response) + mock_image_response.id = "resp-img-gen-123" + mock_image_response.created_at = 1635820005.0 + mock_image_response.model = "gpt-4o" + mock_image_response.object = "response" + mock_image_response.output_text = None + + # Mock output array with image generation call + mock_image_item = Mock() + mock_image_item.type = "image_generation_call" + mock_image_item.result = "base64_encoded_image_data_here" + + mock_message_item = Mock() + mock_message_item.type = "message" + mock_message_item.content = [ + Mock(type="output_text", text="I've generated an image of a cat for you.") + ] + + mock_image_response.output = [mock_image_item, mock_message_item] + mock_image_response.usage = ResponseUsage( + input_tokens=30, + output_tokens=50, + total_tokens=80, + input_tokens_details={"cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 0}, + ) + + # Parse the response + result = self.client.parse_chat_completion(mock_image_response) + + # Assertions + self.assertIsInstance(result, GeneratorOutput) + self.assertEqual(result.data, "I've generated an image of a cat for you.") + self.assertIsNotNone(result.images) + self.assertEqual(result.images, ["base64_encoded_image_data_here"]) + + + def test_streaming_with_helper_function(self): + """Test streaming response with text extraction helper.""" + # Create streaming events with proper structure + event1 = Mock() + event1.type = "response.created" + + event2 = Mock() + event2.type = "response.output_text.delta" + event2.delta = "Hello " + + event3 = Mock() + event3.type = "response.output_text.delta" + event3.delta = "world!" + + event4 = Mock() + event4.type = "response.done" + + events = [event1, event2, event3, event4] + + # Test text extraction + extracted_text = [] + for event in events: + text = extract_text_from_response_stream(event) + if text: + extracted_text.append(text) + + # Assertions + self.assertEqual(extracted_text, ["Hello ", "world!"]) + self.assertEqual("".join(extracted_text), "Hello world!") + + async def test_reasoning_model_streaming(self): + """Test streaming with reasoning model responses.""" + # Setup mock + mock_async_client = AsyncMock() + + # Create reasoning streaming events with proper structure + async def mock_reasoning_stream(): + # Reasoning events + event1 = Mock() + event1.type = "reasoning.start" + yield event1 + + event2 = Mock() + event2.type = "reasoning.delta" + event2.delta = "Thinking..." + yield event2 + + # Text output events + event3 = Mock() + event3.type = "response.output_text.delta" + event3.delta = "The answer " + yield event3 + + event4 = Mock() + event4.type = "response.output_text.delta" + event4.delta = "is 42." + yield event4 + + event5 = Mock() + event5.type = "response.done" + yield event5 + + mock_async_client.responses.create.return_value = mock_reasoning_stream() + self.client.async_client = mock_async_client + + # Call with reasoning model + api_kwargs = { + "model": "o1", + "input": "What is the meaning of life?", + "stream": True, + "reasoning": {"effort": "medium", "summary": "auto"} + } + + stream = await self.client.acall(api_kwargs, ModelType.LLM_REASONING) + + # Process the stream + text_chunks = [] + async for event in stream: + text = extract_text_from_response_stream(event) + if text: + text_chunks.append(text) + + # Assertions + self.assertEqual("".join(text_chunks), "The answer is 42.") + + def test_convert_inputs_to_api_kwargs_reasoning_model_strips_unsupported_params(self): + """Test that unsupported Chat Completion parameters are removed for reasoning models.""" + model_kwargs = { + "model": "o3-mini", + "frequency_penalty": 0, + "presence_penalty": 0, + "temperature": 0.0, + "top_p": 0.99, + "reasoning": {"effort": "medium", "summary": "auto"}, + } + result = self.client.convert_inputs_to_api_kwargs( + input="Solve this problem", + model_kwargs=model_kwargs, + model_type=ModelType.LLM_REASONING, + ) + # Unsupported params should be removed + self.assertNotIn("frequency_penalty", result) + self.assertNotIn("presence_penalty", result) + self.assertNotIn("temperature", result) + self.assertNotIn("top_p", result) + # Supported params should remain + self.assertEqual(result["model"], "o3-mini") + self.assertIn("reasoning", result) + self.assertEqual(result["reasoning"]["effort"], "medium") + + def test_convert_inputs_to_api_kwargs_llm_keeps_all_params(self): + """Test that regular LLM model_type keeps all parameters including frequency_penalty.""" + model_kwargs = { + "model": "gpt-4o", + "frequency_penalty": 0, + "presence_penalty": 0, + "temperature": 0.0, + "top_p": 0.99, + } + result = self.client.convert_inputs_to_api_kwargs( + input="Hello", + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + # All params should be preserved for regular LLM + self.assertIn("frequency_penalty", result) + self.assertIn("presence_penalty", result) + self.assertIn("temperature", result) + self.assertIn("top_p", result) + + def test_multiple_images_input(self): + """Test multimodal input with multiple images.""" + # Test with multiple images + multi_image_kwargs = self.client.convert_inputs_to_api_kwargs( + input="Compare these images", + model_kwargs={ + "model": "gpt-4o", + "images": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + ] + }, + model_type=ModelType.LLM + ) + + # Check content array + content = multi_image_kwargs["input"][0]["content"] + + # Should have 1 text + 3 images = 4 items + self.assertEqual(len(content), 4) + + # Count image contents + image_contents = [c for c in content if c["type"] == "input_image"] + self.assertEqual(len(image_contents), 3) + + # Verify each image + self.assertEqual(image_contents[0]["image_url"], "https://example.com/image1.jpg") + self.assertEqual(image_contents[1]["image_url"], "https://example.com/image2.jpg") + self.assertTrue(image_contents[2]["image_url"].startswith("data:image/png;base64,")) + + +if __name__ == "__main__": + unittest.main()