Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions multi_agents_pipeline/agents/Planning_Agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from .custom_messages import TextMessage, TSTaskMessage, TSMessage
from typing import Optional, List
from autogen_core import RoutedAgent, default_subscription, message_handler, MessageContext, TopicId
from autogen_core.models import ChatCompletionClient, SystemMessage
from pydantic import ValidationError
from pydantic import BaseModel


@default_subscription
class PlanningAgent(RoutedAgent):
"""A planning agent that uses OpenAI API to generate tasks for a Time Series Agent and QA Agent.

Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The ChatCompletion client.
"""
def __init__(self, name: str, model_client: ChatCompletionClient) -> None:
super().__init__("planning_agent")
self.name = name
self._model_client = model_client
self._system_messages = [SystemMessage(content="You are a helpful AI assistant.")]

async def send_message_to_openai(self, messages: List[SystemMessage], ctx: MessageContext, json_output: Optional[bool | BaseModel] = False) -> str:
"""Sends messages to OpenAI and returns the response content.

Args:
messages (List[SystemMessage]): The list of messages to send to OpenAI.

Returns:
str: The response content from OpenAI.
"""
response = await self._model_client.create(
messages=self._system_messages + messages,
cancellation_token=ctx.cancellation_token,
json_output=json_output)
if isinstance(response.content, str):
return response.content
else:
raise ValueError("Response content is not a valid JSON string")

async def generate_ts_task(self, original_message: TSTaskMessage, ctx: MessageContext) -> TSMessage:
"""Generates a time series task message based on the original message.

Args:
original_message (TSTaskMessage): The original message containing the task description and filepath.

Returns:
TSMessage: A new TSMessage with the task type and description.
"""
ts_message = SystemMessage(
source="user",
content=f"""The task for the time series analysis is: {original_message.description}.
The time-series data is stored at {original_message.filepath}. Provide a detailed description of the data
based on the task description. Also, provide what type of analysis would be required to complete the task among
the following types: ["statistical forecasting", "anomaly detection"].
"""
)

response_content = await self.send_message_to_openai([ts_message], ctx, json_output=TSMessage)

try:
ts_task = TSMessage.model_validate_json(response_content)
ts_task.source = "planner" # Set the source to the Planning Agent
ts_task.filepath = original_message.filepath # Ensure the filepath is preserved
# Send the generated task to the QA Agent
return ts_task
except ValidationError as e:
raise ValueError(f"Response content is not a valid TextMessage: {e}") from e

async def generate_qa_task(self, original_message: TSTaskMessage, ctx: MessageContext) -> TextMessage:
"""Generates a QA task message based on the original message.

Args:
original_message (TSTaskMessage): The original message containing the task description and filepath.

Returns:
TextMessage: A new TextMessage with the task description.
"""
task_message = SystemMessage(
source="user",
content=f"""Write a descriptive task for the following prompt: {original_message.description}.
The time-series data is stored at {original_message.filepath}.
"""
)

response_content = await self.send_message_to_openai([task_message], ctx, json_output=TextMessage)

try:
qa_task = TextMessage.model_validate_json(response_content)
qa_task.source = "planner" # Set the source to the Planning Agent
# Send the generated task to the QA Agent
return qa_task
except ValidationError as e:
raise ValueError(f"Response content is not a valid TextMessage: {e}") from e

@message_handler
async def handle_ts_task_message(self, message: TSTaskMessage, ctx: MessageContext) -> None:
"""Handles incoming time series task messages and generates a response using the OpenAI Assistant API.

Args:
message (TSTaskMessage): The incoming message containing the user's query.
"""
ts_task = await self.generate_ts_task(message, ctx)
print(f"[{self.name}] Sending TS task to TS Agent...")
await self.publish_message(
ts_task,
TopicId(type="Planner-TS", source=self.id.key)
)
#await self.send_message(ts_task, AgentId("ts_agent", "default"))

qa_task = await self.generate_qa_task(message, ctx)
print(f"[{self.name}] Sending QA task to QA Agent...")
await self.publish_message(
qa_task,
TopicId(type="Planner-QA", source=self.id.key)
)
#await self.send_message(qa_task, AgentId("qa_agent", "default"))
2 changes: 1 addition & 1 deletion multi_agents_pipeline/agents/QA_Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def handle_TS(self, message: TSMessage, ctx: MessageContext) -> None:
# below is the prompt that combine the task and the TS Info.
# TODO : Modify according to the task type and task description. Currently just a placeholder
prompt = f"""
You are a Time Serise Expert.
You are a Time Series Expert.

Here is a task given by the planner:
{self._last_plan or "(no plan received)"}
Expand Down
11 changes: 10 additions & 1 deletion multi_agents_pipeline/agents/custom_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,13 @@ class TSMessage(BaseModel):
source: str
filepath: str
task_type:Optional[str] = None
description: Optional[str] = None
description: Optional[str] = None

class TSTaskMessage(BaseModel):
"""
passed to Planner

