Skip to content

Commit 95d68ca

Browse files
Feature/web search pipeline (#293)
* web search pipeline * web search pipeline * re-merge logic
1 parent 608d2f9 commit 95d68ca

File tree

4 files changed

+165
-126
lines changed

4 files changed

+165
-126
lines changed

r2r/examples/clients/run_basic_client_old.py

Lines changed: 112 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -23,115 +23,115 @@
2323
print(f"Upsert entry response:\n{entry_response}\n\n")
2424

2525

26-
# entry_response = client.add_entry(
27-
# generate_id_from_label("doc 1"),
28-
# {"txt": "This is a test entry"},
29-
# {"tags": ["example", "test"]},
30-
# do_upsert=False,
31-
# )
32-
# print(f"Copy same entry response:\n{entry_response}\n\n")
33-
34-
35-
# print("Upserting entries to remote db...")
36-
# # Upsert multiple entries
37-
# entries = [
38-
# {
39-
# "document_id": generate_id_from_label("doc 2"),
40-
# "blobs": {"txt": "Second test entry"},
41-
# "metadata": {"tags": "bulk"},
42-
# },
43-
# {
44-
# "document_id": generate_id_from_label("doc 3"),
45-
# "blobs": {"txt": "Third test entry"},
46-
# "metadata": {"tags": "example"},
47-
# },
48-
# ]
49-
# bulk_upsert_response = client.add_entries(entries, do_upsert=True)
50-
# print(f"Upsert entries response:\n{bulk_upsert_response}\n\n")
51-
52-
# # Perform a search
53-
# print("Searching remote db...")
54-
# search_response = client.search("test", 5)
55-
# print(f"Search response:\n{search_response}\n\n")
56-
57-
# print("Searching remote db with filter...")
58-
# # Perform a search w/ filter
59-
# filtered_search_response = client.search("test", 5, filters={"tags": "bulk"})
60-
# print(f"Search response w/ filter:\n{filtered_search_response}\n\n")
61-
62-
# print("Deleting sample document in remote db...")
63-
# # Delete a document
64-
# response = client.filtered_deletion(
65-
# "document_id", generate_id_from_label("doc 2")
66-
# )
67-
# print(f"Deletion response:\n{response}\n\n")
68-
69-
# print("Searching remote db with filter after deletion...")
70-
# # Perform a search w/ filter after deletion
71-
# post_deletion_filtered_search_response = client.search(
72-
# "test", 5, filters={"tags": "bulk"}
73-
# )
74-
# print(
75-
# f"Search response w/ filter+deletion:\n{post_deletion_filtered_search_response}\n\n"
76-
# )
77-
78-
# # Example file path for upload
79-
# # get file directory
80-
# current_file_directory = os.path.dirname(os.path.realpath(__file__))
81-
82-
# file_path = os.path.join(current_file_directory, "..", "data", "test.pdf")
83-
84-
# print(f"Uploading and processing file: {file_path}...")
85-
# # # Upload and process a file
86-
# metadata = {"tags": ["example", "test"]}
87-
# upload_pdf_response = client.upload_and_process_file(
88-
# generate_id_from_label("pdf 1"), file_path, metadata, None
89-
# )
90-
# print(f"Upload test pdf response:\n{upload_pdf_response}\n\n")
91-
92-
# print("Searching remote db after upload...")
93-
# # Perform a search on this file
94-
# pdf_filtered_search_response = client.search(
95-
# "what is a cool physics equation?",
96-
# 5,
97-
# filters={"document_id": generate_id_from_label("pdf 1")},
98-
# )
99-
# print(
100-
# f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
101-
# )
102-
103-
104-
# print("Performing RAG...")
105-
# # Perform a search on this file
106-
# pdf_filtered_search_response = client.rag_completion(
107-
# "Are there any test documents?",
108-
# 5,
109-
# filters={"document_id": generate_id_from_label("pdf 1")},
110-
# )
111-
# print(
112-
# f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
113-
# )
114-
115-
# print("Performing RAG with streaming...")
116-
117-
118-
# # Perform a RAG completion with streaming
119-
# async def stream_rag_completion():
120-
# async for chunk in client.stream_rag_completion(
121-
# "Are there any test documents?",
122-
# 5,
123-
# filters={"document_id": generate_id_from_label("pdf 1")},
124-
# generation_config={"stream": True},
125-
# ):
126-
# print(chunk, end="", flush=True)
127-
128-
129-
# asyncio.run(stream_rag_completion())
130-
131-
# print("Fetching logs after all steps...")
132-
# logs_response = client.get_logs()
133-
# print(f"Logs response:\n{logs_response}\n")
134-
135-
# print("Fetching logs summary after all steps...")
136-
# logs_summary_response = client.get_logs_summary()
137-
# print(f"Logs summary response:\n{logs_summary_response}\n")
26+
entry_response = client.add_entry(
27+
generate_id_from_label("doc 1"),
28+
{"txt": "This is a test entry"},
29+
{"tags": ["example", "test"]},
30+
do_upsert=False,
31+
)
32+
print(f"Copy same entry response:\n{entry_response}\n\n")
33+
34+
35+
print("Upserting entries to remote db...")
36+
# Upsert multiple entries
37+
entries = [
38+
{
39+
"document_id": generate_id_from_label("doc 2"),
40+
"blobs": {"txt": "Second test entry"},
41+
"metadata": {"tags": "bulk"},
42+
},
43+
{
44+
"document_id": generate_id_from_label("doc 3"),
45+
"blobs": {"txt": "Third test entry"},
46+
"metadata": {"tags": "example"},
47+
},
48+
]
49+
bulk_upsert_response = client.add_entries(entries, do_upsert=True)
50+
print(f"Upsert entries response:\n{bulk_upsert_response}\n\n")
51+
52+
# Perform a search
53+
print("Searching remote db...")
54+
search_response = client.search("test", 5)
55+
print(f"Search response:\n{search_response}\n\n")
56+
57+
print("Searching remote db with filter...")
58+
# Perform a search w/ filter
59+
filtered_search_response = client.search("test", 5, filters={"tags": "bulk"})
60+
print(f"Search response w/ filter:\n{filtered_search_response}\n\n")
61+
62+
print("Deleting sample document in remote db...")
63+
# Delete a document
64+
response = client.filtered_deletion(
65+
"document_id", generate_id_from_label("doc 2")
66+
)
67+
print(f"Deletion response:\n{response}\n\n")
68+
69+
print("Searching remote db with filter after deletion...")
70+
# Perform a search w/ filter after deletion
71+
post_deletion_filtered_search_response = client.search(
72+
"test", 5, filters={"tags": "bulk"}
73+
)
74+
print(
75+
f"Search response w/ filter+deletion:\n{post_deletion_filtered_search_response}\n\n"
76+
)
77+
78+
# Example file path for upload
79+
# get file directory
80+
current_file_directory = os.path.dirname(os.path.realpath(__file__))
81+
82+
file_path = os.path.join(current_file_directory, "..", "data", "test.pdf")
83+
84+
print(f"Uploading and processing file: {file_path}...")
85+
# # Upload and process a file
86+
metadata = {"tags": ["example", "test"]}
87+
upload_pdf_response = client.upload_and_process_file(
88+
generate_id_from_label("pdf 1"), file_path, metadata, None
89+
)
90+
print(f"Upload test pdf response:\n{upload_pdf_response}\n\n")
91+
92+
print("Searching remote db after upload...")
93+
# Perform a search on this file
94+
pdf_filtered_search_response = client.search(
95+
"what is a cool physics equation?",
96+
5,
97+
filters={"document_id": generate_id_from_label("pdf 1")},
98+
)
99+
print(
100+
f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
101+
)
102+
103+
104+
print("Performing RAG...")
105+
# Perform a search on this file
106+
pdf_filtered_search_response = client.rag_completion(
107+
"Are there any test documents?",
108+
5,
109+
filters={"document_id": generate_id_from_label("pdf 1")},
110+
)
111+
print(
112+
f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
113+
)
114+
115+
print("Performing RAG with streaming...")
116+
117+
118+
# Perform a RAG completion with streaming
119+
async def stream_rag_completion():
120+
async for chunk in client.stream_rag_completion(
121+
"Are there any test documents?",
122+
5,
123+
filters={"document_id": generate_id_from_label("pdf 1")},
124+
generation_config={"stream": True},
125+
):
126+
print(chunk, end="", flush=True)
127+
128+
129+
asyncio.run(stream_rag_completion())
130+
131+
print("Fetching logs after all steps...")
132+
logs_response = client.get_logs()
133+
print(f"Logs response:\n{logs_response}\n")
134+
135+
print("Fetching logs summary after all steps...")
136+
logs_summary_response = client.get_logs_summary()
137+
print(f"Logs summary response:\n{logs_summary_response}\n")
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
"""A simple example to demonstrate the usage of `WebSearchRAGPipeline`."""
1+
"""A simple example to demonstrate the usage of `WebRAGPipeline`."""
22

33
import uvicorn
44

55
from r2r.main import E2EPipelineFactory, R2RConfig
6-
from r2r.pipelines import WebSearchRAGPipeline
6+
from r2r.pipelines import WebRAGPipeline
77

8-
# Creates a pipeline using the `WebSearchRAGPipeline` implementation
8+
# Creates a pipeline using the `WebRAGPipeline` implementation
99
app = E2EPipelineFactory.create_pipeline(
10-
config=R2RConfig.load_config(), rag_pipeline_impl=WebSearchRAGPipeline
10+
config=R2RConfig.load_config(), rag_pipeline_impl=WebRAGPipeline
1111
)
1212

1313

1414
if __name__ == "__main__":
1515
# Run the FastAPI application using Uvicorn
16-
uvicorn.run(app, host="0.0.0.0", port=8000)
16+
uvicorn.run(app, host="0.0.0.0", port=8000)

r2r/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .basic.ingestion import BasicIngestionPipeline, IngestionType
44
from .basic.prompt_provider import BasicPromptProvider
55
from .basic.rag import BasicRAGPipeline
6-
from .web_search.rag import WebSearchRAGPipeline
6+
from .web_search.rag import WebRAGPipeline
77

88
__all__ = [
99
"DocumentPage",
@@ -13,5 +13,5 @@
1313
"BasicIngestionPipeline",
1414
"BasicPromptProvider",
1515
"BasicRAGPipeline",
16-
"WebSearchRAGPipeline",
16+
"WebRAGPipeline",
1717
]

r2r/pipelines/web_search/rag.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
22
A simple example to demonstrate the usage of `WebSearchRAGPipeline`.
33
"""
4-
4+
import json
55
import logging
6-
from typing import Optional
6+
from typing import Generator, Optional
77

88
from r2r.core import (
9+
GenerationConfig,
910
LLMProvider,
1011
LoggingDatabaseConnection,
11-
PromptProvider,
12+
RAGPipeline,
1213
VectorDBProvider,
1314
log_execution_to_db,
1415
)
@@ -18,20 +19,39 @@
1819

1920
from ..basic.rag import BasicRAGPipeline
2021

22+
WEB_RAG_SYSTEM_PROMPT = "You are a helpful assistant."
23+
WEB_RAG_TASK_PROMPT = """
24+
## Task:
25+
Answer the query given immediately below given the context which follows later. Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context. Pay close attention to the title of each given source to ensure it is consistent with the query.
26+
27+
### Query:
28+
{query}
29+
30+
### Context:
31+
{context}
32+
33+
### Query:
34+
{query}
35+
36+
REMINDER - Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context.
37+
## Response:
38+
"""
2139
logger = logging.getLogger(__name__)
2240

2341

24-
class WebSearchRAGPipeline(BasicRAGPipeline):
42+
class WebRAGPipeline(BasicRAGPipeline):
2543
def __init__(
2644
self,
2745
llm: LLMProvider,
2846
db: VectorDBProvider,
2947
embedding_model: str,
3048
embeddings_provider: OpenAIEmbeddingProvider,
3149
logging_connection: Optional[LoggingDatabaseConnection] = None,
32-
prompt_provider: Optional[PromptProvider] = BasicPromptProvider(),
50+
prompt_provider: Optional[BasicPromptProvider] = BasicPromptProvider(
51+
WEB_RAG_SYSTEM_PROMPT, WEB_RAG_TASK_PROMPT
52+
),
3353
) -> None:
34-
logger.debug(f"Initalizing `WebSearchRAGPipeline`.")
54+
logger.debug(f"Initalizing `WebRAGPipeline`.")
3555
super().__init__(
3656
llm=llm,
3757
logging_connection=logging_connection,
@@ -67,7 +87,7 @@ def search(
6787
)
6888

6989
return results
70-
90+
7191
@log_execution_to_db
7292
def construct_context(self, results: list) -> str:
7393
local_context = super().construct_context(
@@ -77,3 +97,22 @@ def construct_context(self, results: list) -> str:
7797
[ele["result"] for ele in results if ele["type"] == "external"]
7898
)
7999
return local_context + "\n\n" + web_context
100+
101+
def _stream_run(
102+
self,
103+
search_results: list,
104+
context: str,
105+
prompt: str,
106+
generation_config: GenerationConfig,
107+
) -> Generator[str, None, None]:
108+
yield f"<{RAGPipeline.SEARCH_STREAM_MARKER}>"
109+
yield json.dumps(search_results)
110+
yield f"</{RAGPipeline.SEARCH_STREAM_MARKER}>"
111+
112+
yield f"<{RAGPipeline.CONTEXT_STREAM_MARKER}>"
113+
yield context
114+
yield f"</{RAGPipeline.CONTEXT_STREAM_MARKER}>"
115+
yield f"<{RAGPipeline.COMPLETION_STREAM_MARKER}>"
116+
for chunk in self.generate_completion(prompt, generation_config):
117+
yield chunk
118+
yield f"</{RAGPipeline.COMPLETION_STREAM_MARKER}>"

0 commit comments

Comments
 (0)