Skip to content

Commit

Permalink
Feature/web search pipeline (#293)
Browse files Browse the repository at this point in the history
* web search pipeline

* web search pipeline

* re-merge logic
  • Loading branch information
emrgnt-cmplxty authored Apr 15, 2024
1 parent 608d2f9 commit 95d68ca
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 126 deletions.
224 changes: 112 additions & 112 deletions r2r/examples/clients/run_basic_client_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,115 +23,115 @@
print(f"Upsert entry response:\n{entry_response}\n\n")


# entry_response = client.add_entry(
# generate_id_from_label("doc 1"),
# {"txt": "This is a test entry"},
# {"tags": ["example", "test"]},
# do_upsert=False,
# )
# print(f"Copy same entry response:\n{entry_response}\n\n")


# print("Upserting entries to remote db...")
# # Upsert multiple entries
# entries = [
# {
# "document_id": generate_id_from_label("doc 2"),
# "blobs": {"txt": "Second test entry"},
# "metadata": {"tags": "bulk"},
# },
# {
# "document_id": generate_id_from_label("doc 3"),
# "blobs": {"txt": "Third test entry"},
# "metadata": {"tags": "example"},
# },
# ]
# bulk_upsert_response = client.add_entries(entries, do_upsert=True)
# print(f"Upsert entries response:\n{bulk_upsert_response}\n\n")

# # Perform a search
# print("Searching remote db...")
# search_response = client.search("test", 5)
# print(f"Search response:\n{search_response}\n\n")

# print("Searching remote db with filter...")
# # Perform a search w/ filter
# filtered_search_response = client.search("test", 5, filters={"tags": "bulk"})
# print(f"Search response w/ filter:\n{filtered_search_response}\n\n")

# print("Deleting sample document in remote db...")
# # Delete a document
# response = client.filtered_deletion(
# "document_id", generate_id_from_label("doc 2")
# )
# print(f"Deletion response:\n{response}\n\n")

# print("Searching remote db with filter after deletion...")
# # Perform a search w/ filter after deletion
# post_deletion_filtered_search_response = client.search(
# "test", 5, filters={"tags": "bulk"}
# )
# print(
# f"Search response w/ filter+deletion:\n{post_deletion_filtered_search_response}\n\n"
# )

# # Example file path for upload
# # get file directory
# current_file_directory = os.path.dirname(os.path.realpath(__file__))

# file_path = os.path.join(current_file_directory, "..", "data", "test.pdf")

# print(f"Uploading and processing file: {file_path}...")
# # # Upload and process a file
# metadata = {"tags": ["example", "test"]}
# upload_pdf_response = client.upload_and_process_file(
# generate_id_from_label("pdf 1"), file_path, metadata, None
# )
# print(f"Upload test pdf response:\n{upload_pdf_response}\n\n")

# print("Searching remote db after upload...")
# # Perform a search on this file
# pdf_filtered_search_response = client.search(
# "what is a cool physics equation?",
# 5,
# filters={"document_id": generate_id_from_label("pdf 1")},
# )
# print(
# f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
# )


# print("Performing RAG...")
# # Perform a search on this file
# pdf_filtered_search_response = client.rag_completion(
# "Are there any test documents?",
# 5,
# filters={"document_id": generate_id_from_label("pdf 1")},
# )
# print(
# f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
# )

# print("Performing RAG with streaming...")


# # Perform a RAG completion with streaming
# async def stream_rag_completion():
# async for chunk in client.stream_rag_completion(
# "Are there any test documents?",
# 5,
# filters={"document_id": generate_id_from_label("pdf 1")},
# generation_config={"stream": True},
# ):
# print(chunk, end="", flush=True)


# asyncio.run(stream_rag_completion())

# print("Fetching logs after all steps...")
# logs_response = client.get_logs()
# print(f"Logs response:\n{logs_response}\n")

# print("Fetching logs summary after all steps...")
# logs_summary_response = client.get_logs_summary()
# print(f"Logs summary response:\n{logs_summary_response}\n")
entry_response = client.add_entry(
generate_id_from_label("doc 1"),
{"txt": "This is a test entry"},
{"tags": ["example", "test"]},
do_upsert=False,
)
print(f"Copy same entry response:\n{entry_response}\n\n")


print("Upserting entries to remote db...")
# Upsert multiple entries
entries = [
{
"document_id": generate_id_from_label("doc 2"),
"blobs": {"txt": "Second test entry"},
"metadata": {"tags": "bulk"},
},
{
"document_id": generate_id_from_label("doc 3"),
"blobs": {"txt": "Third test entry"},
"metadata": {"tags": "example"},
},
]
bulk_upsert_response = client.add_entries(entries, do_upsert=True)
print(f"Upsert entries response:\n{bulk_upsert_response}\n\n")

# Perform a search
print("Searching remote db...")
search_response = client.search("test", 5)
print(f"Search response:\n{search_response}\n\n")

print("Searching remote db with filter...")
# Perform a search w/ filter
filtered_search_response = client.search("test", 5, filters={"tags": "bulk"})
print(f"Search response w/ filter:\n{filtered_search_response}\n\n")

print("Deleting sample document in remote db...")
# Delete a document
response = client.filtered_deletion(
"document_id", generate_id_from_label("doc 2")
)
print(f"Deletion response:\n{response}\n\n")

print("Searching remote db with filter after deletion...")
# Perform a search w/ filter after deletion
post_deletion_filtered_search_response = client.search(
"test", 5, filters={"tags": "bulk"}
)
print(
f"Search response w/ filter+deletion:\n{post_deletion_filtered_search_response}\n\n"
)

# Example file path for upload
# get file directory
current_file_directory = os.path.dirname(os.path.realpath(__file__))

file_path = os.path.join(current_file_directory, "..", "data", "test.pdf")

print(f"Uploading and processing file: {file_path}...")
# # Upload and process a file
metadata = {"tags": ["example", "test"]}
upload_pdf_response = client.upload_and_process_file(
generate_id_from_label("pdf 1"), file_path, metadata, None
)
print(f"Upload test pdf response:\n{upload_pdf_response}\n\n")

print("Searching remote db after upload...")
# Perform a search on this file
pdf_filtered_search_response = client.search(
"what is a cool physics equation?",
5,
filters={"document_id": generate_id_from_label("pdf 1")},
)
print(
f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
)


print("Performing RAG...")
# Perform a search on this file
pdf_filtered_search_response = client.rag_completion(
"Are there any test documents?",
5,
filters={"document_id": generate_id_from_label("pdf 1")},
)
print(
f"Search response w/ uploaded pdf filter:\n{pdf_filtered_search_response}\n"
)

print("Performing RAG with streaming...")


# Perform a RAG completion with streaming
async def stream_rag_completion():
async for chunk in client.stream_rag_completion(
"Are there any test documents?",
5,
filters={"document_id": generate_id_from_label("pdf 1")},
generation_config={"stream": True},
):
print(chunk, end="", flush=True)


asyncio.run(stream_rag_completion())

print("Fetching logs after all steps...")
logs_response = client.get_logs()
print(f"Logs response:\n{logs_response}\n")

print("Fetching logs summary after all steps...")
logs_summary_response = client.get_logs_summary()
print(f"Logs summary response:\n{logs_summary_response}\n")
10 changes: 5 additions & 5 deletions r2r/examples/servers/web_search_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""A simple example to demonstrate the usage of `WebSearchRAGPipeline`."""
"""A simple example to demonstrate the usage of `WebRAGPipeline`."""

import uvicorn

from r2r.main import E2EPipelineFactory, R2RConfig
from r2r.pipelines import WebSearchRAGPipeline
from r2r.pipelines import WebRAGPipeline

# Creates a pipeline using the `WebSearchRAGPipeline` implementation
# Creates a pipeline using the `WebRAGPipeline` implementation
app = E2EPipelineFactory.create_pipeline(
config=R2RConfig.load_config(), rag_pipeline_impl=WebSearchRAGPipeline
config=R2RConfig.load_config(), rag_pipeline_impl=WebRAGPipeline
)


if __name__ == "__main__":
# Run the FastAPI application using Uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)
4 changes: 2 additions & 2 deletions r2r/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .basic.ingestion import BasicIngestionPipeline, IngestionType
from .basic.prompt_provider import BasicPromptProvider
from .basic.rag import BasicRAGPipeline
from .web_search.rag import WebSearchRAGPipeline
from .web_search.rag import WebRAGPipeline

__all__ = [
"DocumentPage",
Expand All @@ -13,5 +13,5 @@
"BasicIngestionPipeline",
"BasicPromptProvider",
"BasicRAGPipeline",
"WebSearchRAGPipeline",
"WebRAGPipeline",
]
53 changes: 46 additions & 7 deletions r2r/pipelines/web_search/rag.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
A simple example to demonstrate the usage of `WebSearchRAGPipeline`.
"""

import json
import logging
from typing import Optional
from typing import Generator, Optional

from r2r.core import (
GenerationConfig,
LLMProvider,
LoggingDatabaseConnection,
PromptProvider,
RAGPipeline,
VectorDBProvider,
log_execution_to_db,
)
Expand All @@ -18,20 +19,39 @@

from ..basic.rag import BasicRAGPipeline

WEB_RAG_SYSTEM_PROMPT = "You are a helpful assistant."
WEB_RAG_TASK_PROMPT = """
## Task:
Answer the query given immediately below given the context which follows later. Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context. Pay close attention to the title of each given source to ensure it is consistent with the query.
### Query:
{query}
### Context:
{context}
### Query:
{query}
REMINDER - Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context.
## Response:
"""
logger = logging.getLogger(__name__)


class WebSearchRAGPipeline(BasicRAGPipeline):
class WebRAGPipeline(BasicRAGPipeline):
def __init__(
self,
llm: LLMProvider,
db: VectorDBProvider,
embedding_model: str,
embeddings_provider: OpenAIEmbeddingProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
prompt_provider: Optional[PromptProvider] = BasicPromptProvider(),
prompt_provider: Optional[BasicPromptProvider] = BasicPromptProvider(
WEB_RAG_SYSTEM_PROMPT, WEB_RAG_TASK_PROMPT
),
) -> None:
logger.debug(f"Initalizing `WebSearchRAGPipeline`.")
logger.debug(f"Initalizing `WebRAGPipeline`.")
super().__init__(
llm=llm,
logging_connection=logging_connection,
Expand Down Expand Up @@ -67,7 +87,7 @@ def search(
)

return results

@log_execution_to_db
def construct_context(self, results: list) -> str:
local_context = super().construct_context(
Expand All @@ -77,3 +97,22 @@ def construct_context(self, results: list) -> str:
[ele["result"] for ele in results if ele["type"] == "external"]
)
return local_context + "\n\n" + web_context

def _stream_run(
self,
search_results: list,
context: str,
prompt: str,
generation_config: GenerationConfig,
) -> Generator[str, None, None]:
yield f"<{RAGPipeline.SEARCH_STREAM_MARKER}>"
yield json.dumps(search_results)
yield f"</{RAGPipeline.SEARCH_STREAM_MARKER}>"

yield f"<{RAGPipeline.CONTEXT_STREAM_MARKER}>"
yield context
yield f"</{RAGPipeline.CONTEXT_STREAM_MARKER}>"
yield f"<{RAGPipeline.COMPLETION_STREAM_MARKER}>"
for chunk in self.generate_completion(prompt, generation_config):
yield chunk
yield f"</{RAGPipeline.COMPLETION_STREAM_MARKER}>"

0 comments on commit 95d68ca

Please sign in to comment.