This message contains a text prompt and the filepath to the data file.
"""
description: str
filepath: str
18 changes: 16 additions & 2 deletions multi_agents_pipeline/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from autogen_core import AgentId, SingleThreadedAgentRuntime, TopicId
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import (
ChatCompletionClient,
LLMMessage,
Expand All @@ -10,8 +9,9 @@
)
from agents.QA_Agent import QAAgent
from agents.TS_Agent import TSAgent
from agents.Planning_Agent import PlanningAgent
from agents.Reward_Agent import RewardAgent
from agents.custom_messages import TextMessage, TSMessage
from agents.custom_messages import TextMessage, TSMessage, TSTaskMessage
from autogen_core import TRACE_LOGGER_NAME
import aiofiles
import yaml
Expand All @@ -29,6 +29,12 @@ async def main() -> None:

model_client = await get_model_client(QA_MODEL_CONFIG_PATH)

await PlanningAgent.register(
runtime,
"Planning_Agent",
lambda: PlanningAgent(name="Planning_Agent", model_client=model_client),
)

await QAAgent.register(
runtime,
"QA_Agent",
Expand Down Expand Up @@ -85,6 +91,14 @@ async def main() -> None:
)


# mock a TSTaskMessage from user
# ts_task_message = TSTaskMessage(
# description="The file contains time series data of the hand motion of an actor raising their arm. From this data alone, tell me if the actor is raising a gun or pointing their finger.",
# filepath="../datasets/GunPointAgeSpan/GunPointAgeSpan_TRAIN.tsv"
# )
# await runtime.send_message(ts_task_message, AgentId("Planning_Agent", "default"))


await runtime.stop_when_idle()

if __name__ == "__main__":
Expand Down
71 changes: 57 additions & 14 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,36 +1,56 @@
accelerate==0.32.0
aiofiles==24.1.0
aiohappyeyeballs==2.4.4
aiohttp==3.11.10
annotated-types==0.7.0
anyio==4.6.2
asttokens==3.0.0
async-timeout==5.0.1
attrs==24.3.0
autogen-agentchat==0.5.1
autogen-core==0.5.1
aiofiles==24.1.0
aiohappyeyeballs==2.4.4
aiohttp==3.11.10
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.6.2
asttokens==3.0.0
async-timeout==5.0.1
attrs==24.3.0
autogen-agentchat==0.5.1
autogen-core==0.5.1
autogen-ext==0.5.1
certifi==2024.8.30
charset-normalizer==3.1.0
click==8.1.8
cmake==3.26.3
contourpy==1.0.7
coverage==7.8.0
cycler==0.11.0
decorator==5.2.1
Deprecated==1.2.18
distro==1.9.0
einops==0.6.0
exceptiongroup==1.2.2
executing==2.2.0
fastapi==0.112.2
filelock==3.12.0
fonttools==4.39.3
huggingface-hub
frozenlist==1.5.0
fsspec==2025.3.2
h11==0.14.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
idna==3.4
importlib-resources==5.12.0
importlib_metadata==8.4.0
iniconfig==2.0.0
ipdb==0.13.13
ipython==8.35.0
jedi==0.19.2
Jinja2==3.1.2
jiter==0.9.0
joblib==1.2.0
jsonref==1.1.0
kiwisolver==1.4.4
lit==16.0.1
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.7
mpmath==1.3.0
multidict==6.4.3
networkx==3.1
numpy==1.24.2
nvidia-cublas-cu11==11.10.3.66
Expand All @@ -44,11 +64,24 @@ nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
openai==1.73.0
opentelemetry-api==1.32.0
packaging==23.1
pandas==2.0.0
Pillow>=11.0.0
parso==0.8.4
peft==0.10.0
pexpect==4.9.0
pillow==11.1.0
pluggy==1.5.0
prompt_toolkit==3.0.50
propcache==0.3.1
protobuf==5.29.4
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
pydantic==2.11.3
pydantic_core==2.33.1
Pygments==2.19.1
pyparsing==3.0.9
pytest==8.1.1
pytest-cov==4.1.0
Expand All @@ -58,22 +91,32 @@ pytz==2023.3
PyYAML==6.0
regex==2023.3.23
requests==2.28.2
safetensors==0.5.3
scikit-learn==1.2.2
scipy==1.10.1
six==1.16.0
sniffio==1.3.1
stack-data==0.6.3
starlette==0.38.6
sympy==1.11.1
taos-ws-py==0.3.3
threadpoolctl==3.1.0
tokenizers>=0.19
tiktoken==0.9.0
tokenizers==0.19.1
tomli==2.0.2
torch==2.0.0
tqdm==4.65.0
traitlets==5.14.3
transformers==4.40.0
triton==2.0.0
typing-inspection==0.4.0
typing_extensions==4.12.2
tzdata==2023.3
uvicorn==0.34.0
urllib3==1.26.15
uvicorn==0.34.0
wcwidth==0.2.13
wrapt==1.17.2
yarl==1.19.0
zipp==3.15.0
ipdb
peft==0.10.0