Skip to content

[RFC] - ClientFactory: Client Management for Extensibility #1241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions RFCs/002/RFC-002-client-factory.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Protocol, runtime_checkable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Messages\n",
"\n",
"From RFC-000: Eventually, we will move away from dictionaries and represent all messages with objects (dataclasses or Pydantic base models). For now, we will use existing implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Message = dict[str, Any]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration\n",
"\n",
"Eventually, we will move away from dictionaries and represent the configuration with objects, but for now we will use the existing dict implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Configuration = dict[str, Any]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Client messages\n",
"\n",
"We need to define a message protocol that all the clients will align to, we are using OpenAIs ChatCompletion at the moment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen.oai.oai_models import ChatCompletion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Client protocol\n",
"\n",
"The basic client protocol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@runtime_checkable\n",
"class ClientProtocol(Protocol):\n",
" def create(self, history: Message) -> ChatCompletion:\n",
" \"\"\"\n",
" Generates a response/message based on the conversation history.\n",
" \"\"\"\n",
" pass\n",
"\n",
" def accept(cls, config: Configuration) -> tuple[bool, int]:\n",
" \"\"\"\n",
" Returns a tuple:\n",
" - A boolean indicating if the client matches.\n",
" - An integer representing the specificity (higher is more specific).\n",
" \"\"\"\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Client Factory"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ClientFactory:\n",
" _registry: list[ClientProtocol] = []\n",
"\n",
" @classmethod\n",
" def register(cls, client_class: ClientProtocol) -> ClientProtocol:\n",
" \"\"\"\n",
" Decorator to register a new client class.\n",
" \"\"\"\n",
" cls._registry.append(client_class)\n",
" return client_class\n",
"\n",
" @classmethod\n",
" def create_client(cls, config: Configuration) -> ClientProtocol:\n",
" # Collect matches and their priorities\n",
" matches = []\n",
" for client_class in cls._registry:\n",
" is_match, priority = client_class.accept(config)\n",
" if is_match:\n",
" matches.append((priority, client_class))\n",
"\n",
" if not matches:\n",
" raise ValueError(f\"No compatible client found for configuration: {config}\")\n",
"\n",
" # Sort by priority (highest priority first)\n",
" matches.sort(key=lambda x: x[0], reverse=True)\n",
"\n",
" # Log if there are multiple matches\n",
" if len(matches) > 1:\n",
" print(f\"Warning: Multiple clients matched. Using {matches[0][1].__name__} with priority {matches[0][0]}.\")\n",
"\n",
" # Instantiate the highest-priority client\n",
" return matches[0][1](config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@ClientFactory.register\n",
"class ExampleClient:\n",
" priority = 0\n",
"\n",
" def __init__(self, config: Configuration):\n",
" self.name = \"ExampleClient\"\n",
"\n",
" def create(self, history: Message) -> ChatCompletion:\n",
" return \"This is a generated response from ExampleClient.\" # Should be chat completion object, str now before we have the model\n",
"\n",
" @classmethod\n",
" def accept(self, config: Configuration) -> bool:\n",
" return config.get(\"type\") == \"example\", self.priority"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration for the client\n",
"config = {\"type\": \"example\"}\n",
"\n",
"# Create the client using the factory\n",
"client = ClientFactory.create_client(config)\n",
"\n",
"# Generate a response\n",
"history = [{\"role\": \"user\", \"content\": \"Hello!\"}]\n",
"response = client.create(history)\n",
"\n",
"print(f\"Client: {client.name}\")\n",
"print(f\"Response: {response}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Structured output, function calling etc..\n",
"\n",
"The functionaltities that each client supports will be different. For example, one client might support structured output, while another might support function calling. The factory pattern allows us to create a client instance based on the provided configuration. The client instance can then be used to generate a response based on the conversation history.\n",
"\n",
"Existing clients can continue using the simpler accept method with a default priority of 0.\n",
"\n",
"This way, we can support previous implementations of clients with the lowest priority and slowly cannibalise them into more specific, higher priority clients."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example clients\n",
"@ClientFactory.register\n",
"class DefaultOpenAIClient:\n",
" def __init__(self, config: Configuration):\n",
" self.name = \"DefaultOpenAIClient\"\n",
"\n",
" def create(self, history: Message) -> ChatCompletion:\n",
" return \"This is a response from DefaultOpenAIClient.\"\n",
"\n",
" @classmethod\n",
" def accept(cls, config: Configuration) -> tuple[bool, int]:\n",
" is_match = config.get(\"api_type\") == \"openai\"\n",
" priority = 1 # General OpenAI client has lower specificity\n",
" return is_match, priority\n",
"\n",
"\n",
"@ClientFactory.register\n",
"class OpenAIStructuredClient:\n",
" def __init__(self, config: Configuration):\n",
" self.name = \"OpenAIStructuredClient\"\n",
"\n",
" def create(self, history: Message) -> ChatCompletion:\n",
" return \"\"\"{\"response\": \"This is a response from OpenAIStructuredClient.\"}\"\"\"\n",
"\n",
" @classmethod\n",
" def accept(cls, config: Configuration) -> tuple[bool, int]:\n",
" is_match = config.get(\"api_type\") == \"openai\" and \"structured_output\" in config\n",
" priority = 10 # Structured client has higher specificity\n",
" return is_match, priority"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configurations\n",
"general_config = {\"api_type\": \"openai\"}\n",
"structured_config = {\"api_type\": \"openai\", \"structured_output\": True}\n",
"\n",
"# Create general OpenAI client\n",
"general_client = ClientFactory.create_client(general_config)\n",
"print(f\"Client: {general_client.name}\")\n",
"print(f\"Response: {general_client.create([])}\")\n",
"\n",
"# Create structured OpenAI client\n",
"structured_client = ClientFactory.create_client(structured_config)\n",
"print(f\"Client: {structured_client.name}\")\n",
"print(f\"Response: {structured_client.create([])}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading