diff --git a/prepline_general/api/general.py b/prepline_general/api/general.py index 30b91796..55133572 100644 --- a/prepline_general/api/general.py +++ b/prepline_general/api/general.py @@ -19,11 +19,11 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial import pypdf -from pypdf import PdfReader, PdfWriter +from pypdf import PdfReader, PdfWriter, PageObject import psutil import requests import backoff -from typing import Optional, Mapping +from typing import Any, Dict, IO, List, Mapping, Optional, Tuple from fastapi import ( status, FastAPI, @@ -40,6 +40,7 @@ import secrets # Unstructured Imports +from unstructured.documents.elements import Element from unstructured.partition.auto import partition from unstructured.staging.base import ( convert_to_isd, @@ -53,18 +54,7 @@ app = FastAPI() router = APIRouter() - -def is_expected_response_type(media_type, response_type): - if media_type == "application/json" and response_type not in [dict, list]: - return True - elif media_type == "text/csv" and response_type != str: - return True - else: - return False - - -logger = logging.getLogger("unstructured_api") - +Is_Chipper_Processing = False DEFAULT_MIMETYPES = ( "application/pdf,application/msword,image/jpeg,image/png,text/markdown," @@ -90,7 +80,21 @@ def is_expected_response_type(media_type, response_type): os.environ["UNSTRUCTURED_ALLOWED_MIMETYPES"] = DEFAULT_MIMETYPES -def get_pdf_splits(pdf_pages, split_size=1): +def is_expected_response_type( + media_type: str, response_type: Union[str, Dict[Any, Any], List[Any]] +) -> bool: + if media_type == "application/json" and response_type not in [dict, list]: # type: ignore + return True + elif media_type == "text/csv" and response_type != str: + return True + else: + return False + + +logger = logging.getLogger("unstructured_api") + + +def get_pdf_splits(pdf_pages: List[PageObject], split_size: int = 1): """ Given a pdf (PdfReader) with n pages, split it into pdfs each with split_size # of pages Return the files with their page offset in the form [( BytesIO, int)] @@ -105,7 +109,7 @@ def get_pdf_splits(pdf_pages, split_size=1): for page in pdf_pages[offset:end]: new_pdf.add_page(page) - new_pdf.write(pdf_buffer) + new_pdf.write(pdf_buffer) # type: ignore pdf_buffer.seek(0) yield (pdf_buffer.read(), offset) @@ -113,8 +117,8 @@ def get_pdf_splits(pdf_pages, split_size=1): # Do not retry with these status codes -def is_non_retryable(e): - return 400 <= e.status_code < 500 +def is_non_retryable(e: Exception) -> bool: + return 400 <= e.status_code < 500 # type: ignore @backoff.on_exception( @@ -124,7 +128,14 @@ def is_non_retryable(e): giveup=is_non_retryable, logger=logger, ) -def call_api(request_url, api_key, filename, file, content_type, **partition_kwargs): +def call_api( + request_url: str, + api_key: str, + filename: str, + file: IO[bytes], + content_type: str, + **partition_kwargs: Dict[str, Any], +): """ Call the api with the given request_url. """ @@ -144,7 +155,13 @@ def call_api(request_url, api_key, filename, file, content_type, **partition_kwa return response.text -def partition_file_via_api(file_tuple, request, filename, content_type, **partition_kwargs): +def partition_file_via_api( + file_tuple: Tuple[Any, Any], + request: Request, + filename: str, + content_type: str, + **partition_kwargs: Dict[str, Any], +): """ Send the given file to be partitioned remotely with retry logic, where the remote url is set by env var. @@ -163,7 +180,7 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit api_key = request.headers.get("unstructured-api-key") - result = call_api(request_url, api_key, filename, file, content_type, **partition_kwargs) + result = call_api(request_url, api_key, filename, file, content_type, **partition_kwargs) # type: ignore elements = elements_from_json(text=result) # We need to account for the original page numbers @@ -176,8 +193,14 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit def partition_pdf_splits( - request, pdf_pages, file, metadata_filename, content_type, coordinates, **partition_kwargs -): + request: Request, + pdf_pages: List[PageObject], + file: IO[bytes], + metadata_filename: Optional[str], + content_type: str, + coordinates: bool, + **partition_kwargs: Dict[str, Any], +) -> List[Element]: """ Split a pdf into chunks and process in parallel with more api calls, or partition locally if the chunk is small enough. As soon as any remote call fails, bubble up @@ -204,7 +227,7 @@ def partition_pdf_splits( **partition_kwargs, ) - results = [] + results: List[Element] = [] page_iterator = get_pdf_splits(pdf_pages, split_size=pages_per_pdf) partition_func = partial( @@ -224,9 +247,6 @@ def partition_pdf_splits( return results -IS_CHIPPER_PROCESSING = False - - class ChipperMemoryProtection: """ Chipper calls are expensive, and right now we can only do one call at a time. @@ -235,50 +255,50 @@ class ChipperMemoryProtection: """ def __enter__(self): - global IS_CHIPPER_PROCESSING - if IS_CHIPPER_PROCESSING: + global Is_Chipper_Processing + if Is_Chipper_Processing: # Log here so we can track how often it happens logger.error("Chipper is already is use") raise HTTPException( status_code=503, detail="Server is under heavy load. Please try again later." ) - IS_CHIPPER_PROCESSING = True + Is_Chipper_Processing = True - def __exit__(self, exc_type, exc_value, exc_tb): - global IS_CHIPPER_PROCESSING - IS_CHIPPER_PROCESSING = False + def __exit__(self, exc_type, exc_value, exc_tb): # type: ignore + global Is_Chipper_Processing + Is_Chipper_Processing = False def pipeline_api( - file, - request=None, - filename="", - file_content_type=None, - response_type="application/json", - m_coordinates=[], - m_encoding=[], - m_hi_res_model_name=[], - m_include_page_breaks=[], - m_ocr_languages=None, - m_pdf_infer_table_structure=[], - m_skip_infer_table_types=[], - m_strategy=[], - m_xml_keep_tags=[], - languages=None, - m_chunking_strategy=[], - m_multipage_sections=[], - m_combine_under_n_chars=[], - m_new_after_n_chars=[], - m_max_characters=[], + file: Optional[IO[bytes]], + request: Request, + filename: Union[str, None] = "", + file_content_type: Union[str, None] = None, + response_type: str = "application/json", + m_coordinates: List[str] = [], + m_encoding: List[str] = [], + m_hi_res_model_name: List[str] = [], + m_include_page_breaks: List[str] = [], + m_ocr_languages: Union[List[str], None] = None, + m_pdf_infer_table_structure: List[str] = [], + m_skip_infer_table_types: List[str] = [], + m_strategy: List[str] = [], + m_xml_keep_tags: List[str] = [], + languages: Union[List[str], None] = None, + m_chunking_strategy: List[str] = [], + m_multipage_sections: List[str] = [], + m_combine_under_n_chars: List[str] = [], + m_new_after_n_chars: List[str] = [], + m_max_characters: List[str] = [], ): - if filename.endswith(".msg"): + if filename and filename.endswith(".msg"): # Note(yuming): convert file type for msg files # since fast api might sent the wrong one. file_content_type = "application/x-ole-storage" # We don't want to keep logging the same params for every parallel call - origin_ip = request.headers.get("X-Forwarded-For") or request.client.host + origin_ip = request.headers.get("X-Forwarded-For") or request.client.host # type: ignore is_internal_request = origin_ip.startswith("10.") if not is_internal_request: @@ -313,11 +333,10 @@ def pipeline_api( _check_free_memory() - if file_content_type == "application/pdf": - pdf = _check_pdf(file) + pdf = _check_pdf(file) if file and file_content_type == "application/pdf" else None show_coordinates_str = (m_coordinates[0] if len(m_coordinates) else "false").lower() - show_coordinates = show_coordinates_str == "true" + show_coordinates: bool = show_coordinates_str == "true" hi_res_model_name = _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates) strategy = _validate_strategy(m_strategy) @@ -394,7 +413,7 @@ def pipeline_api( ) ) - partition_kwargs = { + partition_kwargs: Dict[str, Any] = { "file": file, "metadata_filename": filename, "content_type": file_content_type, @@ -413,8 +432,9 @@ def pipeline_api( "new_after_n_chars": new_after_n_chars, "max_characters": max_characters, } + elements: List[Element] - if file_content_type == "application/pdf" and pdf_parallel_mode_enabled: + if file_content_type == "application/pdf" and pdf_parallel_mode_enabled and pdf: # Be careful of naming differences in api params vs partition params! # These kwargs are going back into the api, not into partition # They need to be switched back in partition_pdf_splits @@ -480,7 +500,7 @@ def pipeline_api( # Clean up returned elements # Note(austin): pydantic should control this sort of thing for us for i, element in enumerate(elements): - elements[i].metadata.filename = os.path.basename(filename) + elements[i].metadata.filename = os.path.basename(filename) # type: ignore if not show_coordinates and element.metadata.coordinates: elements[i].metadata.coordinates = None @@ -516,25 +536,25 @@ def _check_free_memory(): ) -def _check_pdf(file): +def _check_pdf(file: Union[str, IO[Any]]) -> PdfReader: """Check if the PDF file is encrypted, otherwise assume it is not a valid PDF.""" try: - pdf = PdfReader(file) + pdf = PdfReader(stream=file) # StrByteType can be str or IO[Any] # This will raise if the file is encrypted pdf.metadata return pdf - except pypdf.errors.FileNotDecryptedError: + except pypdf.errors.FileNotDecryptedError: # type: ignore raise HTTPException( status_code=400, detail="File is encrypted. Please decrypt it with password.", ) - except pypdf.errors.PdfReadError: + except pypdf.errors.PdfReadError: # type: ignore raise HTTPException(status_code=422, detail="File does not appear to be a valid PDF") -def _validate_strategy(m_strategy): - strategy = (m_strategy[0] if len(m_strategy) else "auto").lower() +def _validate_strategy(m_strategy: List[str]) -> str: + strategy: str = (m_strategy[0] if len(m_strategy) else "auto").lower() strategies = ["fast", "hi_res", "auto", "ocr_only"] if strategy not in strategies: raise HTTPException( @@ -543,7 +563,9 @@ def _validate_strategy(m_strategy): return strategy -def _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates): +def _validate_hi_res_model_name( + m_hi_res_model_name: List[str], show_coordinates: bool +) -> Union[str, None]: hi_res_model_name = m_hi_res_model_name[0] if len(m_hi_res_model_name) else None # Make sure chipper aliases to the latest model @@ -558,7 +580,7 @@ def _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates): return hi_res_model_name -def _validate_chunking_strategy(m_chunking_strategy): +def _validate_chunking_strategy(m_chunking_strategy: List[str]) -> Union[str, None]: chunking_strategy = m_chunking_strategy[0].lower() if len(m_chunking_strategy) else None chunk_strategies = ["by_title"] if chunking_strategy and (chunking_strategy not in chunk_strategies): @@ -569,18 +591,17 @@ def _validate_chunking_strategy(m_chunking_strategy): return chunking_strategy -def _set_pdf_infer_table_structure(m_pdf_infer_table_structure, strategy): +def _set_pdf_infer_table_structure(m_pdf_infer_table_structure: List[str], strategy: str) -> bool: pdf_infer_table_structure = ( m_pdf_infer_table_structure[0] if len(m_pdf_infer_table_structure) else "false" ).lower() if strategy == "hi_res" and pdf_infer_table_structure == "true": - pdf_infer_table_structure = True + return True else: - pdf_infer_table_structure = False - return pdf_infer_table_structure + return False -def get_validated_mimetype(file): +def get_validated_mimetype(file: UploadFile): """ Return a file's mimetype, either via the file.content_type or the mimetypes lib if that's too generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and @@ -591,7 +612,7 @@ def get_validated_mimetype(file): content_type = mimetypes.guess_type(str(file.filename))[0] # Some filetypes missing for this library, just hardcode them for now - if not content_type: + if not content_type and file.filename: if file.filename.endswith(".md"): content_type = "text/markdown" elif file.filename.endswith(".msg"): @@ -613,7 +634,12 @@ def get_validated_mimetype(file): class MultipartMixedResponse(StreamingResponse): CRLF = b"\r\n" - def __init__(self, *args, content_type: str = None, **kwargs): + def __init__( + self, + *args: Any, + content_type: Union[str, None] = None, + **kwargs, # type: ignore + ): super().__init__(*args, **kwargs) self.content_type = content_type @@ -627,7 +653,7 @@ def init_headers(self, headers: Optional[Mapping[str, str]] = None) -> None: def boundary(self): return b"--" + self.boundary_value.encode() - def _build_part_headers(self, headers: dict) -> bytes: + def _build_part_headers(self, headers: Dict[str, str]) -> bytes: header_bytes = b"" for header, value in headers.items(): header_bytes += f"{header}: {value}".encode() + self.CRLF @@ -635,7 +661,10 @@ def _build_part_headers(self, headers: dict) -> bytes: def build_part(self, chunk: bytes) -> bytes: part = self.boundary + self.CRLF - part_headers = {"Content-Length": len(chunk), "Content-Transfer-Encoding": "base64"} + part_headers: Dict[str, Any] = { + "Content-Length": len(chunk), + "Content-Transfer-Encoding": "base64", + } if self.content_type is not None: part_headers["Content-Type"] = self.content_type part += self._build_part_headers(part_headers) @@ -661,8 +690,10 @@ async def stream_response(self, send: Send) -> None: await send({"type": "http.response.body", "body": b"", "more_body": False}) -def ungz_file(file: UploadFile, gz_uncompressed_content_type=None) -> UploadFile: - def return_content_type(filename): +def ungz_file( + file: UploadFile, gz_uncompressed_content_type: Union[str, None] = None +) -> UploadFile: + def return_content_type(filename: str): if gz_uncompressed_content_type: return gz_uncompressed_content_type else: @@ -740,7 +771,7 @@ def pipeline_1( status_code=status.HTTP_406_NOT_ACCEPTABLE, ) - def response_generator(is_multipart): + def response_generator(is_multipart: bool): for file in files: file_content_type = get_validated_mimetype(file) @@ -768,8 +799,7 @@ def response_generator(is_multipart): m_new_after_n_chars=new_after_n_chars, m_max_characters=max_characters, ) - - if is_expected_response_type(media_type, type(response)): + if is_expected_response_type(media_type, type(response)): # type: ignore raise HTTPException( detail=( f"Conflict in media type {media_type}" @@ -792,14 +822,14 @@ def response_generator(is_multipart): status_code=status.HTTP_406_NOT_ACCEPTABLE, ) - def join_responses(responses): + def join_responses(responses: List[Any]): if media_type != "text/csv": return responses - data = pd.read_csv(io.BytesIO(responses[0].body)) + data = pd.read_csv(io.BytesIO(responses[0].body)) # type: ignore if len(responses) > 1: for resp in responses[1:]: - resp_data = pd.read_csv(io.BytesIO(resp.body)) - data = data.merge(resp_data, how="outer") + resp_data = pd.read_csv(io.BytesIO(resp.body)) # type: ignore + data = data.merge(resp_data, how="outer") # type: ignore return PlainTextResponse(data.to_csv()) if content_type == "multipart/mixed":