-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Expand file tree
/
Copy pathworkflow.py
More file actions
99 lines (79 loc) · 3.51 KB
/
workflow.py
File metadata and controls
99 lines (79 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import nest_asyncio
from llama_index.llms.cerebras import Cerebras
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.settings import Settings
from llama_index.core.workflow import Event, Context, Workflow, StartEvent, StopEvent, step
from llama_index.core.schema import NodeWithScore
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.response_synthesizers import CompactAndRefine
class RetrieverEvent(Event):
"""Result of running retrieval"""
nodes: list[NodeWithScore]
class RAGWorkflow(Workflow):
def __init__(self, llm_choice="Llama 4", embedding_model="BAAI/bge-large-en-v1.5"):
super().__init__()
# Get the correct model name based on selection
model_name = "meta-llama/llama-4-scout-17b-16e-instruct"
# Initialize LLM and embedding model
self.llm = Cerebras(model=model_name, api_key=os.getenv("CEREBRAS_API_KEY"))
self.embed_model = FastEmbedEmbedding(model_name=embedding_model)
# Configure global settings
Settings.llm = self.llm
Settings.embed_model = self.embed_model
self.index = None
@step
async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:
"""Entry point to ingest documents from a directory."""
dirname = ev.get("dirname")
if not dirname:
return None
documents = SimpleDirectoryReader(dirname).load_data()
self.index = VectorStoreIndex.from_documents(documents=documents)
return StopEvent(result=self.index)
@step
async def retrieve(self, ctx: Context, ev: StartEvent) -> RetrieverEvent | None:
"""Entry point for RAG retrieval."""
query = ev.get("query")
index = ev.get("index") or self.index
if not query:
return None
if index is None:
print("Index is empty, load some documents before querying!")
return None
retriever = index.as_retriever(similarity_top_k=2)
nodes = await retriever.aretrieve(query)
await ctx.set("query", query)
return RetrieverEvent(nodes=nodes)
@step
async def synthesize(self, ctx: Context, ev: RetrieverEvent) -> StopEvent:
"""Generate a response using retrieved nodes."""
summarizer = CompactAndRefine(streaming=True, verbose=True)
query = await ctx.get("query", default=None)
response = await summarizer.asynthesize(query, nodes=ev.nodes)
return StopEvent(result=response)
async def query(self, query_text: str):
"""Helper method to perform a complete RAG query."""
if self.index is None:
raise ValueError("No documents have been ingested. Call ingest_documents first.")
result = await self.run(query=query_text, index=self.index)
return result
async def ingest_documents(self, directory: str):
"""Helper method to ingest documents."""
result = await self.run(dirname=directory)
self.index = result
return result
# Example usage
async def main():
# Initialize the workflow
workflow = RAGWorkflow(llm_choice="Llama 4")
# Ingest documents
await workflow.ingest_documents("data")
# Perform a query
result = await workflow.query("How was DeepSeekR1 trained?")
# Print the response
async for chunk in result.async_response_gen():
print(chunk, end="", flush=True)
if __name__ == "__main__":
import asyncio
asyncio.run(main())