Skip to content

Upgrade autogen to 0.5.2 #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autogenui/datamodel/app.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@
class ModelConfig:
model: str
model_type: Literal["OpenAIChatCompletionClient"]
base_url: Optional[str] = None # Add base_url as it's also used
model_info: Optional[dict] = None # Add the missing model_info field


@dataclass
6 changes: 3 additions & 3 deletions autogenui/manager.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@
import json
from pathlib import Path
from .datamodel import TeamResult, TaskResult, TeamConfig
from autogen_agentchat.messages import AgentMessage, ChatMessage
from autogen_core.base import CancellationToken
from autogen_agentchat.messages import ChatMessage, BaseChatMessage, BaseAgentEvent
from autogen_core._cancellation_token import CancellationToken # Corrected import path
from .provider import Provider
from .datamodel import TeamConfig

@@ -38,7 +38,7 @@ async def run_stream(
task: str,
team_config: Optional[Union[TeamConfig, str, Path]] = None,
cancellation_token: Optional[CancellationToken] = None
) -> AsyncGenerator[Union[AgentMessage, ChatMessage, TaskResult], None]:
) -> AsyncGenerator[Union[ChatMessage, BaseAgentEvent, TaskResult], None]:
"""Stream the team's execution results with optional JSON config loading"""
start_time = time.time()

203 changes: 193 additions & 10 deletions autogenui/provider.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,163 @@

from .datamodel import AgentConfig, ModelConfig, ToolConfig, TerminationConfig, TeamConfig
from autogen_agentchat.agents import AssistantAgent, CodingAssistantAgent
from autogen_agentchat.agents import AssistantAgent # AssistantAgent is already imported
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_core.components.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.conditions._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_core.tools._function_tool import FunctionTool


AgentTypes = AssistantAgent | CodingAssistantAgent
AgentTypes = AssistantAgent
TeamTypes = RoundRobinGroupChat | SelectorGroupChat
ModelTypes = OpenAIChatCompletionClient | None
TerminationTypes = MaxMessageTermination | StopMessageTermination | TextMentionTermination


# Define the custom speaker selector function (Corrected signature)
# It only receives the message history
import logging # Add logging for debugging

# Define the custom speaker selector function (Corrected signature)
# It only receives the message history
def custom_speaker_selector(messages: list) -> str | None:
"""
Custom speaker selection based on message history:
- Finds the last valid agent message (qa_agent or writing_agent) to base decisions on.
- qa_agent speaks first and answers questions. If needed, it can request writing_agent
by including 'REQUEST_WRITING_AGENT' in its message. It signals completion by
including 'QA_COMPLETE'.
- writing_agent speaks ONLY when requested by qa_agent. It signals completion by
including 'STORY_COMPLETE'.
- final_responder speaks once after either 'QA_COMPLETE' or 'STORY_COMPLETE'
is found in the last valid agent message (case-insensitive).
- The chat ends after final_responder speaks.
"""
logging.debug(f"Selector called with {len(messages)} messages.")
if not messages:
logging.debug("Selector: No messages, selecting initial speaker: qa_agent")
return "qa_agent"

# --- Termination Check: Check if final_responder spoke anywhere ---
logging.debug("Selector: --- Starting Termination Check ---") # Add start marker
for idx, msg in enumerate(messages): # Add index for clarity
speaker_name = None
logging.debug(f"Selector: Checking message at index {idx}. Type: {type(msg)}, Content snippet: {str(msg)[:150]}...") # Log type and snippet

# CORRECTED LOGIC: Prioritize 'source' attribute for TextMessage objects
if hasattr(msg, 'source'):
speaker_name = getattr(msg, 'source', None)
logging.debug(f"Selector: Index {idx}: Found 'source' attribute: '{speaker_name}'")
# Fallback for dict-like messages
elif isinstance(msg, dict):
speaker_name = msg.get('source', msg.get('name')) # Prefer 'source', fallback to 'name'
logging.debug(f"Selector: Index {idx}: Message is dict. Extracted source/name: '{speaker_name}'")
else:
logging.debug(f"Selector: Index {idx}: Message is not object with 'source' or dict. Skipping name check.")

