diff --git a/superagi/agent/agent_iteration_step_handler.py b/superagi/agent/agent_iteration_step_handler.py index 0e2c4fcec..c92bb83ed 100644 --- a/superagi/agent/agent_iteration_step_handler.py +++ b/superagi/agent/agent_iteration_step_handler.py @@ -64,11 +64,13 @@ def execute_step(self): agent_config=agent_config, agent_execution_config=agent_execution_config, prompt=iteration_workflow_step.prompt, - agent_tools=agent_tools) + agent_tools=agent_tools, + branching_enabled=iteration_workflow_step.branching_enabled) messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \ .build_agent_messages(prompt, agent_feeds, history_enabled=iteration_workflow_step.history_enabled, - completion_prompt=iteration_workflow_step.completion_prompt) + completion_prompt=iteration_workflow_step.completion_prompt, + branching_enabled=iteration_workflow_step.branching_enabled) logger.debug("Prompt messages:", messages) current_tokens = TokenCounter.count_message_tokens(messages = messages, model = self.llm.get_model()) @@ -134,12 +136,13 @@ def _update_agent_execution_next_step(self, execution, next_step_id, step_respon def _build_agent_prompt(self, iteration_workflow: IterationWorkflow, agent_config: dict, agent_execution_config: dict, - prompt: str, agent_tools: list): + prompt: str, agent_tools: list, branching_enabled: bool = False): max_token_limit = int(get_config("MAX_TOOL_TOKEN_LIMIT", 600)) prompt = AgentPromptBuilder.replace_main_variables(prompt, agent_execution_config["goal"], agent_execution_config["instruction"], agent_config["constraints"], agent_tools, - (not iteration_workflow.has_task_queue)) + (not iteration_workflow.has_task_queue), + branching_enabled=branching_enabled) if iteration_workflow.has_task_queue: response = self.task_queue.get_last_task_details() last_task, last_task_result = (response["task"], response["response"]) if response is not None else ("", "") @@ -147,7 +150,8 @@ def _build_agent_prompt(self, iteration_workflow: IterationWorkflow, agent_confi token_limit = TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit() - max_token_limit prompt = AgentPromptBuilder.replace_task_based_variables(prompt, current_task, last_task, last_task_result, self.task_queue.get_tasks(), - self.task_queue.get_completed_tasks(), token_limit) + self.task_queue.get_completed_tasks(), token_limit, + branching_enabled=branching_enabled) return prompt def _build_tools(self, agent_config: dict, agent_execution_config: dict): diff --git a/superagi/agent/agent_message_builder.py b/superagi/agent/agent_message_builder.py index 63698b874..3931a6fd6 100644 --- a/superagi/agent/agent_message_builder.py +++ b/superagi/agent/agent_message_builder.py @@ -24,7 +24,7 @@ def __init__(self, session, llm, llm_model: str, agent_id: int, agent_execution_ self.organisation = Agent.find_org_by_agent_id(self.session, self.agent_id) def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=False, - completion_prompt: str = None): + completion_prompt: str = None, branching_enabled=False): """ Build agent messages for LLM agent. Args: @@ -32,6 +32,7 @@ def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=F agent_feeds (list): The list of agent feeds. history_enabled (bool): Whether to use history or not. completion_prompt (str): The completion prompt to be used for generating the agent messages. + branching_enabled (bool): Whether to use branching logic for Tree of Thought (ToT). """ token_limit = TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm_model) max_output_token_limit = int(get_config("MAX_TOOL_TOKEN_LIMIT", 800)) @@ -42,7 +43,8 @@ def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=F full_message_history = [{'role': agent_feed.role, 'content': agent_feed.feed, 'chat_id': agent_feed.id} for agent_feed in agent_feeds] past_messages, current_messages = self._split_history(full_message_history, - ((token_limit - base_token_limit - max_output_token_limit) // 4) * 3) + ((token_limit - base_token_limit - max_output_token_limit) // 4) * 3, + branching_enabled) if past_messages: ltm_summary = self._build_ltm_summary(past_messages=past_messages, output_token_limit=(token_limit - base_token_limit - max_output_token_limit) // 4) @@ -56,7 +58,7 @@ def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=F self._add_initial_feeds(agent_feeds, messages) return messages - def _split_history(self, history: List, pending_token_limit: int) -> Tuple[List[BaseMessage], List[BaseMessage]]: + def _split_history(self, history: List, pending_token_limit: int, branching_enabled: bool) -> Tuple[List[BaseMessage], List[BaseMessage]]: hist_token_count = 0 i = len(history) for message in reversed(history): @@ -65,10 +67,21 @@ def _split_history(self, history: List, pending_token_limit: int) -> Tuple[List[ hist_token_count += token_count if hist_token_count > pending_token_limit: self._add_or_update_last_agent_feed_ltm_summary_id(str(history[i-1]['chat_id'])) + if branching_enabled: + return self._split_history_with_branching(history[:i], history[i:]) return history[:i], history[i:] i -= 1 return [], history + def _split_history_with_branching(self, past_messages: List[BaseMessage], current_messages: List[BaseMessage]) -> Tuple[List[BaseMessage], List[BaseMessage]]: + # Implement branching logic for Tree of Thought (ToT) + # This is a placeholder implementation, you can customize it based on your requirements + branches = [] + for message in past_messages: + if "branch" in message["content"]: + branches.append(message) + return branches, current_messages + def _add_initial_feeds(self, agent_feeds: list, messages: list): if agent_feeds: return diff --git a/superagi/agent/agent_prompt_builder.py b/superagi/agent/agent_prompt_builder.py index 4b9bce554..36ed8eada 100644 --- a/superagi/agent/agent_prompt_builder.py +++ b/superagi/agent/agent_prompt_builder.py @@ -64,7 +64,7 @@ def clean_prompt(cls, prompt): @classmethod def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instructions: List[str], constraints: List[str], - tools: List[BaseTool], add_finish_tool: bool = True): + tools: List[BaseTool], add_finish_tool: bool = True, branching_enabled: bool = False): """Replace the main variables in the super agi prompt. Args: @@ -74,6 +74,7 @@ def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instruc constraints (List[str]): The list of constraints. tools (List[BaseTool]): The list of tools. add_finish_tool (bool): Whether to add finish tool or not. + branching_enabled (bool): Whether to enable branching logic for Tree of Thought (ToT). """ super_agi_prompt = super_agi_prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(goals)) if len(instructions) > 0 and len(instructions[0]) > 0: @@ -90,11 +91,15 @@ def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instruc # logger.info(tools) tools_string = AgentPromptBuilder.add_tools_to_prompt(tools, add_finish_tool) super_agi_prompt = super_agi_prompt.replace("{tools}", tools_string) + + if branching_enabled: + super_agi_prompt += "\nNote: Branching logic is enabled for this agent." + return super_agi_prompt @classmethod def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str, last_task: str, - last_task_result: str, pending_tasks: List[str], completed_tasks: list, token_limit: int): + last_task_result: str, pending_tasks: List[str], completed_tasks: list, token_limit: int, branching_enabled: bool = False): """Replace the task based variables in the super agi prompt. Args: @@ -105,6 +110,7 @@ def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str, pending_tasks (List[str]): The list of pending tasks. completed_tasks (list): The list of completed tasks. token_limit (int): The token limit. + branching_enabled (bool): Whether to enable branching logic for Tree of Thought (ToT). """ if "{current_task}" in super_agi_prompt: super_agi_prompt = super_agi_prompt.replace("{current_task}", current_task) @@ -133,4 +139,8 @@ def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str, if token_count > min(600, pending_tokens): break super_agi_prompt = super_agi_prompt.replace("{task_history}", "\n" + final_output + "\n") + + if branching_enabled: + super_agi_prompt += "\nNote: Branching logic is enabled for this agent." + return super_agi_prompt diff --git a/superagi/agent/agent_tool_step_handler.py b/superagi/agent/agent_tool_step_handler.py index 5b8c1c127..c1bbdd94e 100644 --- a/superagi/agent/agent_tool_step_handler.py +++ b/superagi/agent/agent_tool_step_handler.py @@ -102,7 +102,8 @@ def _process_input_instruction(self, agent_config, agent_execution_config, step_ agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id) messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \ .build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled, - completion_prompt=step_tool.completion_prompt) + completion_prompt=step_tool.completion_prompt, + branching_enabled=workflow_step.branching_enabled) # print(messages) current_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model()) response = self.llm.chat_completion(messages, TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm.get_model()) - current_tokens) diff --git a/superagi/agent/queue_step_handler.py b/superagi/agent/queue_step_handler.py index fcd9baf1f..00ba3348d 100644 --- a/superagi/agent/queue_step_handler.py +++ b/superagi/agent/queue_step_handler.py @@ -1,7 +1,5 @@ import time - import numpy as np - from superagi.agent.agent_message_builder import AgentLlmMessageBuilder from superagi.agent.task_queue import TaskQueue from superagi.helper.error_handler import ErrorHandler @@ -16,7 +14,6 @@ from superagi.models.agent import Agent from superagi.types.queue_status import QueueStatus - class QueueStepHandler: """Handles the queue step of the agent workflow""" def __init__(self, session, llm, agent_id: int, agent_execution_id: int): diff --git a/superagi/agent/task_queue.py b/superagi/agent/task_queue.py index 6acddf3f3..502f4be11 100644 --- a/superagi/agent/task_queue.py +++ b/superagi/agent/task_queue.py @@ -48,3 +48,18 @@ def set_status(self, status): def get_status(self): return self.db.get(self.queue_name + "_status") + def add_branching_task(self, task: str, branch_id: int): + self.db.lpush(self.queue_name + f"_branch_{branch_id}", task) + + def complete_branching_task(self, response, branch_id: int): + if len(self.get_branching_tasks(branch_id)) <= 0: + return + task = self.db.lpop(self.queue_name + f"_branch_{branch_id}") + self.db.lpush(self.completed_tasks + f"_branch_{branch_id}", str({"task": task, "response": response})) + + def get_branching_tasks(self, branch_id: int): + return self.db.lrange(self.queue_name + f"_branch_{branch_id}", 0, -1) + + def get_completed_branching_tasks(self, branch_id: int): + tasks = self.db.lrange(self.completed_tasks + f"_branch_{branch_id}", 0, -1) + return [eval(task) for task in tasks] diff --git a/superagi/agent/workflow_seed.py b/superagi/agent/workflow_seed.py index 137776cba..bea35ac4c 100644 --- a/superagi/agent/workflow_seed.py +++ b/superagi/agent/workflow_seed.py @@ -214,6 +214,15 @@ def build_fixed_task_based_agent(cls, session): AgentWorkflowStep.add_next_workflow_step(session, step2.id, step2.id) AgentWorkflowStep.add_next_workflow_step(session, step2.id, -1, "COMPLETE") + @classmethod + def build_tree_of_thought_agent(cls, session): + agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Tree of Thought Workflow", "Tree of Thought Workflow") + step1 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id, + str(agent_workflow.id) + "_step1", + "Tree of Thought Agent-I", step_type="TRIGGER") + AgentWorkflowStep.add_next_workflow_step(session, step1.id, step1.id) + AgentWorkflowStep.add_next_workflow_step(session, step1.id, -1, "COMPLETE") + class IterationWorkflowSeed: @classmethod @@ -267,3 +276,14 @@ def build_action_based_agents(cls, session): output = AgentPromptTemplate.analyse_task() IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "ab1", output["prompt"], str(output["variables"]), "TRIGGER", "tools") + + @classmethod + def build_tree_of_thought_agent(cls, session): + iteration_workflow = IterationWorkflow.find_or_create_by_name(session, "Tree of Thought Agent-I", "Tree of Thought Agent") + output = AgentPromptTemplate.get_super_agi_single_prompt() + IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "tot1", + output["prompt"], + str(output["variables"]), "TRIGGER", "tools", + history_enabled=True, + completion_prompt="Determine which next tool to use, and respond using the format specified above:", + branching_enabled=True) diff --git a/superagi/models/workflows/iteration_workflow.py b/superagi/models/workflows/iteration_workflow.py index e00a80d0e..62f3091ff 100644 --- a/superagi/models/workflows/iteration_workflow.py +++ b/superagi/models/workflows/iteration_workflow.py @@ -133,4 +133,4 @@ def find_or_create_by_name(cls, session, name: str, description: str, has_task_q @classmethod def find_by_id(cls, session, id: int): """ Find the workflow step by id""" - return session.query(IterationWorkflow).filter(IterationWorkflow.id == id).first() \ No newline at end of file + return session.query(IterationWorkflow).filter(IterationWorkflow.id == id).first() diff --git a/superagi/models/workflows/iteration_workflow_step.py b/superagi/models/workflows/iteration_workflow_step.py index 15bd7ab22..4257c68cb 100644 --- a/superagi/models/workflows/iteration_workflow_step.py +++ b/superagi/models/workflows/iteration_workflow_step.py @@ -21,6 +21,7 @@ class IterationWorkflowStep(DBBaseModel): next_step_id (int): The ID of the next step in the workflow. history_enabled (bool): Indicates whether history is enabled for the step. completion_prompt (str): The completion prompt for the step. + branching_enabled (bool): Indicates whether branching is enabled for the step. """ __tablename__ = 'iteration_workflow_steps' @@ -35,6 +36,7 @@ class IterationWorkflowStep(DBBaseModel): next_step_id = Column(Integer) history_enabled = Column(Boolean) completion_prompt = Column(Text) + branching_enabled = Column(Boolean, default=False) def __repr__(self): """ @@ -99,7 +101,7 @@ def find_by_id(cls, session, step_id: int): @classmethod def find_or_create_step(self, session, iteration_workflow_id: int, unique_id: str, prompt: str, variables: str, step_type: str, output_type: str, - completion_prompt: str = "", history_enabled: bool = False): + completion_prompt: str = "", history_enabled: bool = False, branching_enabled: bool = False): workflow_step = session.query(IterationWorkflowStep).filter(IterationWorkflowStep.unique_id == unique_id).first() if workflow_step is None: workflow_step = IterationWorkflowStep(unique_id=unique_id) @@ -113,10 +115,8 @@ def find_or_create_step(self, session, iteration_workflow_id: int, unique_id: st workflow_step.iteration_workflow_id = iteration_workflow_id workflow_step.next_step_id = -1 workflow_step.history_enabled = history_enabled + workflow_step.branching_enabled = branching_enabled if completion_prompt: workflow_step.completion_prompt = completion_prompt session.commit() return workflow_step - - - diff --git a/superagi/tools/thinking/tools.py b/superagi/tools/thinking/tools.py index 370eeee4f..6d1488693 100644 --- a/superagi/tools/thinking/tools.py +++ b/superagi/tools/thinking/tools.py @@ -18,6 +18,10 @@ class ThinkingSchema(BaseModel): ..., description="Task description which needs reasoning.", ) + branching_enabled: bool = Field( + default=False, + description="Flag to enable branching logic for Tree of Thought (ToT).", + ) class ThinkingTool(BaseTool): """ @@ -45,12 +49,13 @@ class Config: arbitrary_types_allowed = True - def _execute(self, task_description: str): + def _execute(self, task_description: str, branching_enabled: bool = False): """ Execute the Thinking tool. Args: task_description : The task description. + branching_enabled : Flag to enable branching logic for Tree of Thought (ToT). Returns: Thought process of llm for the task @@ -64,6 +69,8 @@ def _execute(self, task_description: str): metadata = {"agent_execution_id":self.agent_execution_id} relevant_tool_response = self.tool_response_manager.get_relevant_response(query=task_description,metadata=metadata) prompt = prompt.replace("{relevant_tool_response}",relevant_tool_response) + if branching_enabled: + prompt += "\nNote: Branching logic is enabled for this task." messages = [{"role": "system", "content": prompt}] result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit) @@ -72,4 +79,4 @@ def _execute(self, task_description: str): return result["content"] except Exception as e: logger.error(e) - return f"Error generating text: {e}" \ No newline at end of file + return f"Error generating text: {e}"