Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/badges/coverage.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"schemaVersion":1,"label":"coverage","message":"49.88%","color":"red"}
{"schemaVersion":1,"label":"coverage","message":"51.02%","color":"red"}
120 changes: 0 additions & 120 deletions api/endpoints/files.py

This file was deleted.

87 changes: 2 additions & 85 deletions api/endpoints/ocr.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
import base64
from contextvars import ContextVar

from fastapi import APIRouter, Depends, File, HTTPException, Request, Security, UploadFile
from fastapi import APIRouter, Depends, Request, Security
from fastapi.responses import JSONResponse
import pymupdf
from redis.asyncio import Redis as AsyncRedis
from sqlalchemy.ext.asyncio import AsyncSession

from api.helpers._accesscontroller import AccessController
from api.helpers.models import ModelRegistry
from api.schemas.core.context import RequestContext
from api.schemas.core.documents import FileType
from api.schemas.core.models import RequestContent
from api.schemas.exception import HTTPExceptionModel
from api.schemas.ocr import OCR, CreateOCR, DPIForm, ModelForm, PromptForm
from api.schemas.parse import ParsedDocument, ParsedDocumentMetadata, ParsedDocumentPage
from api.schemas.usage import Usage
from api.utils.context import global_context
from api.schemas.ocr import OCR, CreateOCR
from api.utils.dependencies import get_model_registry, get_postgres_session, get_redis_client, get_request_context
from api.utils.exceptions import (
FileSizeLimitExceededException,
ModelIsTooBusyException,
ModelNotFoundException,
WrongModelTypeException,
Expand Down Expand Up @@ -66,79 +59,3 @@ async def ocr(
)

return JSONResponse(content=OCR(**response.json()).model_dump(), status_code=response.status_code)


@router.post(path=EndpointRoute.OCR_BETA, dependencies=[Security(dependency=AccessController())], status_code=200, response_model=ParsedDocument)
@hooks
async def ocr_beta(
request: Request,
file: UploadFile = File(..., description="The file to parse."),
model: str = ModelForm,
dpi: int = DPIForm,
prompt: str = PromptForm,
model_registry: ModelRegistry = Depends(get_model_registry),
redis_client: AsyncRedis = Depends(get_redis_client),
postgres_session: AsyncSession = Depends(get_postgres_session),
request_context: ContextVar[RequestContext] = Depends(get_request_context),
) -> JSONResponse:
"""
Extracts text from PDF files using OCR.
"""
# check if file is a pdf (raises UnsupportedFileTypeException if not a PDF)
global_context.document_manager.parser_manager.check_file_type(file=file, type=FileType.PDF)

# check file size
if file.size > FileSizeLimitExceededException.MAX_CONTENT_SIZE:
raise FileSizeLimitExceededException()

file_content = await file.read()
pdf = pymupdf.open(stream=file_content, filetype="pdf")
document = ParsedDocument(data=[], usage=Usage())
for i, page in enumerate(pdf):
image = page.get_pixmap(dpi=dpi)
img_byte_arr = image.tobytes("png")
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(img_byte_arr).decode("utf-8")}"}},
],
}
],
"n": 1,
"stream": False,
}

model_provider = await model_registry.get_model_provider(
model=model,
endpoint=EndpointRoute.OCR_BETA,
postgres_session=postgres_session,
redis_client=redis_client,
request_context=request_context,
)

response = await model_provider.forward_request(
request_content=RequestContent(method="POST", endpoint=EndpointRoute.CHAT_COMPLETIONS, json=payload, model=model),
redis_client=redis_client,
)
status = response.status_code
body_json = response.json()
if status // 100 != 2:
pdf.close()
raise HTTPException(status_code=status, detail=body_json.get("detail", "OCR request failed"))
text = body_json.get("choices", [{}])[0].get("message", {}).get("content", "")
document.data.append(
ParsedDocumentPage(
content=text,
images={},
metadata=ParsedDocumentMetadata(page=i, document_name=file.filename, **pdf.metadata),
)
)
if body_json.get("usage"):
document.usage = Usage(**body_json["usage"])
pdf.close()

return JSONResponse(content=document.model_dump(), status_code=200)
8 changes: 7 additions & 1 deletion api/endpoints/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
router = APIRouter(prefix="/v1", tags=[RouterName.PARSE.title()])