# Check if the speaker is final_responder (case-insensitive)
if speaker_name and isinstance(speaker_name, str):
processed_name = speaker_name.strip().lower()
logging.debug(f"Selector: Index {idx}: Comparing processed name '{processed_name}' with 'final_responder'")
if processed_name == "final_responder":
logging.debug(f"Selector: MATCH FOUND! final_responder (name from source/dict: '{speaker_name}') found at index {idx}. Terminating.")
return None # Terminate the conversation
else:
logging.debug(f"Selector: Index {idx}: No valid speaker name found or not string. Speaker name was: '{speaker_name}'")

logging.debug("Selector: --- Finished Termination Check (No termination) ---") # Add end marker

# --- Find the last valid agent message ---
last_valid_agent_message = None
logging.debug("Selector: Searching for last valid agent message...")
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
logging.debug(f"Selector: Checking message at index {i}. Type: {type(msg)}")

# Try accessing keys directly or via .get() if available
source_raw = None
content_raw = None
try:
# Prefer .get() if it's dict-like, otherwise try direct access if it's an object
if hasattr(msg, 'get'):
source_raw = msg.get("source")
content_raw = msg.get("content")
elif hasattr(msg, 'source') and hasattr(msg, 'content'):
source_raw = msg.source
content_raw = msg.content
except Exception as e:
logging.debug(f"Selector: Error accessing source/content at index {i}: {e}")
continue # Skip if keys/attributes don't exist

# Check if both source and content were successfully retrieved
if source_raw is not None and content_raw is not None:
source_type = type(source_raw)
logging.debug(f"Selector: Index {i}: Found 'source' key/attr. Type: {source_type}, Raw value: '{source_raw}'")

# Ensure source is string before processing
if isinstance(source_raw, str):
source_stripped = source_raw.strip()
source_lower = source_stripped.lower()
logging.debug(f"Selector: Index {i}: Stripped source: '{source_stripped}', Lowercased source: '{source_lower}'")

# Check if source contains agent names (case-insensitive)
is_qa_agent = "qa_agent" in source_lower
is_writing_agent = "writing_agent" in source_lower
logging.debug(f"Selector: Index {i}: 'qa_agent' in source? {is_qa_agent}, 'writing_agent' in source? {is_writing_agent}")

if is_qa_agent or is_writing_agent:
last_valid_agent_message = msg # Store the original msg object
logging.debug(f"Selector: Found last valid agent message at index {i} from source '{source_lower}' (original source: '{source_raw}')")
break # Found the most recent valid one
else:
logging.debug(f"Selector: Processed message source '{source_lower}' does not contain 'qa_agent' or 'writing_agent'. Skipping.")
else:
logging.debug(f"Selector: Message source at index {i} is not a string (Type: {source_type}). Skipping.")
else:
logging.debug(f"Selector: Message at index {i} lacks 'source' or 'content' key/attribute. Skipping. Msg: {msg}")


if last_valid_agent_message is None:
logging.warning("Selector: No valid agent message (containing 'qa_agent' or 'writing_agent' in source) found in history. Defaulting to qa_agent.")
return "qa_agent" # Default if no valid agent message found

# --- Speaker Selection based on LAST VALID AGENT message ---
# Get source and content from the identified valid message (using the same safe access)
last_valid_source_raw = None
last_valid_content_raw = None
try:
if hasattr(last_valid_agent_message, 'get'):
last_valid_source_raw = last_valid_agent_message.get("source")
last_valid_content_raw = last_valid_agent_message.get("content")
elif hasattr(last_valid_agent_message, 'source') and hasattr(last_valid_agent_message, 'content'):
last_valid_source_raw = last_valid_agent_message.source
last_valid_content_raw = last_valid_agent_message.content
except Exception as e:
logging.error(f"Selector: Error accessing source/content from identified last_valid_agent_message: {e}. Defaulting to qa_agent.")
return "qa_agent" # Should not happen if found previously, but safety check

