Skip to content

Commit

Permalink
Merge pull request #459 from SciPhi-AI/Nolan/ReaddBlast
Browse files Browse the repository at this point in the history
Reapply changes from merge conflict
  • Loading branch information
NolanTrem authored Jun 14, 2024
2 parents 2a387b9 + ab271fe commit 5dad883
Showing 1 changed file with 137 additions and 27 deletions.
164 changes: 137 additions & 27 deletions r2r/main/r2r_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
to_async_generator,
)
from r2r.pipes import R2REvalPipe
from r2r.telemetry.telemetry_decorator import telemetry_event

from .r2r_abstractions import R2RPipelines, R2RProviders
from .r2r_config import R2RConfig
Expand Down Expand Up @@ -258,6 +259,7 @@ class UpdatePromptRequest(BaseModel):
template: Optional[str] = None
input_types: Optional[dict[str, str]] = None

@telemetry_event("UpdatePrompt")
async def update_prompt_app(self, request: UpdatePromptRequest):
"""Update a prompt's template and/or input types."""
try:
Expand Down Expand Up @@ -289,7 +291,27 @@ async def aingest_documents(
)

document_infos = []
skipped_documents = []
processed_documents = []
existing_document_ids = [
str(doc_info.document_id)
for doc_info in self.providers.vector_db.get_documents_info()
]

for iteration, document in enumerate(documents):
if (
version is not None
and str(document.id) in existing_document_ids
):
logger.error(f"Document with ID {document.id} already exists.")
if len(documents) == 1:
raise HTTPException(
status_code=409,
detail=f"Document with ID {document.id} already exists.",
)
skipped_documents.append(document.title or str(document.id))
continue

document_metadata = (
metadatas[iteration] if metadatas else document.metadata
)
Expand Down Expand Up @@ -319,24 +341,62 @@ async def aingest_documents(
)
)

processed_documents.append(document.title or str(document.id))

if skipped_documents and len(skipped_documents) == len(documents):
logger.error("All provided documents already exist.")
raise HTTPException(
status_code=409,
detail="All provided documents already exist. Use the update endpoint to update these documents.",
)

if skipped_documents:
logger.warning(
f"Skipped ingestion for the following documents since they already exist: {', '.join(skipped_documents)}. Use the update endpoint to update these documents."
)

await self.ingestion_pipeline.run(
input=to_async_generator(documents),
versions=versions,
input=to_async_generator(
[
doc
for doc in documents
if str(doc.id) not in existing_document_ids
]
),
versions=[
info.version
for info in document_infos
if info.created_at == info.updated_at
],
run_manager=self.run_manager,
)

self.providers.vector_db.upsert_documents_info(document_infos)
return {"results": "Entries upserted successfully."}
return {
"processed_documents": [
f"Document '{title}' processed successfully."
for title in processed_documents
],
"skipped_documents": [
f"Document '{title}' skipped since it already exists."
for title in skipped_documents
],
}

class IngestDocumentsRequest(BaseModel):
documents: list[Document]

@telemetry_event("IngestDocuments")
async def ingest_documents_app(self, request: IngestDocumentsRequest):
async with manage_run(
self.run_manager, "ingest_documents_app"
) as run_id:
try:
return await self.aingest_documents(request.documents)

except HTTPException as he:
raise he

