Skip to content

Support for both Tree of Thought (ToT) and Chain of Thought (CoT) #1454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions superagi/agent/agent_iteration_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def execute_step(self):
agent_config=agent_config,
agent_execution_config=agent_execution_config,
prompt=iteration_workflow_step.prompt,
agent_tools=agent_tools)
agent_tools=agent_tools,
branching_enabled=iteration_workflow_step.branching_enabled)

messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=iteration_workflow_step.history_enabled,
completion_prompt=iteration_workflow_step.completion_prompt)
completion_prompt=iteration_workflow_step.completion_prompt,
branching_enabled=iteration_workflow_step.branching_enabled)

logger.debug("Prompt messages:", messages)
current_tokens = TokenCounter.count_message_tokens(messages = messages, model = self.llm.get_model())
Expand Down Expand Up @@ -134,20 +136,22 @@ def _update_agent_execution_next_step(self, execution, next_step_id, step_respon

def _build_agent_prompt(self, iteration_workflow: IterationWorkflow, agent_config: dict,
agent_execution_config: dict,
prompt: str, agent_tools: list):
prompt: str, agent_tools: list, branching_enabled: bool = False):
max_token_limit = int(get_config("MAX_TOOL_TOKEN_LIMIT", 600))
prompt = AgentPromptBuilder.replace_main_variables(prompt, agent_execution_config["goal"],
agent_execution_config["instruction"],
agent_config["constraints"], agent_tools,
(not iteration_workflow.has_task_queue))
(not iteration_workflow.has_task_queue),
branching_enabled=branching_enabled)
if iteration_workflow.has_task_queue:
response = self.task_queue.get_last_task_details()
last_task, last_task_result = (response["task"], response["response"]) if response is not None else ("", "")
current_task = self.task_queue.get_first_task() or ""
token_limit = TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit() - max_token_limit
prompt = AgentPromptBuilder.replace_task_based_variables(prompt, current_task, last_task, last_task_result,
self.task_queue.get_tasks(),
self.task_queue.get_completed_tasks(), token_limit)
self.task_queue.get_completed_tasks(), token_limit,
branching_enabled=branching_enabled)
return prompt

def _build_tools(self, agent_config: dict, agent_execution_config: dict):
Expand Down
19 changes: 16 additions & 3 deletions superagi/agent/agent_message_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def __init__(self, session, llm, llm_model: str, agent_id: int, agent_execution_
self.organisation = Agent.find_org_by_agent_id(self.session, self.agent_id)

def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=False,
completion_prompt: str = None):
completion_prompt: str = None, branching_enabled=False):
""" Build agent messages for LLM agent.

Args:
prompt (str): The prompt to be used for generating the agent messages.
agent_feeds (list): The list of agent feeds.
history_enabled (bool): Whether to use history or not.
completion_prompt (str): The completion prompt to be used for generating the agent messages.
branching_enabled (bool): Whether to use branching logic for Tree of Thought (ToT).
"""
token_limit = TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm_model)
max_output_token_limit = int(get_config("MAX_TOOL_TOKEN_LIMIT", 800))
Expand All @@ -42,7 +43,8 @@ def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=F
full_message_history = [{'role': agent_feed.role, 'content': agent_feed.feed, 'chat_id': agent_feed.id}
for agent_feed in agent_feeds]
past_messages, current_messages = self._split_history(full_message_history,
((token_limit - base_token_limit - max_output_token_limit) // 4) * 3)
((token_limit - base_token_limit - max_output_token_limit) // 4) * 3,
branching_enabled)
if past_messages:
ltm_summary = self._build_ltm_summary(past_messages=past_messages,
output_token_limit=(token_limit - base_token_limit - max_output_token_limit) // 4)
Expand All @@ -56,7 +58,7 @@ def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=F
self._add_initial_feeds(agent_feeds, messages)
return messages

def _split_history(self, history: List, pending_token_limit: int) -> Tuple[List[BaseMessage], List[BaseMessage]]:
def _split_history(self, history: List, pending_token_limit: int, branching_enabled: bool) -> Tuple[List[BaseMessage], List[BaseMessage]]:
hist_token_count = 0
i = len(history)
for message in reversed(history):
Expand All @@ -65,10 +67,21 @@ def _split_history(self, history: List, pending_token_limit: int) -> Tuple[List[
hist_token_count += token_count
if hist_token_count > pending_token_limit:
self._add_or_update_last_agent_feed_ltm_summary_id(str(history[i-1]['chat_id']))
if branching_enabled:
return self._split_history_with_branching(history[:i], history[i:])
return history[:i], history[i:]
i -= 1
return [], history

def _split_history_with_branching(self, past_messages: List[BaseMessage], current_messages: List[BaseMessage]) -> Tuple[List[BaseMessage], List[BaseMessage]]:
# Implement branching logic for Tree of Thought (ToT)
# This is a placeholder implementation, you can customize it based on your requirements
branches = []
for message in past_messages:
if "branch" in message["content"]:
branches.append(message)
return branches, current_messages

def _add_initial_feeds(self, agent_feeds: list, messages: list):
if agent_feeds:
return
Expand Down
14 changes: 12 additions & 2 deletions superagi/agent/agent_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def clean_prompt(cls, prompt):

@classmethod
def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instructions: List[str], constraints: List[str],
tools: List[BaseTool], add_finish_tool: bool = True):
tools: List[BaseTool], add_finish_tool: bool = True, branching_enabled: bool = False):
"""Replace the main variables in the super agi prompt.

Args:
Expand All @@ -74,6 +74,7 @@ def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instruc
constraints (List[str]): The list of constraints.
tools (List[BaseTool]): The list of tools.
add_finish_tool (bool): Whether to add finish tool or not.
branching_enabled (bool): Whether to enable branching logic for Tree of Thought (ToT).
"""
super_agi_prompt = super_agi_prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(goals))
if len(instructions) > 0 and len(instructions[0]) > 0:
Expand All @@ -90,11 +91,15 @@ def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instruc
# logger.info(tools)
tools_string = AgentPromptBuilder.add_tools_to_prompt(tools, add_finish_tool)
super_agi_prompt = super_agi_prompt.replace("{tools}", tools_string)

if branching_enabled:
super_agi_prompt += "\nNote: Branching logic is enabled for this agent."

return super_agi_prompt

@classmethod
def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str, last_task: str,
last_task_result: str, pending_tasks: List[str], completed_tasks: list, token_limit: int):
last_task_result: str, pending_tasks: List[str], completed_tasks: list, token_limit: int, branching_enabled: bool = False):
"""Replace the task based variables in the super agi prompt.

Args:
Expand All @@ -105,6 +110,7 @@ def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str,
pending_tasks (List[str]): The list of pending tasks.
completed_tasks (list): The list of completed tasks.
token_limit (int): The token limit.
branching_enabled (bool): Whether to enable branching logic for Tree of Thought (ToT).
"""
if "{current_task}" in super_agi_prompt:
super_agi_prompt = super_agi_prompt.replace("{current_task}", current_task)
Expand Down Expand Up @@ -133,4 +139,8 @@ def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str,
if token_count > min(600, pending_tokens):
break
super_agi_prompt = super_agi_prompt.replace("{task_history}", "\n" + final_output + "\n")

