|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +from typing import Optional, Type |
| 3 | + |
| 4 | +from autogen_core.models import ChatCompletionClient |
| 5 | +from autogen_core.tools import FunctionTool |
| 6 | +from autogen_agentchat.agents import AssistantAgent |
| 7 | +from autogen_agentchat.messages import ( |
| 8 | + TextMessage, |
| 9 | + ToolCallExecutionEvent, |
| 10 | + ToolCallRequestEvent, |
| 11 | + ModelClientStreamingChunkEvent, |
| 12 | +) |
| 13 | + |
| 14 | +from ..agents import Agent |
| 15 | +from ..schemas.context import Context |
| 16 | +from ..schemas.agent_schemas import ( |
| 17 | + Message, |
| 18 | + TextContent, |
| 19 | + DataContent, |
| 20 | + FunctionCall, |
| 21 | + FunctionCallOutput, |
| 22 | + MessageType, |
| 23 | + RunStatus, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +class AutogenContextAdapter: |
| 28 | + def __init__(self, context: Context, attr: dict): |
| 29 | + self.context = context |
| 30 | + self.attr = attr |
| 31 | + |
| 32 | + # Adapted attribute |
| 33 | + self.toolkit = None |
| 34 | + self.model = None |
| 35 | + self.memory = None |
| 36 | + self.new_message = None |
| 37 | + |
| 38 | + async def initialize(self): |
| 39 | + self.model = await self.adapt_model() |
| 40 | + self.memory = await self.adapt_memory() |
| 41 | + self.new_message = await self.adapt_new_message() |
| 42 | + self.toolkit = await self.adapt_tools() |
| 43 | + |
| 44 | + async def adapt_memory(self): |
| 45 | + messages = [] |
| 46 | + |
| 47 | + # Build context |
| 48 | + for msg in self.context.session.messages[:-1]: # Exclude the last one |
| 49 | + messages.append(AutogenContextAdapter.converter(msg)) |
| 50 | + |
| 51 | + return messages |
| 52 | + |
| 53 | + @staticmethod |
| 54 | + def converter(message: Message): |
| 55 | + # TODO: support more message type |
| 56 | + return TextMessage.load( |
| 57 | + { |
| 58 | + "id": message.id, |
| 59 | + "source": message.role, |
| 60 | + "content": message.content[0].text if message.content else "", |
| 61 | + }, |
| 62 | + ) |
| 63 | + |
| 64 | + async def adapt_new_message(self): |
| 65 | + last_message = self.context.session.messages[-1] |
| 66 | + |
| 67 | + return AutogenContextAdapter.converter(last_message) |
| 68 | + |
| 69 | + async def adapt_model(self): |
| 70 | + return self.attr["model"] |
| 71 | + |
| 72 | + async def adapt_tools(self): |
| 73 | + toolkit = self.attr["agent_config"].get("toolkit", []) |
| 74 | + tools = self.attr["tools"] |
| 75 | + |
| 76 | + # in case, tools is None and tools == [] |
| 77 | + if not tools: |
| 78 | + return toolkit |
| 79 | + |
| 80 | + if self.context.activate_tools: |
| 81 | + # Only add activated tool |
| 82 | + activated_tools = self.context.activate_tools |
| 83 | + else: |
| 84 | + from ...sandbox.tools.utils import setup_tools |
| 85 | + |
| 86 | + activated_tools = setup_tools( |
| 87 | + tools=self.attr["tools"], |
| 88 | + environment_manager=self.context.environment_manager, |
| 89 | + session_id=self.context.session.id, |
| 90 | + user_id=self.context.session.user_id, |
| 91 | + include_schemas=False, |
| 92 | + ) |
| 93 | + |
| 94 | + for tool in activated_tools: |
| 95 | + func = FunctionTool( |
| 96 | + func=tool.make_function(), |
| 97 | + description=tool.schema["function"]["description"], |
| 98 | + name=tool.name, |
| 99 | + ) |
| 100 | + toolkit.append(func) |
| 101 | + |
| 102 | + return toolkit |
| 103 | + |
| 104 | + |
| 105 | +class AutogenAgent(Agent): |
| 106 | + def __init__( |
| 107 | + self, |
| 108 | + name: str, |
| 109 | + model: ChatCompletionClient, |
| 110 | + tools=None, |
| 111 | + agent_config=None, |
| 112 | + agent_builder: Optional[Type[AssistantAgent]] = AssistantAgent, |
| 113 | + ): |
| 114 | + super().__init__(name=name, agent_config=agent_config) |
| 115 | + |
| 116 | + assert isinstance( |
| 117 | + model, |
| 118 | + ChatCompletionClient, |
| 119 | + ), "model must be a subclass of ChatCompletionClient in autogen" |
| 120 | + |
| 121 | + # Set default agent_builder |
| 122 | + if agent_builder is None: |
| 123 | + agent_builder = AssistantAgent |
| 124 | + |
| 125 | + assert issubclass( |
| 126 | + agent_builder, |
| 127 | + AssistantAgent, |
| 128 | + ), "agent_builder must be a subclass of AssistantAgent in autogen" |
| 129 | + |
| 130 | + # Replace name if not exists |
| 131 | + self.agent_config["name"] = self.agent_config.get("name") or name |
| 132 | + |
| 133 | + self._attr = { |
| 134 | + "model": model, |
| 135 | + "tools": tools, |
| 136 | + "agent_config": self.agent_config, |
| 137 | + "agent_builder": agent_builder, |
| 138 | + } |
| 139 | + self._agent = None |
| 140 | + self.tools = tools |
| 141 | + |
| 142 | + def copy(self) -> "AutogenAgent": |
| 143 | + return AutogenAgent(**self._attr) |
| 144 | + |
| 145 | + def build(self, as_context): |
| 146 | + self._agent = self._attr["agent_builder"]( |
| 147 | + **self._attr["agent_config"], |
| 148 | + model_client=as_context.model, |
| 149 | + tools=as_context.toolkit, |
| 150 | + ) |
| 151 | + |
| 152 | + return self._agent |
| 153 | + |
| 154 | + async def run(self, context): |
| 155 | + ag_context = AutogenContextAdapter(context=context, attr=self._attr) |
| 156 | + await ag_context.initialize() |
| 157 | + |
| 158 | + # We should always build a new agent since the state is manage outside |
| 159 | + # the agent |
| 160 | + self._agent = self.build(ag_context) |
| 161 | + |
| 162 | + resp = self._agent.run_stream( |
| 163 | + task=ag_context.memory + [ag_context.new_message], |
| 164 | + ) |
| 165 | + |
| 166 | + text_message = Message( |
| 167 | + type=MessageType.MESSAGE, |
| 168 | + role="assistant", |
| 169 | + status=RunStatus.InProgress, |
| 170 | + ) |
| 171 | + yield text_message |
| 172 | + |
| 173 | + text_delta_content = TextContent(delta=True) |
| 174 | + is_text_delta = False |
| 175 | + stream_mode = False |
| 176 | + async for event in resp: |
| 177 | + if getattr(event, "source", "user") == "user": |
| 178 | + continue |
| 179 | + |
| 180 | + if isinstance(event, TextMessage): |
| 181 | + if stream_mode: |
| 182 | + continue |
| 183 | + is_text_delta = True |
| 184 | + text_delta_content.text = event.content |
| 185 | + text_delta_content = text_message.add_delta_content( |
| 186 | + new_content=text_delta_content, |
| 187 | + ) |
| 188 | + yield text_delta_content |
| 189 | + elif isinstance(event, ModelClientStreamingChunkEvent): |
| 190 | + stream_mode = True |
| 191 | + is_text_delta = True |
| 192 | + text_delta_content.text = event.content |
| 193 | + text_delta_content = text_message.add_delta_content( |
| 194 | + new_content=text_delta_content, |
| 195 | + ) |
| 196 | + yield text_delta_content |
| 197 | + elif isinstance(event, ToolCallRequestEvent): |
| 198 | + data = DataContent( |
| 199 | + data=FunctionCall( |
| 200 | + call_id=event.id, |
| 201 | + name=event.content[0].name, |
| 202 | + arguments=event.content[0].arguments, |
| 203 | + ).model_dump(), |
| 204 | + ) |
| 205 | + message = Message( |
| 206 | + type=MessageType.PLUGIN_CALL, |
| 207 | + role="assistant", |
| 208 | + status=RunStatus.Completed, |
| 209 | + content=[data], |
| 210 | + ) |
| 211 | + yield message |
| 212 | + elif isinstance(event, ToolCallExecutionEvent): |
| 213 | + data = DataContent( |
| 214 | + data=FunctionCallOutput( |
| 215 | + call_id=event.id, |
| 216 | + output=event.content[0].content, |
| 217 | + ).model_dump(), |
| 218 | + ) |
| 219 | + message = Message( |
| 220 | + type=MessageType.PLUGIN_CALL_OUTPUT, |
| 221 | + role="assistant", |
| 222 | + status=RunStatus.Completed, |
| 223 | + content=[data], |
| 224 | + ) |
| 225 | + yield message |
| 226 | + |
| 227 | + # Add to message |
| 228 | + is_text_delta = True |
| 229 | + text_delta_content.text = event.content[0].content |
| 230 | + text_delta_content = text_message.add_delta_content( |
| 231 | + new_content=text_delta_content, |
| 232 | + ) |
| 233 | + yield text_delta_content |
| 234 | + |
| 235 | + if is_text_delta: |
| 236 | + yield text_message.content_completed(text_delta_content.index) |
| 237 | + yield text_message.completed() |
| 238 | + |
| 239 | + async def run_async( |
| 240 | + self, |
| 241 | + context, |
| 242 | + **kwargs, |
| 243 | + ): |
| 244 | + async for event in self.run(context): |
| 245 | + yield event |
0 commit comments