except Exception as e:
await self.logging_connection.log(
log_id=run_id,
Expand Down Expand Up @@ -423,6 +483,7 @@ async def aupdate_documents(
class UpdateDocumentsRequest(BaseModel):
documents: list[Document]

@telemetry_event("UpdateDocuments")
async def update_documents_app(self, request: UpdateDocumentsRequest):
async with manage_run(
self.run_manager, "update_documents_app"
Expand All @@ -445,10 +506,7 @@ async def update_documents_app(self, request: UpdateDocumentsRequest):
logger.error(
f"update_documents_app(documents={request.documents}) - \n\n{str(e)})"
)
logger.error(
f"update_documents_app(documents={request.documents}) - \n\n{str(e)})"
)
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

@syncable
async def aingest_files(
Expand Down Expand Up @@ -482,6 +540,12 @@ async def aingest_files(
try:
documents = []
document_infos = []
skipped_documents = []
processed_documents = []
existing_document_ids = [
str(doc_info.document_id)
for doc_info in self.providers.vector_db.get_documents_info()
]

for iteration, file in enumerate(files):
logger.info(f"Processing file: {file.filename}")
Expand Down Expand Up @@ -522,14 +586,27 @@ async def aingest_files(
detail=f"{file_extension} is explicitly excluded in the configuration file.",
)

file_content = await file.read()
logger.info(f"File read successfully: {file.filename}")

document_id = (
generate_id_from_label(file.filename)
if document_ids is None
else document_ids[iteration]
)
if (
version is not None
and str(document_id) in existing_document_ids
):
logger.error(f"File with ID {document_id} already exists.")
if len(files) == 1:
raise HTTPException(
status_code=409,
detail=f"File with ID {document_id} already exists.",
)
skipped_documents.append(file.filename)
continue

file_content = await file.read()
logger.info(f"File read successfully: {file.filename}")

document_metadata = metadatas[iteration] if metadatas else {}
document_title = (
document_metadata.get("title", None) or file.filename
Expand Down Expand Up @@ -567,7 +644,21 @@ async def aingest_files(
)
)

# Run the pipeline asynchronously with filtered documents
processed_documents.append(file.filename)

if skipped_documents and len(skipped_documents) == len(files):
logger.error("All uploaded documents already exist.")
raise HTTPException(
status_code=409,
detail="All uploaded documents already exist. Use the update endpoint to update these documents.",
)

if skipped_documents:
logger.warning(
f"Skipped ingestion for the following documents since they already exist: {', '.join(skipped_documents)}. Use the update endpoint to update these documents."
)

# Run the pipeline asynchronously
await self.ingestion_pipeline.run(
input=to_async_generator(documents),
versions=versions,
Expand All @@ -578,8 +669,14 @@ async def aingest_files(
self.providers.vector_db.upsert_documents_info(document_infos)

return {
"results": f"File '{file}' processed successfully."
for file in document_infos
"processed_documents": [
f"File '{filename}' processed successfully."
for filename in processed_documents
],
"skipped_documents": [
f"File '{filename}' skipped since it already exists."
for filename in skipped_documents
],
}
except Exception as e:
raise e
Expand All @@ -588,6 +685,7 @@ async def aingest_files(
for file in files:
file.file.close()

@telemetry_event("IngestFiles")
async def ingest_files_app(
self,
files: list[UploadFile] = File(...),
Expand Down Expand Up @@ -756,6 +854,7 @@ class UpdateFilesRequest(BaseModel):
metadatas: Optional[str] = Form(None)
ids: str = Form("")

@telemetry_event("UpdateFiles")
async def update_files_app(
self,
files: list[UploadFile] = File(...),
Expand Down Expand Up @@ -845,6 +944,7 @@ class SearchRequest(BaseModel):
search_limit: int = 10
do_hybrid_search: Optional[bool] = False

@telemetry_event("Search")
async def search_app(self, request: SearchRequest):
async with manage_run(self.run_manager, "search_app") as run_id:
try:
Expand Down Expand Up @@ -960,6 +1060,7 @@ class RAGRequest(BaseModel):
rag_generation_config: Optional[str] = None
streaming: Optional[bool] = None

@telemetry_event("RAG")
async def rag_app(self, request: RAGRequest):
async with manage_run(self.run_manager, "rag_app") as run_id:
try:
Expand Down Expand Up @@ -1069,6 +1170,7 @@ class EvalRequest(BaseModel):
context: str
completion: str

@telemetry_event("Evaluate")
async def evaluate_app(self, request: EvalRequest):
async with manage_run(self.run_manager, "evaluate_app") as run_id:
try:
Expand Down Expand Up @@ -1110,6 +1212,7 @@ class DeleteRequest(BaseModel):
keys: list[str]
values: list[Union[bool, int, str]]

@telemetry_event("Delete")
async def delete_app(self, request: DeleteRequest = Body(...)):
try:
return await self.adelete(request.keys, request.values)
Expand Down Expand Up @@ -1168,6 +1271,7 @@ async def alogs(

return {"results": aggregated_logs}

@telemetry_event("Logs")
async def logs_app(
self,
log_type_filter: Optional[str] = Query(None),
Expand Down Expand Up @@ -1236,27 +1340,27 @@ async def aanalytics(
analysis_type = analysis_config[0]
if analysis_type == "bar_chart":
extract_key = analysis_config[1]
results[
filter_key
] = AnalysisTypes.generate_bar_chart_data(
filtered_logs[filter_key], extract_key
results[filter_key] = (
AnalysisTypes.generate_bar_chart_data(
filtered_logs[filter_key], extract_key
)
)
elif analysis_type == "basic_statistics":
extract_key = analysis_config[1]
results[
filter_key
] = AnalysisTypes.calculate_basic_statistics(
filtered_logs[filter_key], extract_key
results[filter_key] = (
AnalysisTypes.calculate_basic_statistics(
filtered_logs[filter_key], extract_key
)
)
elif analysis_type == "percentile":
extract_key = analysis_config[1]
percentile = int(analysis_config[2])
results[
filter_key
] = AnalysisTypes.calculate_percentile(
filtered_logs[filter_key],
extract_key,
percentile,
results[filter_key] = (
AnalysisTypes.calculate_percentile(
filtered_logs[filter_key],
extract_key,
percentile,
)
)
else:
logger.warning(
Expand All @@ -1265,6 +1369,7 @@ async def aanalytics(

return {"results": results}

@telemetry_event("Analytics")
async def analytics_app(
self,
filter_criteria: FilterCriteria = Body(...),
Expand Down Expand Up @@ -1292,6 +1397,7 @@ async def aapp_settings(self, *args: Any, **kwargs: Any):
}
}

@telemetry_event("AppSettings")
async def app_settings_app(self):
"""Return the config.json and all prompts."""
try:
Expand All @@ -1306,6 +1412,7 @@ async def ausers_stats(self, user_ids: Optional[list[uuid.UUID]] = None):
[str(ele) for ele in user_ids]
)

@telemetry_event("UsersStats")
async def users_stats_app(
self, user_ids: Optional[list[uuid.UUID]] = Query(None)
):
Expand Down Expand Up @@ -1335,6 +1442,7 @@ async def adocuments_info(
),
)

@telemetry_event("DocumentsInfo")
async def documents_info_app(
self,
document_ids: Optional[list[str]] = Query(None),
Expand All @@ -1355,6 +1463,7 @@ async def documents_info_app(
async def adocument_chunks(self, document_id: str) -> list[str]:
return self.providers.vector_db.get_document_chunks(document_id)

@telemetry_event("DocumentChunks")
async def document_chunks_app(self, document_id: str):
try:
chunks = await self.adocument_chunks(document_id)
Expand All @@ -1365,6 +1474,7 @@ async def document_chunks_app(self, document_id: str):
)
raise HTTPException(status_code=500, detail=str(e)) from e

@telemetry_event("OpenAPI")
def openapi_spec_app(self):
from fastapi.openapi.utils import get_openapi

Expand Down

0 comments on commit 5dad883

Please sign in to comment.