if branching_enabled:
super_agi_prompt += "\nNote: Branching logic is enabled for this agent."

return super_agi_prompt
3 changes: 2 additions & 1 deletion superagi/agent/agent_tool_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def _process_input_instruction(self, agent_config, agent_execution_config, step_
agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id)
messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled,
completion_prompt=step_tool.completion_prompt)
completion_prompt=step_tool.completion_prompt,
branching_enabled=workflow_step.branching_enabled)
# print(messages)
current_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
response = self.llm.chat_completion(messages, TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm.get_model()) - current_tokens)
Expand Down
3 changes: 0 additions & 3 deletions superagi/agent/queue_step_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import time

import numpy as np

from superagi.agent.agent_message_builder import AgentLlmMessageBuilder
from superagi.agent.task_queue import TaskQueue
from superagi.helper.error_handler import ErrorHandler
Expand All @@ -16,7 +14,6 @@
from superagi.models.agent import Agent
from superagi.types.queue_status import QueueStatus


class QueueStepHandler:
"""Handles the queue step of the agent workflow"""
def __init__(self, session, llm, agent_id: int, agent_execution_id: int):
Expand Down
15 changes: 15 additions & 0 deletions superagi/agent/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,18 @@ def set_status(self, status):
def get_status(self):
return self.db.get(self.queue_name + "_status")

def add_branching_task(self, task: str, branch_id: int):
self.db.lpush(self.queue_name + f"_branch_{branch_id}", task)

def complete_branching_task(self, response, branch_id: int):
if len(self.get_branching_tasks(branch_id)) <= 0:
return
task = self.db.lpop(self.queue_name + f"_branch_{branch_id}")
self.db.lpush(self.completed_tasks + f"_branch_{branch_id}", str({"task": task, "response": response}))

def get_branching_tasks(self, branch_id: int):
return self.db.lrange(self.queue_name + f"_branch_{branch_id}", 0, -1)

def get_completed_branching_tasks(self, branch_id: int):
tasks = self.db.lrange(self.completed_tasks + f"_branch_{branch_id}", 0, -1)
return [eval(task) for task in tasks]
20 changes: 20 additions & 0 deletions superagi/agent/workflow_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ def build_fixed_task_based_agent(cls, session):
AgentWorkflowStep.add_next_workflow_step(session, step2.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, -1, "COMPLETE")

