-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathllamaindex.py
41 lines (33 loc) · 1.24 KB
/
llamaindex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Optional
from llama_index.core.workflow import Workflow
from agent_workflow_server.agents.base import BaseAdapter, BaseAgent
from agent_workflow_server.services.message import Message
from agent_workflow_server.storage.models import Run
class LlamaIndexAdapter(BaseAdapter):
def load_agent(self, agent: object) -> Optional[BaseAgent]:
if callable(agent) and len(inspect.signature(agent).parameters) == 0:
result = agent()
if isinstance(result, Workflow):
return LlamaIndexAgent(result)
if isinstance(agent, Workflow):
return LlamaIndexAgent(agent)
return None
class LlamaIndexAgent(BaseAgent):
def __init__(self, agent: Workflow):
self.agent = agent
async def astream(self, run: Run):
input = run["input"]
handler = self.agent.run(**input)
async for event in handler.stream_events():
yield Message(
type="message",
data=event,
)
final_result = await handler
yield Message(
type="message",
data=final_result,
)