diff --git a/backend/requirements.txt b/backend/requirements.txt index 91f058f..d661a87 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -45,4 +45,5 @@ tiktoken tqdm~=4.65.0 types-requests==0.1.13 typing-inspect==0.8.0 -typing_extensions==4.5.0 \ No newline at end of file +typing_extensions==4.5.0 +PyPDF2 \ No newline at end of file diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 0fad57c..2fd2903 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -10,6 +10,7 @@ from sqlalchemy import create_engine from PIL import Image from loguru import logger +from PyPDF2 import PdfReader from real_agents.adapters.data_model import ( DatabaseDataModel, @@ -17,6 +18,7 @@ ImageDataModel, TableDataModel, KaggleDataModel, + DocumentDataModel, ) from real_agents.data_agent import ( DataSummaryExecutor, @@ -32,7 +34,8 @@ DOCUMENT_EXTENSIONS = {"pdf", "doc", "docx", "txt"} DATABASE_EXTENSIONS = {"sqlite", "db"} IMAGE_EXTENSIONS = {"jpg", "png", "jpeg"} -ALLOW_EXTENSIONS = TABLE_EXTENSIONS | DOCUMENT_EXTENSIONS | DATABASE_EXTENSIONS | IMAGE_EXTENSIONS +PDF_EXTENSIONS = {"pdf"} +ALLOW_EXTENSIONS = TABLE_EXTENSIONS | DOCUMENT_EXTENSIONS | DATABASE_EXTENSIONS | IMAGE_EXTENSIONS | PDF_EXTENSIONS LOCAL = "local" REDIS = "redis" @@ -127,6 +130,18 @@ def load_grounding_source(file_path: str) -> Any: "size": img.size, "mode": img.mode, } + elif suffix == ".pdf": + brut_doc = PdfReader(file_path) + grounding_source = { + "plain_text": "".join(f'-PAGE_{str(i)}-{page.extract_text()}' for i, page in enumerate(brut_doc.pages)), + "num_pages": len(brut_doc.pages), + "metadata": { + 'author': brut_doc.metadata.author, + 'year': brut_doc.metadata.creation_date.year, + 'title': brut_doc.metadata.title, + 'subject': brut_doc.metadata.subject, + } + } else: raise ValueError("File type not allowed to be set as grounding source") return grounding_source @@ -146,6 +161,8 @@ def get_data_model_cls(file_path: str) -> DataModel: data_model_cls = DatabaseDataModel elif suffix == ".jpeg" or suffix == ".png" or suffix == ".jpg": data_model_cls = ImageDataModel + elif suffix == ".pdf": + data_model_cls = DocumentDataModel else: raise ValueError("File type not allowed to be set as grounding source") return data_model_cls diff --git a/frontend/components/Chatbar/components/ChatbarSettings.tsx b/frontend/components/Chatbar/components/ChatbarSettings.tsx index 2fe9da3..0d13e33 100644 --- a/frontend/components/Chatbar/components/ChatbarSettings.tsx +++ b/frontend/components/Chatbar/components/ChatbarSettings.tsx @@ -233,7 +233,7 @@ export const ChatbarSettings = ({ className="sr-only" tabIndex={-1} type="file" - accept=".csv, .tsv, .xslx, .db, .sqlite, .png, .jpg, .jpeg" + accept=".csv, .tsv, .xslx, .db, .sqlite, .png, .jpg, .jpeg, .pdf" ref={fileInputRef} onChange={handleUpload} /> diff --git a/real_agents/adapters/data_model/__init__.py b/real_agents/adapters/data_model/__init__.py index 75b57ec..e7fe6fe 100644 --- a/real_agents/adapters/data_model/__init__.py +++ b/real_agents/adapters/data_model/__init__.py @@ -6,6 +6,7 @@ from real_agents.adapters.data_model.kaggle import KaggleDataModel from real_agents.adapters.data_model.plugin import APIYamlModel, SpecModel from real_agents.adapters.data_model.table import TableDataModel +from real_agents.adapters.data_model.document import DocumentDataModel __all__ = [ "DataModel", diff --git a/real_agents/adapters/data_model/document.py b/real_agents/adapters/data_model/document.py new file mode 100644 index 0000000..96b0c67 --- /dev/null +++ b/real_agents/adapters/data_model/document.py @@ -0,0 +1,17 @@ +from typing import Any, List + +from real_agents.adapters.data_model.base import DataModel + + +class DocumentDataModel(DataModel): + """A data model for a document (can contain text, images, tables, other data).""" + + def get_raw_data(self) -> Any: + return self.raw_data + + def get_llm_side_data(self, + max_tokens: int = 2000, + chunk_size: int = 1000, + chunk_overlap: int = 200 + ) -> Any: + return self.raw_data['plain_text'][:max_tokens] diff --git a/real_agents/data_agent/executors/data_summary_executor.py b/real_agents/data_agent/executors/data_summary_executor.py index 64ad808..d7dabcb 100644 --- a/real_agents/data_agent/executors/data_summary_executor.py +++ b/real_agents/data_agent/executors/data_summary_executor.py @@ -5,7 +5,12 @@ from langchain import PromptTemplate from real_agents.adapters.callbacks.executor_streaming import ExecutorStreamingChainHandler -from real_agents.adapters.data_model import DatabaseDataModel, TableDataModel, ImageDataModel +from real_agents.adapters.data_model import ( + DatabaseDataModel, + TableDataModel, + ImageDataModel, + DocumentDataModel +) from real_agents.adapters.llm import LLMChain @@ -195,3 +200,96 @@ def _parse_output(self, content: str) -> Tuple[str, str]: bullet_points.append(f"{bullet_point_id}. " + line[1:].strip().strip('"')) bullet_point_id += 1 return table_summary, "\n".join(bullet_points) + + +class DocumentSummaryExecutor(DataSummaryExecutor): + SUMMARY_PROMPT_TEMPLATE = """ + {img_info} + + Provide a succinct summary of the uploaded file with less than 20 words. Please ensure your summary is a complete sentence and include it within tags." + Then provide {num_insights} very simple and basic suggestions in natural language about further processing with the data. The final results should be markdown '+' bullet point list, e.g., + The first suggestion." + + Begin. + """ + stream_handler = ExecutorStreamingChainHandler() + + def run( + self, + grounding_source: DocumentDataModel, + llm: BaseLanguageModel, + use_intelligent_summary: bool = True, + num_insights: int = 3, + ) -> Dict[str, Any]: + summary = "" + if isinstance(grounding_source, DocumentDataModel): + # Basic summary + summary += ( + f"Your document **{grounding_source.raw_data['metadata']['title']}** created by " + f"{grounding_source.raw_data['metadata']['author']} at " + f"{grounding_source.raw_data['metadata']['year']} year. \n" + ) + + # Intelligent summary + if use_intelligent_summary: + intelligent_summary = self._intelligent_summary( + grounding_source, + num_insights=num_insights, + llm=llm, + ) + _, suggestions = self._parse_output(intelligent_summary) + summary += "\n" + suggestions + + for stream_token in summary.split(" "): + self.stream_handler.on_llm_new_token(stream_token) + else: + raise ValueError(f"Unsupported data summary for grounding source type: {type(grounding_source)}") + return summary + + def _intelligent_summary(self, grounding_source: DocumentDataModel, num_insights: int, llm: BaseLanguageModel) -> str: + """Use LLM to generate data summary.""" + summary_prompt_template = PromptTemplate( + input_variables=["img_info", "num_insights"], + template=self.SUMMARY_PROMPT_TEMPLATE, + ) + method = LLMChain(llm=llm, prompt=summary_prompt_template) + result = method.run({"img_info": grounding_source.get_llm_side_data(), "num_insights": num_insights}) + return result + + @staticmethod + def text_summary(llm: BaseLanguageModel, reduce_template: str) -> str: + reduce_prompt = PromptTemplate.from_template(reduce_template) + map_prompt = PromptTemplate.from_template(reduce_prompt) + map_chain = LLMChain(llm=llm, prompt=map_prompt) + # Run chain + reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) + + return reduce_chain, map_chain + + + def _parse_output(self, content: str) -> Tuple[str, str]: + """Parse the output of the LLM to get the data summary.""" + from bs4 import BeautifulSoup + + # Using 'html.parser' to parse the content + soup = BeautifulSoup(content, "html.parser") + # Parsing the tag and summary contents + try: + table_summary = soup.find("summary").text + except Exception: + import traceback + + traceback.print_exc() + table_summary = "" + + lines = content.split("\n") + # Initialize an empty list to hold the parsed bullet points + bullet_points = [] + # Loop through each line + bullet_point_id = 1 + for line in lines: + # If the line starts with '+', it is a bullet point + if line.startswith("+"): + # Remove the '+ ' from the start of the line and add it to the list + bullet_points.append(f"{bullet_point_id}. " + line[1:].strip().strip('"')) + bullet_point_id += 1 + return table_summary, "\n".join(bullet_points) \ No newline at end of file