@classmethod
def build_tree_of_thought_agent(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Tree of Thought Workflow", "Tree of Thought Workflow")
step1 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step1",
"Tree of Thought Agent-I", step_type="TRIGGER")
AgentWorkflowStep.add_next_workflow_step(session, step1.id, step1.id)
AgentWorkflowStep.add_next_workflow_step(session, step1.id, -1, "COMPLETE")


class IterationWorkflowSeed:
@classmethod
Expand Down Expand Up @@ -267,3 +276,14 @@ def build_action_based_agents(cls, session):
output = AgentPromptTemplate.analyse_task()
IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "ab1",
output["prompt"], str(output["variables"]), "TRIGGER", "tools")

@classmethod
def build_tree_of_thought_agent(cls, session):
iteration_workflow = IterationWorkflow.find_or_create_by_name(session, "Tree of Thought Agent-I", "Tree of Thought Agent")
output = AgentPromptTemplate.get_super_agi_single_prompt()
IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "tot1",
output["prompt"],
str(output["variables"]), "TRIGGER", "tools",
history_enabled=True,
completion_prompt="Determine which next tool to use, and respond using the format specified above:",
branching_enabled=True)
2 changes: 1 addition & 1 deletion superagi/models/workflows/iteration_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ def find_or_create_by_name(cls, session, name: str, description: str, has_task_q
@classmethod
def find_by_id(cls, session, id: int):
""" Find the workflow step by id"""
return session.query(IterationWorkflow).filter(IterationWorkflow.id == id).first()
return session.query(IterationWorkflow).filter(IterationWorkflow.id == id).first()
8 changes: 4 additions & 4 deletions superagi/models/workflows/iteration_workflow_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class IterationWorkflowStep(DBBaseModel):
next_step_id (int): The ID of the next step in the workflow.
history_enabled (bool): Indicates whether history is enabled for the step.
completion_prompt (str): The completion prompt for the step.
branching_enabled (bool): Indicates whether branching is enabled for the step.
"""

__tablename__ = 'iteration_workflow_steps'
Expand All @@ -35,6 +36,7 @@ class IterationWorkflowStep(DBBaseModel):
next_step_id = Column(Integer)
history_enabled = Column(Boolean)
completion_prompt = Column(Text)
branching_enabled = Column(Boolean, default=False)

def __repr__(self):
"""
Expand Down Expand Up @@ -99,7 +101,7 @@ def find_by_id(cls, session, step_id: int):
@classmethod
def find_or_create_step(self, session, iteration_workflow_id: int, unique_id: str,
prompt: str, variables: str, step_type: str, output_type: str,
completion_prompt: str = "", history_enabled: bool = False):
completion_prompt: str = "", history_enabled: bool = False, branching_enabled: bool = False):
workflow_step = session.query(IterationWorkflowStep).filter(IterationWorkflowStep.unique_id == unique_id).first()
if workflow_step is None:
workflow_step = IterationWorkflowStep(unique_id=unique_id)
Expand All @@ -113,10 +115,8 @@ def find_or_create_step(self, session, iteration_workflow_id: int, unique_id: st
workflow_step.iteration_workflow_id = iteration_workflow_id
workflow_step.next_step_id = -1
workflow_step.history_enabled = history_enabled
workflow_step.branching_enabled = branching_enabled
if completion_prompt:
workflow_step.completion_prompt = completion_prompt
session.commit()
return workflow_step



11 changes: 9 additions & 2 deletions superagi/tools/thinking/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class ThinkingSchema(BaseModel):
...,
description="Task description which needs reasoning.",
)
branching_enabled: bool = Field(
default=False,
description="Flag to enable branching logic for Tree of Thought (ToT).",
)

class ThinkingTool(BaseTool):
"""
Expand Down Expand Up @@ -45,12 +49,13 @@ class Config:
arbitrary_types_allowed = True


def _execute(self, task_description: str):
def _execute(self, task_description: str, branching_enabled: bool = False):
"""
Execute the Thinking tool.

Args:
task_description : The task description.
branching_enabled : Flag to enable branching logic for Tree of Thought (ToT).

Returns:
Thought process of llm for the task
Expand All @@ -64,6 +69,8 @@ def _execute(self, task_description: str):
metadata = {"agent_execution_id":self.agent_execution_id}
relevant_tool_response = self.tool_response_manager.get_relevant_response(query=task_description,metadata=metadata)
prompt = prompt.replace("{relevant_tool_response}",relevant_tool_response)
if branching_enabled:
prompt += "\nNote: Branching logic is enabled for this task."
messages = [{"role": "system", "content": prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)

Expand All @@ -72,4 +79,4 @@ def _execute(self, task_description: str):
return result["content"]
except Exception as e:
logger.error(e)
return f"Error generating text: {e}"
return f"Error generating text: {e}"