-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllamaindex_workflow.py
116 lines (91 loc) · 3.3 KB
/
llamaindex_workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
from pydantic import BaseModel, Field
from agntcy_iomapper import IOMappingAgent, IOMappingAgentMetadata
from agntcy_iomapper.llamaindex import (
IOMappingInputEvent,
IOMappingOutputEvent,
LLamaIndexIOMapperConfig,
)
from examples.llm import Framework, get_azure
from examples.models import Campaign, Statistics, User
from examples.models.data import users
class PickUsersEvent(Event):
prompt: str
class CreateCampaignEvent(Event):
list_users: List[User]
class CampaignCreatedEvent(Event):
campaign: Campaign
class OverallState(BaseModel):
campaign_details: Campaign
stats: Optional[Statistics] = Field(None)
selected_users: List[User]
class CampaignWorkflow(Workflow):
@step
async def prompt_step(self, ctx: Context, ev: StartEvent) -> PickUsersEvent:
await ctx.set("llm", ev.get("llm"))
return PickUsersEvent(prompt=ev.get("prompt"))
@step
async def pick_users_step(
self, ctx: Context, ev: PickUsersEvent
) -> CreateCampaignEvent:
return CreateCampaignEvent(list_users=users)
@step
async def create_campaign(
self, ctx: Context, ev: CreateCampaignEvent
) -> IOMappingInputEvent:
prompt = f"""
You are a campaign builder for company XYZ. Given a list of selected users and a user prompt, create an engaging campaign.
Return the campaign details as a JSON object with the following structure:
{{
"name": "Campaign Name",
"content": "Campaign Content",
"is_urgent": yes/no
}}
Selected Users: {ev.list_users}
User Prompt: Create a campaign for all users
"""
parser = PydanticOutputParser(output_cls=Campaign)
llm = await ctx.get("llm", default=None)
llm_response = llm.complete(prompt)
try:
campaign_details = parser.parse(str(llm_response))
metadata = IOMappingAgentMetadata(
input_fields=["selected_users", "campaign_details.name"],
output_fields=["stats"],
)
config = LLamaIndexIOMapperConfig(llm=llm)
io_mapping_input_event = IOMappingInputEvent(
metadata=metadata,
config=config,
data=OverallState(
campaign_details=campaign_details,
selected_users=ev.list_users,
),
)
return io_mapping_input_event
except Exception as e:
print(f"Error parsing campaign details: {e}")
return StopEvent(result=f"{e}")
@step
async def after_translation(self, evt: IOMappingOutputEvent) -> StopEvent:
return StopEvent(result="Done")
async def main():
llm = get_azure(framework=Framework.LLAMA_INDEX)
w = CampaignWorkflow()
IOMappingAgent.as_worfklow_step(workflow=w)
result = await w.run(prompt="Create a campaign for all users", llm=llm)
print(result)
if __name__ == "__main__":
import asyncio
asyncio.run(main())