From 994b5801f4411bd5e6c43ed2f79e85b9b1a4b833 Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 9 May 2024 11:14:17 +0200 Subject: [PATCH 1/7] refactor: move endpoints to separate file --- prepline_general/api/app.py | 7 +- prepline_general/api/endpoints.py | 167 ++++++++++++++++++++++++++++++ prepline_general/api/general.py | 158 +--------------------------- 3 files changed, 173 insertions(+), 159 deletions(-) create mode 100644 prepline_general/api/endpoints.py diff --git a/prepline_general/api/app.py b/prepline_general/api/app.py index 8a96a84d..302629c3 100644 --- a/prepline_general/api/app.py +++ b/prepline_general/api/app.py @@ -1,10 +1,9 @@ -from fastapi import FastAPI, Request, status, HTTPException +from fastapi import FastAPI, APIRouter, Request, status, HTTPException from fastapi.responses import JSONResponse -from fastapi.security import APIKeyHeader import logging import os -from .general import router as general_router +from .endpoints import router as general_router from .openapi import set_custom_openapi logger = logging.getLogger("unstructured_api") @@ -30,6 +29,8 @@ openapi_tags=[{"name": "general"}], ) +router = APIRouter() + # Note(austin) - This logger just dumps exceptions # We'd rather handle those below, so disable this in deployments uvicorn_logger = logging.getLogger("uvicorn.error") diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py new file mode 100644 index 00000000..f98e240e --- /dev/null +++ b/prepline_general/api/endpoints.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import io +import json +import os +from typing import List, Sequence, Dict, Any, cast, Union + +import pandas as pd +from fastapi import FastAPI, APIRouter, UploadFile, Depends, HTTPException +from starlette import status +from starlette.requests import Request +from starlette.responses import PlainTextResponse + +from prepline_general.api.general import ( + _validate_chunking_strategy, + ungz_file, + get_validated_mimetype, + pipeline_api, + MultipartMixedResponse, +) +from prepline_general.api.models.form_params import GeneralFormParams + +app = FastAPI() +router = APIRouter() +app.include_router(router) + + +@router.post( + "/general/v0/general", + openapi_extra={"x-speakeasy-name-override": "partition"}, + tags=["general"], + summary="Summary", + description="Description", + operation_id="partition_parameters", +) +@router.post("/general/v0.0.67/general", include_in_schema=False) +def general_partition( + request: Request, + # cannot use annotated type here because of a bug described here: + # https://github.com/tiangolo/fastapi/discussions/10280 + # The openapi metadata must be added separately in openapi.py file. + # TODO: Check if the bug is fixed and change the declaration to use Annoteted[List[UploadFile], File(...)] + # For new parameters - add them in models/form_params.py + files: List[UploadFile], + form_params: GeneralFormParams = Depends(GeneralFormParams.as_form), +): + # -- must have a valid API key -- + if api_key_env := os.environ.get("UNSTRUCTURED_API_KEY"): + api_key = request.headers.get("unstructured-api-key") + if api_key != api_key_env: + raise HTTPException( + detail=f"API key {api_key} is invalid", status_code=status.HTTP_401_UNAUTHORIZED + ) + + content_type = request.headers.get("Accept") + + # -- detect response content-type conflict when multiple files are uploaded -- + if ( + len(files) > 1 + and content_type + and content_type + not in [ + "*/*", + "multipart/mixed", + "application/json", + "text/csv", + ] + ): + raise HTTPException( + detail=f"Conflict in media type {content_type} with response type 'multipart/mixed'.\n", + status_code=status.HTTP_406_NOT_ACCEPTABLE, + ) + + # -- validate other arguments -- + chunking_strategy = _validate_chunking_strategy(form_params.chunking_strategy) + + # -- unzip any uploaded files that need it -- + for idx, file in enumerate(files): + is_content_type_gz = file.content_type == "application/gzip" + is_extension_gz = file.filename and file.filename.endswith(".gz") + if is_content_type_gz or is_extension_gz: + files[idx] = ungz_file(file, form_params.gz_uncompressed_content_type) + + def response_generator(is_multipart: bool): + for file in files: + file_content_type = get_validated_mimetype(file) + + _file = file.file + + response = pipeline_api( + _file, + request=request, + coordinates=form_params.coordinates, + encoding=form_params.encoding, + hi_res_model_name=form_params.hi_res_model_name, + include_page_breaks=form_params.include_page_breaks, + ocr_languages=form_params.ocr_languages, + pdf_infer_table_structure=form_params.pdf_infer_table_structure, + skip_infer_table_types=form_params.skip_infer_table_types, + strategy=form_params.strategy, + xml_keep_tags=form_params.xml_keep_tags, + response_type=form_params.output_format, + filename=str(file.filename), + file_content_type=file_content_type, + languages=form_params.languages, + extract_image_block_types=form_params.extract_image_block_types, + unique_element_ids=form_params.unique_element_ids, + # -- chunking options -- + chunking_strategy=chunking_strategy, + combine_under_n_chars=form_params.combine_under_n_chars, + max_characters=form_params.max_characters, + multipage_sections=form_params.multipage_sections, + new_after_n_chars=form_params.new_after_n_chars, + overlap=form_params.overlap, + overlap_all=form_params.overlap_all, + starting_page_number=form_params.starting_page_number, + ) + + yield ( + json.dumps(response) + if is_multipart and type(response) not in [str, bytes] + else ( + PlainTextResponse(response) + if not is_multipart and form_params.output_format == "text/csv" + else response + ) + ) + + def join_responses( + responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse] + ) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: + """Consolidate partitionings from multiple documents into single response payload.""" + if form_params.output_format != "text/csv": + return cast(List[Union[str, List[Dict[str, Any]]]], responses) + responses = cast(List[PlainTextResponse], responses) + data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] + io.BytesIO(responses[0].body) + ) + if len(responses) > 1: + for resp in responses[1:]: + resp_data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] + io.BytesIO(resp.body) + ) + data = data.merge( # pyright: ignore[reportUnknownMemberType] + resp_data, how="outer" + ) + return PlainTextResponse(data.to_csv()) + + return ( + MultipartMixedResponse( + response_generator(is_multipart=True), content_type=form_params.output_format + ) + if content_type == "multipart/mixed" + else ( + list(response_generator(is_multipart=False))[0] + if len(files) == 1 + else join_responses(list(response_generator(is_multipart=False))) + ) + ) + + +@router.get("/general/v0/general", include_in_schema=False) +@router.get("/general/v0.0.67/general", include_in_schema=False) +async def handle_invalid_get_request(): + raise HTTPException( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Only POST requests are supported." + ) diff --git a/prepline_general/api/general.py b/prepline_general/api/general.py index e0c86b82..7a27045d 100644 --- a/prepline_general/api/general.py +++ b/prepline_general/api/general.py @@ -12,28 +12,22 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from types import TracebackType -from typing import IO, Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import IO, Any, Dict, List, Mapping, Optional, Sequence, Tuple import backoff -import pandas as pd import psutil import requests from fastapi import ( - APIRouter, - Depends, - FastAPI, HTTPException, Request, UploadFile, - status, ) -from fastapi.responses import PlainTextResponse, StreamingResponse +from fastapi.responses import StreamingResponse from pypdf import PageObject, PdfReader, PdfWriter from pypdf.errors import FileNotDecryptedError, PdfReadError from starlette.datastructures import Headers from starlette.types import Send -from prepline_general.api.models.form_params import GeneralFormParams from unstructured.documents.elements import Element from unstructured.partition.auto import partition from unstructured.staging.base import ( @@ -44,9 +38,6 @@ from unstructured_inference.models.base import UnknownModelException from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES -app = FastAPI() -router = APIRouter() - def is_compatible_response_type(media_type: str, response_type: type) -> bool: """True when `response_type` can be converted to `media_type` for HTTP Response.""" @@ -701,148 +692,3 @@ def return_content_type(filename: str): filename=filename, headers=Headers({"content-type": return_content_type(filename)}), ) - - -@router.get("/general/v0/general", include_in_schema=False) -@router.get("/general/v0.0.67/general", include_in_schema=False) -async def handle_invalid_get_request(): - raise HTTPException( - status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Only POST requests are supported." - ) - - -@router.post( - "/general/v0/general", - openapi_extra={"x-speakeasy-name-override": "partition"}, - tags=["general"], - summary="Summary", - description="Description", - operation_id="partition_parameters", -) -@router.post("/general/v0.0.67/general", include_in_schema=False) -def general_partition( - request: Request, - # cannot use annotated type here because of a bug described here: - # https://github.com/tiangolo/fastapi/discussions/10280 - # The openapi metadata must be added separately in openapi.py file. - # TODO: Check if the bug is fixed and change the declaration to use Annoteted[List[UploadFile], File(...)] - # For new parameters - add them in models/form_params.py - files: List[UploadFile], - form_params: GeneralFormParams = Depends(GeneralFormParams.as_form), -): - # -- must have a valid API key -- - if api_key_env := os.environ.get("UNSTRUCTURED_API_KEY"): - api_key = request.headers.get("unstructured-api-key") - if api_key != api_key_env: - raise HTTPException( - detail=f"API key {api_key} is invalid", status_code=status.HTTP_401_UNAUTHORIZED - ) - - content_type = request.headers.get("Accept") - - # -- detect response content-type conflict when multiple files are uploaded -- - if ( - len(files) > 1 - and content_type - and content_type - not in [ - "*/*", - "multipart/mixed", - "application/json", - "text/csv", - ] - ): - raise HTTPException( - detail=f"Conflict in media type {content_type} with response type 'multipart/mixed'.\n", - status_code=status.HTTP_406_NOT_ACCEPTABLE, - ) - - # -- validate other arguments -- - chunking_strategy = _validate_chunking_strategy(form_params.chunking_strategy) - - # -- unzip any uploaded files that need it -- - for idx, file in enumerate(files): - is_content_type_gz = file.content_type == "application/gzip" - is_extension_gz = file.filename and file.filename.endswith(".gz") - if is_content_type_gz or is_extension_gz: - files[idx] = ungz_file(file, form_params.gz_uncompressed_content_type) - - def response_generator(is_multipart: bool): - for file in files: - file_content_type = get_validated_mimetype(file) - - _file = file.file - - response = pipeline_api( - _file, - request=request, - coordinates=form_params.coordinates, - encoding=form_params.encoding, - hi_res_model_name=form_params.hi_res_model_name, - include_page_breaks=form_params.include_page_breaks, - ocr_languages=form_params.ocr_languages, - pdf_infer_table_structure=form_params.pdf_infer_table_structure, - skip_infer_table_types=form_params.skip_infer_table_types, - strategy=form_params.strategy, - xml_keep_tags=form_params.xml_keep_tags, - response_type=form_params.output_format, - filename=str(file.filename), - file_content_type=file_content_type, - languages=form_params.languages, - extract_image_block_types=form_params.extract_image_block_types, - unique_element_ids=form_params.unique_element_ids, - # -- chunking options -- - chunking_strategy=chunking_strategy, - combine_under_n_chars=form_params.combine_under_n_chars, - max_characters=form_params.max_characters, - multipage_sections=form_params.multipage_sections, - new_after_n_chars=form_params.new_after_n_chars, - overlap=form_params.overlap, - overlap_all=form_params.overlap_all, - starting_page_number=form_params.starting_page_number, - ) - - yield ( - json.dumps(response) - if is_multipart and type(response) not in [str, bytes] - else ( - PlainTextResponse(response) - if not is_multipart and form_params.output_format == "text/csv" - else response - ) - ) - - def join_responses( - responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse] - ) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: - """Consolidate partitionings from multiple documents into single response payload.""" - if form_params.output_format != "text/csv": - return cast(List[Union[str, List[Dict[str, Any]]]], responses) - responses = cast(List[PlainTextResponse], responses) - data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] - io.BytesIO(responses[0].body) - ) - if len(responses) > 1: - for resp in responses[1:]: - resp_data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] - io.BytesIO(resp.body) - ) - data = data.merge( # pyright: ignore[reportUnknownMemberType] - resp_data, how="outer" - ) - return PlainTextResponse(data.to_csv()) - - return ( - MultipartMixedResponse( - response_generator(is_multipart=True), content_type=form_params.output_format - ) - if content_type == "multipart/mixed" - else ( - list(response_generator(is_multipart=False))[0] - if len(files) == 1 - else join_responses(list(response_generator(is_multipart=False))) - ) - ) - - -app.include_router(router) From bc72b1c1438bbcd25f5b4bcc4bd0ba80c5fefbaf Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 9 May 2024 12:17:57 +0200 Subject: [PATCH 2/7] refactor: improve router structure --- prepline_general/api/app.py | 9 ++++----- prepline_general/api/endpoints.py | 4 +--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/prepline_general/api/app.py b/prepline_general/api/app.py index 302629c3..f47d539c 100644 --- a/prepline_general/api/app.py +++ b/prepline_general/api/app.py @@ -1,10 +1,10 @@ -from fastapi import FastAPI, APIRouter, Request, status, HTTPException +from fastapi import FastAPI, Request, status, HTTPException from fastapi.responses import JSONResponse import logging import os -from .endpoints import router as general_router -from .openapi import set_custom_openapi +from prepline_general.api.endpoints import router as general_router +from prepline_general.api.openapi import set_custom_openapi logger = logging.getLogger("unstructured_api") @@ -29,7 +29,7 @@ openapi_tags=[{"name": "general"}], ) -router = APIRouter() +app.include_router(general_router) # Note(austin) - This logger just dumps exceptions # We'd rather handle those below, so disable this in deployments @@ -62,7 +62,6 @@ async def error_handler(request: Request, e: Exception): allow_headers=["Content-Type"], ) -app.include_router(general_router) set_custom_openapi(app) diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py index f98e240e..7a26ae09 100644 --- a/prepline_general/api/endpoints.py +++ b/prepline_general/api/endpoints.py @@ -6,7 +6,7 @@ from typing import List, Sequence, Dict, Any, cast, Union import pandas as pd -from fastapi import FastAPI, APIRouter, UploadFile, Depends, HTTPException +from fastapi import APIRouter, UploadFile, Depends, HTTPException from starlette import status from starlette.requests import Request from starlette.responses import PlainTextResponse @@ -20,9 +20,7 @@ ) from prepline_general.api.models.form_params import GeneralFormParams -app = FastAPI() router = APIRouter() -app.include_router(router) @router.post( From 20cefe7e9542470c6c0f3afdfad8a45c62ccfac8 Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 23 May 2024 11:09:25 +0200 Subject: [PATCH 3/7] refactor: move out code from app.py --- prepline_general/api/app.py | 22 ---------------------- prepline_general/api/endpoints.py | 5 +++++ prepline_general/api/logging.py | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 22 deletions(-) create mode 100644 prepline_general/api/logging.py diff --git a/prepline_general/api/app.py b/prepline_general/api/app.py index f47d539c..b98e1506 100644 --- a/prepline_general/api/app.py +++ b/prepline_general/api/app.py @@ -65,26 +65,4 @@ async def error_handler(request: Request, e: Exception): set_custom_openapi(app) - -# Filter out /healthcheck noise -class HealthCheckFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return record.getMessage().find("/healthcheck") == -1 - - -# Filter out /metrics noise -class MetricsCheckFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return record.getMessage().find("/metrics") == -1 - - -logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) -logging.getLogger("uvicorn.access").addFilter(MetricsCheckFilter()) - - -@app.get("/healthcheck", status_code=status.HTTP_200_OK, include_in_schema=False) -def healthcheck(request: Request): - return {"healthcheck": "HEALTHCHECK STATUS: EVERYTHING OK!"} - - logger.info("Started Unstructured API") diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py index 7a26ae09..266c3b11 100644 --- a/prepline_general/api/endpoints.py +++ b/prepline_general/api/endpoints.py @@ -163,3 +163,8 @@ async def handle_invalid_get_request(): raise HTTPException( status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Only POST requests are supported." ) + + +@router.get("/healthcheck", status_code=status.HTTP_200_OK, include_in_schema=False) +def healthcheck(request: Request): + return {"healthcheck": "HEALTHCHECK STATUS: EVERYTHING OK!"} diff --git a/prepline_general/api/logging.py b/prepline_general/api/logging.py new file mode 100644 index 00000000..696bb3fc --- /dev/null +++ b/prepline_general/api/logging.py @@ -0,0 +1,17 @@ +import logging + + +# Filter out /healthcheck noise +class HealthCheckFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.getMessage().find("/healthcheck") == -1 + + +# Filter out /metrics noise +class MetricsCheckFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.getMessage().find("/metrics") == -1 + + +logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) +logging.getLogger("uvicorn.access").addFilter(MetricsCheckFilter()) From e42412394324c69e68bfc7a152c94326ecc97480 Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 23 May 2024 14:07:55 +0200 Subject: [PATCH 4/7] refactor: move out code from general.py --- prepline_general/api/app.py | 2 +- prepline_general/api/endpoints.py | 6 +- prepline_general/api/general.py | 152 +--------------------- prepline_general/api/memory_protection.py | 40 ++++++ prepline_general/api/validation.py | 108 +++++++++++++++ 5 files changed, 156 insertions(+), 152 deletions(-) create mode 100644 prepline_general/api/memory_protection.py create mode 100644 prepline_general/api/validation.py diff --git a/prepline_general/api/app.py b/prepline_general/api/app.py index b98e1506..90f20a2d 100644 --- a/prepline_general/api/app.py +++ b/prepline_general/api/app.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Request, status, HTTPException +from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse import logging import os diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py index 266c3b11..3b40454b 100644 --- a/prepline_general/api/endpoints.py +++ b/prepline_general/api/endpoints.py @@ -12,12 +12,10 @@ from starlette.responses import PlainTextResponse from prepline_general.api.general import ( - _validate_chunking_strategy, ungz_file, - get_validated_mimetype, - pipeline_api, - MultipartMixedResponse, + MultipartMixedResponse, pipeline_api, ) +from prepline_general.api.validation import _validate_chunking_strategy, get_validated_mimetype from prepline_general.api.models.form_params import GeneralFormParams router = APIRouter() diff --git a/prepline_general/api/general.py b/prepline_general/api/general.py index 7a27045d..11cab525 100644 --- a/prepline_general/api/general.py +++ b/prepline_general/api/general.py @@ -11,7 +11,6 @@ from base64 import b64encode from concurrent.futures import ThreadPoolExecutor from functools import partial -from types import TracebackType from typing import IO, Any, Dict, List, Mapping, Optional, Sequence, Tuple import backoff @@ -19,34 +18,21 @@ import requests from fastapi import ( HTTPException, - Request, UploadFile, ) from fastapi.responses import StreamingResponse -from pypdf import PageObject, PdfReader, PdfWriter -from pypdf.errors import FileNotDecryptedError, PdfReadError +from pypdf import PdfReader, PageObject, PdfWriter from starlette.datastructures import Headers +from starlette.requests import Request from starlette.types import Send - from unstructured.documents.elements import Element from unstructured.partition.auto import partition -from unstructured.staging.base import ( - convert_to_dataframe, - convert_to_isd, - elements_from_json, -) +from unstructured.staging.base import elements_from_json, convert_to_dataframe, convert_to_isd from unstructured_inference.models.base import UnknownModelException from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES - -def is_compatible_response_type(media_type: str, response_type: type) -> bool: - """True when `response_type` can be converted to `media_type` for HTTP Response.""" - return ( - False - if media_type == "application/json" and response_type not in [dict, list] - else False if media_type == "text/csv" and response_type != str else True - ) - +from prepline_general.api.memory_protection import ChipperMemoryProtection +from prepline_general.api.validation import _check_pdf, _validate_hi_res_model_name, _validate_strategy logger = logging.getLogger("unstructured_api") @@ -235,37 +221,6 @@ def partition_pdf_splits( return results -is_chipper_processing = False - - -class ChipperMemoryProtection: - """Chipper calls are expensive, and right now we can only do one call at a time. - - If the model is in use, return a 503 error. The API should scale up and the user can try again - on a different server. - """ - - def __enter__(self): - global is_chipper_processing - if is_chipper_processing: - # Log here so we can track how often it happens - logger.error("Chipper is already is use") - raise HTTPException( - status_code=503, detail="Server is under heavy load. Please try again later." - ) - - is_chipper_processing = True - - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - exc_tb: Optional[TracebackType], - ): - global is_chipper_processing - is_chipper_processing = False - - def pipeline_api( file: IO[bytes], request: Request, @@ -521,108 +476,11 @@ def _check_free_memory(): ) -def _check_pdf(file: IO[bytes]): - """Check if the PDF file is encrypted, otherwise assume it is not a valid PDF.""" - try: - pdf = PdfReader(file) - - # This will raise if the file is encrypted - pdf.metadata - return pdf - except FileNotDecryptedError: - raise HTTPException( - status_code=400, - detail="File is encrypted. Please decrypt it with password.", - ) - except PdfReadError: - raise HTTPException(status_code=422, detail="File does not appear to be a valid PDF") - - -def _validate_strategy(strategy: str) -> str: - strategy = strategy.lower() - strategies = ["fast", "hi_res", "auto", "ocr_only"] - if strategy not in strategies: - raise HTTPException( - status_code=400, detail=f"Invalid strategy: {strategy}. Must be one of {strategies}" - ) - return strategy - - -def _validate_hi_res_model_name( - hi_res_model_name: Optional[str], show_coordinates: bool -) -> Optional[str]: - # Make sure chipper aliases to the latest model - if hi_res_model_name and hi_res_model_name == "chipper": - hi_res_model_name = "chipperv2" - - if hi_res_model_name and hi_res_model_name in CHIPPER_MODEL_TYPES and show_coordinates: - raise HTTPException( - status_code=400, - detail=f"coordinates aren't available when using the {hi_res_model_name} model type", - ) - return hi_res_model_name - - -def _validate_chunking_strategy(chunking_strategy: Optional[str]) -> Optional[str]: - """Raise on `chunking_strategy` is not a valid chunking strategy name. - - Also provides case-insensitivity. - """ - if chunking_strategy is None: - return None - - chunking_strategy = chunking_strategy.lower() - available_strategies = ["basic", "by_title"] - - if chunking_strategy not in available_strategies: - raise HTTPException( - status_code=400, - detail=( - f"Invalid chunking strategy: {chunking_strategy}. Must be one of" - f" {available_strategies}" - ), - ) - - return chunking_strategy - - def _set_pdf_infer_table_structure(pdf_infer_table_structure: bool, strategy: str) -> bool: """Avoids table inference in "fast" and "ocr_only" runs.""" return strategy in ("hi_res", "auto") and pdf_infer_table_structure -def get_validated_mimetype(file: UploadFile) -> Optional[str]: - """The MIME-type of `file`. - - The mimetype is computed based on `file.content_type`, or the mimetypes lib if that's too - generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and - return HTTP 400 for an invalid type. - """ - content_type = file.content_type - filename = str(file.filename) # -- "None" when file.filename is None -- - if not content_type or content_type == "application/octet-stream": - content_type = mimetypes.guess_type(filename)[0] - - # Some filetypes missing for this library, just hardcode them for now - if not content_type: - if filename.endswith(".md"): - content_type = "text/markdown" - elif filename.endswith(".msg"): - content_type = "message/rfc822" - - allowed_mimetypes_str = os.environ.get("UNSTRUCTURED_ALLOWED_MIMETYPES") - if allowed_mimetypes_str is not None: - allowed_mimetypes = allowed_mimetypes_str.split(",") - - if content_type not in allowed_mimetypes: - raise HTTPException( - status_code=400, - detail=(f"File type {content_type} is not supported."), - ) - - return content_type - - class MultipartMixedResponse(StreamingResponse): CRLF = b"\r\n" diff --git a/prepline_general/api/memory_protection.py b/prepline_general/api/memory_protection.py new file mode 100644 index 00000000..439acdb7 --- /dev/null +++ b/prepline_general/api/memory_protection.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import logging +from types import TracebackType +from typing import Optional + +from fastapi import HTTPException + +logger = logging.getLogger("unstructured_api") + + +is_chipper_processing = False + + +class ChipperMemoryProtection: + """Chipper calls are expensive, and right now we can only do one call at a time. + + If the model is in use, return a 503 error. The API should scale up and the user can try again + on a different server. + """ + + def __enter__(self): + global is_chipper_processing + if is_chipper_processing: + # Log here so we can track how often it happens + logger.error("Chipper is already is use") + raise HTTPException( + status_code=503, detail="Server is under heavy load. Please try again later." + ) + + is_chipper_processing = True + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType], + ): + global is_chipper_processing + is_chipper_processing = False diff --git a/prepline_general/api/validation.py b/prepline_general/api/validation.py new file mode 100644 index 00000000..df32d578 --- /dev/null +++ b/prepline_general/api/validation.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import mimetypes +import os + +from typing import IO, Optional + +from fastapi import HTTPException, UploadFile +from pypdf import PdfReader +from pypdf.errors import FileNotDecryptedError, PdfReadError +from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES + + +def _check_pdf(file: IO[bytes]): + """Check if the PDF file is encrypted, otherwise assume it is not a valid PDF.""" + try: + pdf = PdfReader(file) + + # This will raise if the file is encrypted + pdf.metadata + return pdf + except FileNotDecryptedError: + raise HTTPException( + status_code=400, + detail="File is encrypted. Please decrypt it with password.", + ) + except PdfReadError: + raise HTTPException(status_code=422, detail="File does not appear to be a valid PDF") + + +def _validate_strategy(strategy: str) -> str: + strategy = strategy.lower() + strategies = ["fast", "hi_res", "auto", "ocr_only"] + if strategy not in strategies: + raise HTTPException( + status_code=400, detail=f"Invalid strategy: {strategy}. Must be one of {strategies}" + ) + return strategy + + +def _validate_hi_res_model_name( + hi_res_model_name: Optional[str], show_coordinates: bool +) -> Optional[str]: + # Make sure chipper aliases to the latest model + if hi_res_model_name and hi_res_model_name == "chipper": + hi_res_model_name = "chipperv2" + + if hi_res_model_name and hi_res_model_name in CHIPPER_MODEL_TYPES and show_coordinates: + raise HTTPException( + status_code=400, + detail=f"coordinates aren't available when using the {hi_res_model_name} model type", + ) + return hi_res_model_name + + +def _validate_chunking_strategy(chunking_strategy: Optional[str]) -> Optional[str]: + """Raise on `chunking_strategy` is not a valid chunking strategy name. + + Also provides case-insensitivity. + """ + if chunking_strategy is None: + return None + + chunking_strategy = chunking_strategy.lower() + available_strategies = ["basic", "by_title"] + + if chunking_strategy not in available_strategies: + raise HTTPException( + status_code=400, + detail=( + f"Invalid chunking strategy: {chunking_strategy}. Must be one of" + f" {available_strategies}" + ), + ) + + return chunking_strategy + + +def get_validated_mimetype(file: UploadFile) -> Optional[str]: + """The MIME-type of `file`. + + The mimetype is computed based on `file.content_type`, or the mimetypes lib if that's too + generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and + return HTTP 400 for an invalid type. + """ + content_type = file.content_type + filename = str(file.filename) # -- "None" when file.filename is None -- + if not content_type or content_type == "application/octet-stream": + content_type = mimetypes.guess_type(filename)[0] + + # Some filetypes missing for this library, just hardcode them for now + if not content_type: + if filename.endswith(".md"): + content_type = "text/markdown" + elif filename.endswith(".msg"): + content_type = "message/rfc822" + + allowed_mimetypes_str = os.environ.get("UNSTRUCTURED_ALLOWED_MIMETYPES") + if allowed_mimetypes_str is not None: + allowed_mimetypes = allowed_mimetypes_str.split(",") + + if content_type not in allowed_mimetypes: + raise HTTPException( + status_code=400, + detail=(f"File type {content_type} is not supported."), + ) + + return content_type From 8ec67e46c0e9ddc625ba651e2d88abf4e84253c1 Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 23 May 2024 14:43:07 +0200 Subject: [PATCH 5/7] fix: lint --- prepline_general/api/endpoints.py | 40 ++++++++++++++++--------------- prepline_general/api/general.py | 6 ++++- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py index ea948b5b..c109ea5d 100644 --- a/prepline_general/api/endpoints.py +++ b/prepline_general/api/endpoints.py @@ -13,7 +13,8 @@ from prepline_general.api.general import ( ungz_file, - MultipartMixedResponse, pipeline_api, + MultipartMixedResponse, + pipeline_api, ) from prepline_general.api.validation import _validate_chunking_strategy, get_validated_mimetype from prepline_general.api.models.form_params import GeneralFormParams @@ -31,14 +32,14 @@ ) @router.post("/general/v0.0.68/general", include_in_schema=False) def general_partition( - request: Request, - # cannot use annotated type here because of a bug described here: - # https://github.com/tiangolo/fastapi/discussions/10280 - # The openapi metadata must be added separately in openapi.py file. - # TODO: Check if the bug is fixed and change the declaration to use Annoteted[List[UploadFile], File(...)] - # For new parameters - add them in models/form_params.py - files: List[UploadFile], - form_params: GeneralFormParams = Depends(GeneralFormParams.as_form), + request: Request, + # cannot use annotated type here because of a bug described here: + # https://github.com/tiangolo/fastapi/discussions/10280 + # The openapi metadata must be added separately in openapi.py file. + # TODO: Check if the bug is fixed and change the declaration to use Annoteted[List[UploadFile], File(...)] + # For new parameters - add them in models/form_params.py + files: List[UploadFile], + form_params: GeneralFormParams = Depends(GeneralFormParams.as_form), ): # -- must have a valid API key -- if api_key_env := os.environ.get("UNSTRUCTURED_API_KEY"): @@ -52,15 +53,15 @@ def general_partition( # -- detect response content-type conflict when multiple files are uploaded -- if ( - len(files) > 1 - and content_type - and content_type - not in [ - "*/*", - "multipart/mixed", - "application/json", - "text/csv", - ] + len(files) > 1 + and content_type + and content_type + not in [ + "*/*", + "multipart/mixed", + "application/json", + "text/csv", + ] ): raise HTTPException( detail=f"Conflict in media type {content_type} with response type 'multipart/mixed'.\n", @@ -123,7 +124,7 @@ def response_generator(is_multipart: bool): ) def join_responses( - responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse] + responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse] ) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: """Consolidate partitionings from multiple documents into single response payload.""" if form_params.output_format != "text/csv": @@ -154,6 +155,7 @@ def join_responses( ) ) + @router.get("/general/v0/general", include_in_schema=False) @router.get("/general/v0.0.68/general", include_in_schema=False) async def handle_invalid_get_request(): diff --git a/prepline_general/api/general.py b/prepline_general/api/general.py index 11cab525..9e6e97b6 100644 --- a/prepline_general/api/general.py +++ b/prepline_general/api/general.py @@ -32,7 +32,11 @@ from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES from prepline_general.api.memory_protection import ChipperMemoryProtection -from prepline_general.api.validation import _check_pdf, _validate_hi_res_model_name, _validate_strategy +from prepline_general.api.validation import ( + _check_pdf, + _validate_hi_res_model_name, + _validate_strategy, +) logger = logging.getLogger("unstructured_api") From 50efef5b2e4df16b1fe5e988b1ade0ba3e2656de Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 23 May 2024 14:46:56 +0200 Subject: [PATCH 6/7] fix: rename in makefile for versioning --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 250da0e5..82b39559 100644 --- a/Makefile +++ b/Makefile @@ -139,7 +139,7 @@ check-version: -s CHANGELOG.md \ -f preprocessing-pipeline-family.yaml release \ -f ${PACKAGE_NAME}/api/app.py release \ - -f ${PACKAGE_NAME}/api/general.py release + -f ${PACKAGE_NAME}/api/endpoints.py release ## version-sync: update references to version with most recent version from CHANGELOG.md .PHONY: version-sync @@ -148,4 +148,4 @@ version-sync: -s CHANGELOG.md \ -f preprocessing-pipeline-family.yaml release \ -f ${PACKAGE_NAME}/api/app.py release \ - -f ${PACKAGE_NAME}/api/general.py release + -f ${PACKAGE_NAME}/api/endpoints.py release From 212333f98a20de2fdd50c2caa715fd41c1b46165 Mon Sep 17 00:00:00 2001 From: "hubert.rutkowski" Date: Thu, 30 May 2024 09:51:54 +0200 Subject: [PATCH 7/7] refactor: extract the response_generator and join_responses functions --- prepline_general/api/endpoints.py | 154 +++++++++++++++++------------- 1 file changed, 85 insertions(+), 69 deletions(-) diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py index c109ea5d..4bfc990a 100644 --- a/prepline_general/api/endpoints.py +++ b/prepline_general/api/endpoints.py @@ -3,7 +3,7 @@ import io import json import os -from typing import List, Sequence, Dict, Any, cast, Union +from typing import List, Sequence, Dict, Any, cast, Union, Optional import pandas as pd from fastapi import APIRouter, UploadFile, Depends, HTTPException @@ -78,84 +78,100 @@ def general_partition( if is_content_type_gz or is_extension_gz: files[idx] = ungz_file(file, form_params.gz_uncompressed_content_type) - def response_generator(is_multipart: bool): - for file in files: - file_content_type = get_validated_mimetype(file) - - _file = file.file - - response = pipeline_api( - _file, - request=request, - coordinates=form_params.coordinates, - encoding=form_params.encoding, - hi_res_model_name=form_params.hi_res_model_name, - include_page_breaks=form_params.include_page_breaks, - ocr_languages=form_params.ocr_languages, - pdf_infer_table_structure=form_params.pdf_infer_table_structure, - skip_infer_table_types=form_params.skip_infer_table_types, - strategy=form_params.strategy, - xml_keep_tags=form_params.xml_keep_tags, - response_type=form_params.output_format, - filename=str(file.filename), - file_content_type=file_content_type, - languages=form_params.languages, - extract_image_block_types=form_params.extract_image_block_types, - unique_element_ids=form_params.unique_element_ids, - # -- chunking options -- - chunking_strategy=chunking_strategy, - combine_under_n_chars=form_params.combine_under_n_chars, - max_characters=form_params.max_characters, - multipage_sections=form_params.multipage_sections, - new_after_n_chars=form_params.new_after_n_chars, - overlap=form_params.overlap, - overlap_all=form_params.overlap_all, - starting_page_number=form_params.starting_page_number, - ) - - yield ( - json.dumps(response) - if is_multipart and type(response) not in [str, bytes] - else ( - PlainTextResponse(response) - if not is_multipart and form_params.output_format == "text/csv" - else response - ) - ) - - def join_responses( - responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse] - ) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: - """Consolidate partitionings from multiple documents into single response payload.""" - if form_params.output_format != "text/csv": - return cast(List[Union[str, List[Dict[str, Any]]]], responses) - responses = cast(List[PlainTextResponse], responses) - data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] - io.BytesIO(responses[0].body) - ) - if len(responses) > 1: - for resp in responses[1:]: - resp_data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] - io.BytesIO(resp.body) - ) - data = data.merge( # pyright: ignore[reportUnknownMemberType] - resp_data, how="outer" - ) - return PlainTextResponse(data.to_csv()) - return ( MultipartMixedResponse( - response_generator(is_multipart=True), content_type=form_params.output_format + response_generator(files, request, form_params, chunking_strategy, is_multipart=True), + content_type=form_params.output_format, ) if content_type == "multipart/mixed" else ( - list(response_generator(is_multipart=False))[0] + list( + response_generator( + files, request, form_params, chunking_strategy, is_multipart=False + ) + )[0] if len(files) == 1 - else join_responses(list(response_generator(is_multipart=False))) + else join_responses( + form_params, + list( + response_generator( + files, request, form_params, chunking_strategy, is_multipart=False + ) + ), + ) ) ) +def join_responses( + form_params: GeneralFormParams, + responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse], +) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: + """Consolidate partitionings from multiple documents into single response payload.""" + if form_params.output_format != "text/csv": + return cast(List[Union[str, List[Dict[str, Any]]]], responses) + responses = cast(List[PlainTextResponse], responses) + data = pd.read_csv(io.BytesIO(responses[0].body)) # pyright: ignore[reportUnknownMemberType] + if len(responses) > 1: + for resp in responses[1:]: + resp_data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] + io.BytesIO(resp.body) + ) + data = data.merge(resp_data, how="outer") # pyright: ignore[reportUnknownMemberType] + return PlainTextResponse(data.to_csv()) + + +def response_generator( + files: List[UploadFile], + request: Request, + form_params: GeneralFormParams, + chunking_strategy: Optional[str], + is_multipart: bool, +): + for file in files: + file_content_type = get_validated_mimetype(file) + _file = file.file + + response = pipeline_api( + _file, + request=request, + coordinates=form_params.coordinates, + encoding=form_params.encoding, + hi_res_model_name=form_params.hi_res_model_name, + include_page_breaks=form_params.include_page_breaks, + ocr_languages=form_params.ocr_languages, + pdf_infer_table_structure=form_params.pdf_infer_table_structure, + skip_infer_table_types=form_params.skip_infer_table_types, + strategy=form_params.strategy, + xml_keep_tags=form_params.xml_keep_tags, + response_type=form_params.output_format, + filename=str(file.filename), + file_content_type=file_content_type, + languages=form_params.languages, + extract_image_block_types=form_params.extract_image_block_types, + unique_element_ids=form_params.unique_element_ids, + # -- chunking options -- + chunking_strategy=chunking_strategy, + combine_under_n_chars=form_params.combine_under_n_chars, + max_characters=form_params.max_characters, + multipage_sections=form_params.multipage_sections, + new_after_n_chars=form_params.new_after_n_chars, + overlap=form_params.overlap, + overlap_all=form_params.overlap_all, + starting_page_number=form_params.starting_page_number, + ) + + yield ( + json.dumps(response) + if is_multipart and type(response) not in [str, bytes] + else ( + PlainTextResponse(response) + if not is_multipart and form_params.output_format == "text/csv" + else response + ) + ) + + @router.get("/general/v0/general", include_in_schema=False) @router.get("/general/v0.0.68/general", include_in_schema=False) async def handle_invalid_get_request():