last_valid_source = str(last_valid_source_raw).strip().lower() if isinstance(last_valid_source_raw, str) else "" # Processed source
last_valid_content = str(last_valid_content_raw).strip().lower() if isinstance(last_valid_content_raw, str) else ""

# 1. Trigger writing_agent if requested by qa_agent in the last valid message
# Use processed source and content
if "qa_agent" in last_valid_source and "request_writing_agent" in last_valid_content:
logging.debug("Selector: Selecting writing_agent based on request in last valid message.")
return "writing_agent"

# 2. Trigger final_responder if a completion signal is in the last valid message
# Use processed source and content
if ("qa_agent" in last_valid_source and "qa_complete" in last_valid_content) or \
("writing_agent" in last_valid_source and "story_complete" in last_valid_content):
logging.debug(f"Selector: Selecting final_responder based on completion signal in last valid message: '{last_valid_content}' from source '{last_valid_source_raw}'") # Log original source for clarity
return "final_responder"

# 3. Default: qa_agent
logging.debug("Selector: No specific trigger in last valid message. Defaulting to qa_agent.")
return "qa_agent"


class Provider():
def __init__(self):
pass

def load_model(self, model_config: ModelConfig | dict) -> ModelTypes:
if isinstance(model_config, dict):
try:
@@ -25,7 +166,30 @@ def load_model(self, model_config: ModelConfig | dict) -> ModelTypes:
raise ValueError("Invalid model config")
model = None
if model_config.model_type == "OpenAIChatCompletionClient":
model = OpenAIChatCompletionClient(model=model_config.model)
# Prepare arguments for the client constructor
client_args = {
"model": model_config.model,
}
# Add optional arguments if they exist in the config
if hasattr(model_config, 'base_url') and model_config.base_url:
client_args['base_url'] = model_config.base_url
# Crucially, pass model_info if it exists, as required for non-standard model names
if hasattr(model_config, 'model_info') and model_config.model_info:
# Assuming model_info in config is already a dict or compatible structure
client_args['model_info'] = model_config.model_info

# Instantiate the client using the prepared arguments
# The client typically handles API key from environment variables (e.g., OPENAI_API_KEY)
try:
model = OpenAIChatCompletionClient(**client_args)
except ValueError as e:
# Provide more context if the specific error occurs again
print(f"Error initializing OpenAIChatCompletionClient: {e}")
print(f"Arguments passed: {client_args}")
raise e # Re-raise the original error
except Exception as e:
print(f"An unexpected error occurred during client initialization: {e}")
raise e
return model

def _func_from_string(self, content: str) -> callable:
@@ -88,7 +252,11 @@ def load_agent(self, agent_config: AgentConfig | dict) -> AgentTypes:
if agent_config.agent_type == "AssistantAgent":
model_client = self.load_model(agent_config.model_client)
system_message = agent_config.system_message if agent_config.system_message else "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed."
tools = [self.load_tool(tool) for tool in agent_config.tools]
# Handle cases where tools might be None before iterating
tools = []
if agent_config.tools is not None:
tools = [self.load_tool(tool) for tool in agent_config.tools]

agent = AssistantAgent(
name=agent_config.name, model_client=model_client, tools=tools, system_message=system_message)

@@ -127,6 +295,21 @@ def load_team(self, team_config: TeamConfig | dict) -> TeamTypes:
team = RoundRobinGroupChat(
agents, termination_condition=termination)
elif team_config.team_type == "SelectorGroupChat":
team = SelectorGroupChat(agents, termination_condition=termination)
# Load the top-level model client ONLY if it's defined in the config
top_level_client = None
if team_config.model_client:
top_level_client = self.load_model(team_config.model_client)
else:
# Explicitly handle the case where no model_client is needed/provided
logging.debug("SelectorGroupChat: No top-level model_client configured. Relying solely on selector_func.")

# Pass agents as the first positional argument
# Use the correct 'selector_func' parameter name
team = SelectorGroupChat(
agents, # Positional argument
termination_condition=termination,
model_client=top_level_client, # Pass the potentially None client
selector_func=custom_speaker_selector # Use correct parameter
)

return team
Loading