@router.post(path=EndpointRoute.PARSE, dependencies=[Security(dependency=AccessController())], status_code=200, response_model=ParsedDocument)
@router.post(
path=EndpointRoute.PARSE,
dependencies=[Security(dependency=AccessController())],
status_code=200,
response_model=ParsedDocument,
deprecated=True,
)
async def parse(
request: Request,
data: Annotated[CreateParseForm, Depends(CreateParseForm.as_form)],
Expand Down
13 changes: 0 additions & 13 deletions api/helpers/_accesscontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ async def __call__(
if request.url.path.endswith(EndpointRoute.EMBEDDINGS) and request.method in ["POST"]:
await self._check_embeddings(body=body, user_info=user_info, postgres_session=postgres_session)

if request.url.path.endswith(EndpointRoute.FILES) and request.method in ["POST"]:
await self._check_files(user_info=user_info, postgres_session=postgres_session)

if request.url.path.endswith(EndpointRoute.OCR) and request.method in ["POST"]:
await self._check_ocr(body=body, user_info=user_info, postgres_session=postgres_session)

Expand Down Expand Up @@ -160,16 +157,6 @@ async def _check_embeddings(body: dict, user_info: UserInfo, postgres_session: A
prompt_tokens = global_context.tokenizer.get_prompt_tokens(endpoint=EndpointRoute.EMBEDDINGS, body=body)
await global_context.limiter.check_user_limits(user_info=user_info, router_id=router_id, prompt_tokens=prompt_tokens)

@staticmethod
async def _check_files(user_info: UserInfo, postgres_session: AsyncSession) -> None:
router_id = await global_context.model_registry.get_router_id_from_model_name(
model_name=global_context.document_manager.vector_store_model,
postgres_session=postgres_session,
)
if router_id is None:
return
await global_context.limiter.check_user_limits(user_info=user_info, router_id=router_id)

@staticmethod
async def _check_ocr(body: dict, user_info: UserInfo, postgres_session: AsyncSession) -> None:
router_id = await global_context.model_registry.get_router_id_from_model_name(model_name=body.get("model"), postgres_session=postgres_session)
Expand Down
1 change: 1 addition & 0 deletions api/helpers/_documentmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ async def create_document(
content = await self.parser_manager.parse(file=file)
except Exception as e:
logger.exception(f"failed to parse {document_name} ({e}).")
print(e)
raise ParsingDocumentFailedException()

# split the content into chunks
Expand Down
1 change: 0 additions & 1 deletion api/helpers/models/_modelregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class ModelRegistry:
EndpointRoute.CHAT_COMPLETIONS: [ModelType.TEXT_GENERATION, ModelType.IMAGE_TEXT_TO_TEXT],
EndpointRoute.EMBEDDINGS: [ModelType.TEXT_EMBEDDINGS_INFERENCE],
EndpointRoute.OCR: [ModelType.IMAGE_TO_TEXT],
EndpointRoute.OCR_BETA: [ModelType.IMAGE_TEXT_TO_TEXT],
EndpointRoute.RERANK: [ModelType.TEXT_CLASSIFICATION],
}

Expand Down
1 change: 0 additions & 1 deletion api/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class RouterDefinition:
RouterDefinition(name=RouterName.COLLECTIONS, module_path="api.endpoints.collections"),
RouterDefinition(name=RouterName.DOCUMENTS, module_path="api.endpoints.documents"),
RouterDefinition(name=RouterName.EMBEDDINGS, module_path="api.endpoints.embeddings"),
RouterDefinition(name=RouterName.FILES, module_path="api.endpoints.files"), # Inexistant ?
RouterDefinition(name=RouterName.MODELS, module_path="api.infrastructure.fastapi.endpoints.models"),
RouterDefinition(name=RouterName.OCR, module_path="api.endpoints.ocr"),
RouterDefinition(name=RouterName.PARSE, module_path="api.endpoints.parse"),
Expand Down
6 changes: 3 additions & 3 deletions api/schemas/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class Model(ConfigBaseModel):

name: constr(strip_whitespace=True, min_length=1, max_length=64) = Field(..., description="Unique name exposed to clients when selecting the model.", examples=["gpt-4o"]) # fmt: off
type: ModelType = Field(..., description="Type of the model. It will be used to identify the model type.", examples=["text-generation"]) # fmt: off
aliases: list[constr(strip_whitespace=True, min_length=1, max_length=64)] = Field(default_factory=list, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]], json_extra_schema={"default": []}) # fmt: off
aliases: list[constr(strip_whitespace=True, min_length=1, max_length=64)] = Field(default_factory=list, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]], json_schema_extra={"default": []}) # fmt: off
load_balancing_strategy: RouterLoadBalancingStrategy = Field(default=RouterLoadBalancingStrategy.SHUFFLE, description="Routing strategy for load balancing between providers of the model.", examples=["least_busy"]) # fmt: off
cost_prompt_tokens: float = Field(default=0.0, ge=0.0, description="Model costs prompt tokens for user budget computation. The cost is by 1M tokens.", examples=[0.1]) # fmt: off
cost_completion_tokens: float = Field(default=0.0, ge=0.0, description="Model costs completion tokens for user budget computation. The cost is by 1M tokens. Set to `0.0` to disable budget computation for this model.", examples=[0.1]) # fmt: off
Expand Down Expand Up @@ -183,7 +183,7 @@ class AlbertDependency(ConfigBaseModel):
"""

url: constr(strip_whitespace=True, min_length=1) = Field(default="https://albert.api.etalab.gouv.fr", description="Albert API url.") # fmt: off
headers: dict[str, str] = Field(default_factory=dict, description="Albert API request headers.", examples=[{"Authorization": "Bearer my-api-key"}], json_extra_schema={"default": {}}) # fmt: off
headers: dict[str, str] = Field(default_factory=dict, description="Albert API request headers.", examples=[{"Authorization": "Bearer my-api-key"}], json_schema_extra={"default": {}}) # fmt: off
timeout: int = Field(default=DEFAULT_TIMEOUT, ge=1, description="Timeout for the Albert API requests.", examples=[10]) # fmt: off


Expand Down Expand Up @@ -375,7 +375,7 @@ class Settings(ConfigBaseModel):
swagger_contact: dict | None = Field(default=None, description="Contact informations of the API in swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.") # fmt: off
swagger_license_info: dict = Field(default={"name": "MIT Licence", "identifier": "MIT", "url": "https://raw.githubusercontent.com/etalab-ia/opengatellm/refs/heads/main/LICENSE"}, description="Licence informations of the API in swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.") # fmt: off
swagger_terms_of_service: str | None = Field(default=None, description="A URL to the Terms of Service for the API in swagger UI. If provided, this has to be a URL.", examples=["https://example.com/terms-of-service"]) # fmt: off
swagger_openapi_tags: list[dict[str, str | dict[str, str]]] = Field(default_factory=list, description="OpenAPI tags of the API in swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.", json_extra_schema={"default": []}) # fmt: off
swagger_openapi_tags: list[dict[str, str | dict[str, str]]] = Field(default_factory=list, description="OpenAPI tags of the API in swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.", json_schema_extra={"default": []}) # fmt: off
swagger_openapi_url: str = Field(default="/openapi.json", pattern=r"^/", description="OpenAPI URL of swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.") # fmt: off
swagger_docs_url: str = Field(default="/docs", pattern=r"^/", description="Docs URL of swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.") # fmt: off
swagger_redoc_url: str = Field(default="/redoc", pattern=r"^/", description="Redoc URL of swagger UI, see https://fastapi.tiangolo.com/tutorial/metadata for more information.") # fmt: off
Expand Down
3 changes: 0 additions & 3 deletions api/schemas/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class ProviderEndpoints(BaseModel):
embeddings: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, pattern=r"^/", to_lower=True), Field(default=None)]
models: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, pattern=r"^/", to_lower=True), Field(default=None)]
ocr: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, pattern=r"^/", to_lower=True), Field(default=None)]
ocr_beta: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, pattern=r"^/", to_lower=True), Field(default=None)]
rerank: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, pattern=r"^/", to_lower=True), Field(default=None)]

def get_endpoint(self, endpoint: EndpointRoute) -> str | None:
Expand All @@ -30,8 +29,6 @@ def get_endpoint(self, endpoint: EndpointRoute) -> str | None:
return self.models
elif endpoint == EndpointRoute.OCR:
return self.ocr
elif endpoint == EndpointRoute.OCR_BETA:
return self.ocr_beta
elif endpoint == EndpointRoute.RERANK:
return self.rerank
else:
Expand Down
Loading