diff --git a/autogenui/datamodel/app.py b/autogenui/datamodel/app.py index 9a9cbda..3caf2a8 100644 --- a/autogenui/datamodel/app.py +++ b/autogenui/datamodel/app.py @@ -9,6 +9,8 @@ class ModelConfig: model: str model_type: Literal["OpenAIChatCompletionClient"] + base_url: Optional[str] = None # Add base_url as it's also used + model_info: Optional[dict] = None # Add the missing model_info field @dataclass diff --git a/autogenui/manager.py b/autogenui/manager.py index 347efe7..dfc6a3e 100644 --- a/autogenui/manager.py +++ b/autogenui/manager.py @@ -3,8 +3,8 @@ import json from pathlib import Path from .datamodel import TeamResult, TaskResult, TeamConfig -from autogen_agentchat.messages import AgentMessage, ChatMessage -from autogen_core.base import CancellationToken +from autogen_agentchat.messages import ChatMessage, BaseChatMessage, BaseAgentEvent +from autogen_core._cancellation_token import CancellationToken # Corrected import path from .provider import Provider from .datamodel import TeamConfig @@ -38,7 +38,7 @@ async def run_stream( task: str, team_config: Optional[Union[TeamConfig, str, Path]] = None, cancellation_token: Optional[CancellationToken] = None - ) -> AsyncGenerator[Union[AgentMessage, ChatMessage, TaskResult], None]: + ) -> AsyncGenerator[Union[ChatMessage, BaseAgentEvent, TaskResult], None]: """Stream the team's execution results with optional JSON config loading""" start_time = time.time() diff --git a/autogenui/provider.py b/autogenui/provider.py index be73ac4..c1ab06d 100644 --- a/autogenui/provider.py +++ b/autogenui/provider.py @@ -1,22 +1,163 @@ - from .datamodel import AgentConfig, ModelConfig, ToolConfig, TerminationConfig, TeamConfig -from autogen_agentchat.agents import AssistantAgent, CodingAssistantAgent +from autogen_agentchat.agents import AssistantAgent # AssistantAgent is already imported from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat -from autogen_ext.models import OpenAIChatCompletionClient -from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination -from autogen_core.components.tools import FunctionTool +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_agentchat.conditions._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination +from autogen_core.tools._function_tool import FunctionTool -AgentTypes = AssistantAgent | CodingAssistantAgent +AgentTypes = AssistantAgent TeamTypes = RoundRobinGroupChat | SelectorGroupChat ModelTypes = OpenAIChatCompletionClient | None TerminationTypes = MaxMessageTermination | StopMessageTermination | TextMentionTermination +# Define the custom speaker selector function (Corrected signature) +# It only receives the message history +import logging # Add logging for debugging + +# Define the custom speaker selector function (Corrected signature) +# It only receives the message history +def custom_speaker_selector(messages: list) -> str | None: + """ + Custom speaker selection based on message history: + - Finds the last valid agent message (qa_agent or writing_agent) to base decisions on. + - qa_agent speaks first and answers questions. If needed, it can request writing_agent + by including 'REQUEST_WRITING_AGENT' in its message. It signals completion by + including 'QA_COMPLETE'. + - writing_agent speaks ONLY when requested by qa_agent. It signals completion by + including 'STORY_COMPLETE'. + - final_responder speaks once after either 'QA_COMPLETE' or 'STORY_COMPLETE' + is found in the last valid agent message (case-insensitive). + - The chat ends after final_responder speaks. + """ + logging.debug(f"Selector called with {len(messages)} messages.") + if not messages: + logging.debug("Selector: No messages, selecting initial speaker: qa_agent") + return "qa_agent" + + # --- Termination Check: Check if final_responder spoke anywhere --- + logging.debug("Selector: --- Starting Termination Check ---") # Add start marker + for idx, msg in enumerate(messages): # Add index for clarity + speaker_name = None + logging.debug(f"Selector: Checking message at index {idx}. Type: {type(msg)}, Content snippet: {str(msg)[:150]}...") # Log type and snippet + + # CORRECTED LOGIC: Prioritize 'source' attribute for TextMessage objects + if hasattr(msg, 'source'): + speaker_name = getattr(msg, 'source', None) + logging.debug(f"Selector: Index {idx}: Found 'source' attribute: '{speaker_name}'") + # Fallback for dict-like messages + elif isinstance(msg, dict): + speaker_name = msg.get('source', msg.get('name')) # Prefer 'source', fallback to 'name' + logging.debug(f"Selector: Index {idx}: Message is dict. Extracted source/name: '{speaker_name}'") + else: + logging.debug(f"Selector: Index {idx}: Message is not object with 'source' or dict. Skipping name check.") + + # Check if the speaker is final_responder (case-insensitive) + if speaker_name and isinstance(speaker_name, str): + processed_name = speaker_name.strip().lower() + logging.debug(f"Selector: Index {idx}: Comparing processed name '{processed_name}' with 'final_responder'") + if processed_name == "final_responder": + logging.debug(f"Selector: MATCH FOUND! final_responder (name from source/dict: '{speaker_name}') found at index {idx}. Terminating.") + return None # Terminate the conversation + else: + logging.debug(f"Selector: Index {idx}: No valid speaker name found or not string. Speaker name was: '{speaker_name}'") + + logging.debug("Selector: --- Finished Termination Check (No termination) ---") # Add end marker + + # --- Find the last valid agent message --- + last_valid_agent_message = None + logging.debug("Selector: Searching for last valid agent message...") + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + logging.debug(f"Selector: Checking message at index {i}. Type: {type(msg)}") + + # Try accessing keys directly or via .get() if available + source_raw = None + content_raw = None + try: + # Prefer .get() if it's dict-like, otherwise try direct access if it's an object + if hasattr(msg, 'get'): + source_raw = msg.get("source") + content_raw = msg.get("content") + elif hasattr(msg, 'source') and hasattr(msg, 'content'): + source_raw = msg.source + content_raw = msg.content + except Exception as e: + logging.debug(f"Selector: Error accessing source/content at index {i}: {e}") + continue # Skip if keys/attributes don't exist + + # Check if both source and content were successfully retrieved + if source_raw is not None and content_raw is not None: + source_type = type(source_raw) + logging.debug(f"Selector: Index {i}: Found 'source' key/attr. Type: {source_type}, Raw value: '{source_raw}'") + + # Ensure source is string before processing + if isinstance(source_raw, str): + source_stripped = source_raw.strip() + source_lower = source_stripped.lower() + logging.debug(f"Selector: Index {i}: Stripped source: '{source_stripped}', Lowercased source: '{source_lower}'") + + # Check if source contains agent names (case-insensitive) + is_qa_agent = "qa_agent" in source_lower + is_writing_agent = "writing_agent" in source_lower + logging.debug(f"Selector: Index {i}: 'qa_agent' in source? {is_qa_agent}, 'writing_agent' in source? {is_writing_agent}") + + if is_qa_agent or is_writing_agent: + last_valid_agent_message = msg # Store the original msg object + logging.debug(f"Selector: Found last valid agent message at index {i} from source '{source_lower}' (original source: '{source_raw}')") + break # Found the most recent valid one + else: + logging.debug(f"Selector: Processed message source '{source_lower}' does not contain 'qa_agent' or 'writing_agent'. Skipping.") + else: + logging.debug(f"Selector: Message source at index {i} is not a string (Type: {source_type}). Skipping.") + else: + logging.debug(f"Selector: Message at index {i} lacks 'source' or 'content' key/attribute. Skipping. Msg: {msg}") + + + if last_valid_agent_message is None: + logging.warning("Selector: No valid agent message (containing 'qa_agent' or 'writing_agent' in source) found in history. Defaulting to qa_agent.") + return "qa_agent" # Default if no valid agent message found + + # --- Speaker Selection based on LAST VALID AGENT message --- + # Get source and content from the identified valid message (using the same safe access) + last_valid_source_raw = None + last_valid_content_raw = None + try: + if hasattr(last_valid_agent_message, 'get'): + last_valid_source_raw = last_valid_agent_message.get("source") + last_valid_content_raw = last_valid_agent_message.get("content") + elif hasattr(last_valid_agent_message, 'source') and hasattr(last_valid_agent_message, 'content'): + last_valid_source_raw = last_valid_agent_message.source + last_valid_content_raw = last_valid_agent_message.content + except Exception as e: + logging.error(f"Selector: Error accessing source/content from identified last_valid_agent_message: {e}. Defaulting to qa_agent.") + return "qa_agent" # Should not happen if found previously, but safety check + + last_valid_source = str(last_valid_source_raw).strip().lower() if isinstance(last_valid_source_raw, str) else "" # Processed source + last_valid_content = str(last_valid_content_raw).strip().lower() if isinstance(last_valid_content_raw, str) else "" + + # 1. Trigger writing_agent if requested by qa_agent in the last valid message + # Use processed source and content + if "qa_agent" in last_valid_source and "request_writing_agent" in last_valid_content: + logging.debug("Selector: Selecting writing_agent based on request in last valid message.") + return "writing_agent" + + # 2. Trigger final_responder if a completion signal is in the last valid message + # Use processed source and content + if ("qa_agent" in last_valid_source and "qa_complete" in last_valid_content) or \ + ("writing_agent" in last_valid_source and "story_complete" in last_valid_content): + logging.debug(f"Selector: Selecting final_responder based on completion signal in last valid message: '{last_valid_content}' from source '{last_valid_source_raw}'") # Log original source for clarity + return "final_responder" + + # 3. Default: qa_agent + logging.debug("Selector: No specific trigger in last valid message. Defaulting to qa_agent.") + return "qa_agent" + + class Provider(): def __init__(self): pass - def load_model(self, model_config: ModelConfig | dict) -> ModelTypes: if isinstance(model_config, dict): try: @@ -25,7 +166,30 @@ def load_model(self, model_config: ModelConfig | dict) -> ModelTypes: raise ValueError("Invalid model config") model = None if model_config.model_type == "OpenAIChatCompletionClient": - model = OpenAIChatCompletionClient(model=model_config.model) + # Prepare arguments for the client constructor + client_args = { + "model": model_config.model, + } + # Add optional arguments if they exist in the config + if hasattr(model_config, 'base_url') and model_config.base_url: + client_args['base_url'] = model_config.base_url + # Crucially, pass model_info if it exists, as required for non-standard model names + if hasattr(model_config, 'model_info') and model_config.model_info: + # Assuming model_info in config is already a dict or compatible structure + client_args['model_info'] = model_config.model_info + + # Instantiate the client using the prepared arguments + # The client typically handles API key from environment variables (e.g., OPENAI_API_KEY) + try: + model = OpenAIChatCompletionClient(**client_args) + except ValueError as e: + # Provide more context if the specific error occurs again + print(f"Error initializing OpenAIChatCompletionClient: {e}") + print(f"Arguments passed: {client_args}") + raise e # Re-raise the original error + except Exception as e: + print(f"An unexpected error occurred during client initialization: {e}") + raise e return model def _func_from_string(self, content: str) -> callable: @@ -88,7 +252,11 @@ def load_agent(self, agent_config: AgentConfig | dict) -> AgentTypes: if agent_config.agent_type == "AssistantAgent": model_client = self.load_model(agent_config.model_client) system_message = agent_config.system_message if agent_config.system_message else "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed." - tools = [self.load_tool(tool) for tool in agent_config.tools] + # Handle cases where tools might be None before iterating + tools = [] + if agent_config.tools is not None: + tools = [self.load_tool(tool) for tool in agent_config.tools] + agent = AssistantAgent( name=agent_config.name, model_client=model_client, tools=tools, system_message=system_message) @@ -127,6 +295,21 @@ def load_team(self, team_config: TeamConfig | dict) -> TeamTypes: team = RoundRobinGroupChat( agents, termination_condition=termination) elif team_config.team_type == "SelectorGroupChat": - team = SelectorGroupChat(agents, termination_condition=termination) + # Load the top-level model client ONLY if it's defined in the config + top_level_client = None + if team_config.model_client: + top_level_client = self.load_model(team_config.model_client) + else: + # Explicitly handle the case where no model_client is needed/provided + logging.debug("SelectorGroupChat: No top-level model_client configured. Relying solely on selector_func.") + + # Pass agents as the first positional argument + # Use the correct 'selector_func' parameter name + team = SelectorGroupChat( + agents, # Positional argument + termination_condition=termination, + model_client=top_level_client, # Pass the potentially None client + selector_func=custom_speaker_selector # Use correct parameter + ) return team diff --git a/autogenui/web/app.py b/autogenui/web/app.py index b328d5d..2de365c 100644 --- a/autogenui/web/app.py +++ b/autogenui/web/app.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware @@ -11,11 +11,11 @@ import traceback import asyncio from pathlib import Path +from fastapi.staticfiles import StaticFiles # Ensure StaticFiles is imported # Import your team manager components from autogen_agentchat import EVENT_LOGGER_NAME -from autogen_agentchat.messages import AgentMessage, ChatMessage, ToolCallMessage, ToolCallResultMessage -from autogen_core.base import CancellationToken +from autogen_agentchat.messages import ChatMessage, BaseChatMessage from ..manager import TeamManager from ..datamodel import TeamResult, TaskResult @@ -68,7 +68,7 @@ async def session_exists(self, session_id: str) -> bool: # CORS middleware setup app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:3000"], # Add your frontend URL + allow_origins=["http://localhost:3000", "http://localhost:8081", "http://127.0.0.1:8081"], # Allow frontend dev server and the served UI origin allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -82,7 +82,7 @@ async def session_exists(self, session_id: str) -> bool: team_manager = TeamManager() # Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) # Set level to DEBUG to see detailed logs logger = logging.getLogger(EVENT_LOGGER_NAME) logger.setLevel(logging.INFO) @@ -146,36 +146,78 @@ async def generate(req: GenerateWebRequest): cancellation_token=None ): try: - if isinstance(message, (AgentMessage, ChatMessage)): - content = message.content if hasattr( - message, 'content') else str(message) - if isinstance(message, ToolCallMessage) or isinstance(message, ToolCallResultMessage): - content = "".join([str(tool_call) - for tool_call in message.content]) + if isinstance(message, BaseChatMessage): + content = message.content if hasattr(message, 'content') else str(message) + # Determine the source name using message.source + source_name = "system" # Default fallback + if hasattr(message, 'source') and isinstance(message.source, str) and message.source: + # Use message.source if it's a non-empty string + source_name = message.source + else: + # Log if source is missing or not a string, then fallback + logger.warning(f"[Source Check] Message 'source' attribute missing, not a string, or empty. Falling back to 'system'. Message: {message}") + + # Removed detailed sender logging + await websocket.send_json({ "type": "message", "content": content, - "source": message.sender if hasattr(message, 'sender') else "system", + "source": source_name, # Use the determined source name (from message.source) "timestamp": str(datetime.datetime.now()) }) elif isinstance(message, TeamResult): + # Restore original content extraction logic (or the improved one without logs) + content_to_send = "[No final_responder message found]" # Default + final_responder_name = "final_responder" # Agreed-upon name for the final agent + try: + if hasattr(message.task_result, 'messages') and message.task_result.messages: + logger.info(f"Searching for last non-empty message from agent '{final_responder_name}' in: {message.task_result.messages}") + # Iterate backwards through messages + for msg in reversed(message.task_result.messages): + # Check if message is from the designated final responder and has non-empty content + if hasattr(msg, 'source') and msg.source == final_responder_name and hasattr(msg, 'content') and msg.content: + logger.info(f"Found suitable message from {final_responder_name}: {msg}") + content_to_send = str(msg.content) + break # Stop after finding the first suitable one + # Keep warning if no suitable message found + if content_to_send == "[No final_responder message found]": + logger.warning(f"Could not find a non-empty message from agent '{final_responder_name}' in the list.") + elif hasattr(message.task_result, 'messages'): + logger.warning("TeamResult.task_result.messages is empty.") + content_to_send = "[Messages list is empty]" + else: + logger.warning("TeamResult.task_result has no 'messages' attribute. Falling back to str(message).") + content_to_send = str(message) # Fallback + except Exception as log_err: + logger.error(f"Error during content extraction for TeamResult: {log_err}") + content_to_send = f"[Error during extraction: {log_err}]" + + # Send the result event as "TaskResultEvent" await websocket.send_json({ - "type": "result", - "content": str(message.task_result.messages[-1].content) if hasattr(message.task_result, 'messages') else str(message), + "type": "TaskResultEvent", # Changed type here + "content": content_to_send, "source": "task_result", "timestamp": str(datetime.datetime.now()) }) + + # Send TerminationEvent immediately after the result (without extra logs) + await websocket.send_json({ + "type": "TerminationEvent", + "content": "Stream completed after result", + "timestamp": str(datetime.datetime.now()) + }) + + # Break the loop (without extra logs) + break + except Exception as e: logger.error( f"Error sending message for session {req.session_id}: {str(e)}") + # If sending fails, we might still want to try sending an error event later + # but re-raise for now to be caught by the outer handler raise - # Send completion event - await websocket.send_json({ - "type": "TerminationEvent", - "content": "Stream completed", - "timestamp": str(datetime.datetime.now()) - }) + # REMOVED TerminationEvent send from here, as it's now sent immediately after TeamResult return { "status": True, @@ -208,10 +250,16 @@ async def generate(req: GenerateWebRequest): # Mount the API router app.mount("/api", api) +# Serve static files from the 'ui' directory +ui_path = Path(__file__).parent / "ui" +if ui_path.exists(): + app.mount("/", StaticFiles(directory=ui_path, html=True), name="ui") +else: + logger.warning(f"UI directory not found at {ui_path}, UI will not be served.") -@app.get("/") -async def root(): - return {"message": "API is running"} +# @app.get("/") # This route is now handled by StaticFiles if index.html exists +# async def root(): +# return {"message": "API is running"} if __name__ == "__main__": import uvicorn diff --git a/frontend/src/components/chat/chatview.tsx b/frontend/src/components/chat/chatview.tsx index 7e33279..f5b9985 100644 --- a/frontend/src/components/chat/chatview.tsx +++ b/frontend/src/components/chat/chatview.tsx @@ -90,6 +90,12 @@ export default function ChatView({ }; socket.onmessage = (event) => { + // Handle heartbeat pong message + if (event.data === "pong") { + // console.log("Received pong"); // Optional: log pong if needed + return; // Ignore pong messages, do not parse as JSON + } + try { const logEvent = JSON.parse(event.data); console.log("Received event:", logEvent); @@ -117,9 +123,11 @@ export default function ChatView({ setMessages((prev) => prev.map((msg) => { if (msg.sessionId === sessionId && msg.sender === "bot") { + // update finalResponse and status return { ...msg, finalResponse: logEvent.content, + status: "complete", }; } return msg; diff --git a/frontend/src/components/chat/messagelist.tsx b/frontend/src/components/chat/messagelist.tsx index be8a315..8fe9213 100644 --- a/frontend/src/components/chat/messagelist.tsx +++ b/frontend/src/components/chat/messagelist.tsx @@ -50,7 +50,11 @@ export default function MessageList({