Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autointent/_dump_tools/unit_dumpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from peft import PeftModel
from pydantic import BaseModel
from sklearn.base import BaseEstimator
from transformers import ( # type: ignore[attr-defined]
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
Expand Down
2 changes: 1 addition & 1 deletion autointent/context/data_handler/_stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numpy import typing as npt
from sklearn.model_selection import train_test_split
from skmultilearn.model_selection import IterativeStratification
from transformers import set_seed # type: ignore[attr-defined]
from transformers import set_seed

from autointent import Dataset
from autointent.custom_types import LabelType
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import ( # type: ignore[attr-defined]
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Expand Down
1 change: 1 addition & 0 deletions autointent/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

114 changes: 114 additions & 0 deletions autointent/server/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""FastAPI application for AutoIntent pipeline inference."""

import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from functools import lru_cache
from pathlib import Path

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from autointent import Pipeline
from autointent.custom_types import ListOfLabelsWithOOS


class Settings(BaseSettings):
"""Application settings loaded from environment variables."""

model_config = SettingsConfigDict(env_file=".env", env_prefix="AUTOINTENT_")
path: str = Field(..., description="Path to the optimized pipeline assets")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
path: str = Field(..., description="Path to the optimized pipeline assets")
path: Path = Field(..., description="Path to the optimized pipeline assets")

host: str = "127.0.0.1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
host: str = "127.0.0.1"
host: str = "0.0.0.0"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff ругается на это

port: int = 8013


class PredictRequest(BaseModel):
"""Request model for the predict endpoint."""

utterances: list[str] = Field(..., description="List of text utterances to classify")


class PredictResponse(BaseModel):
"""Response model for the predict endpoint."""

predictions: ListOfLabelsWithOOS = Field(..., description="List of predicted class labels")


settings = Settings()
logger = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def load_pipeline() -> Pipeline:
"""Load the optimized pipeline from disk."""
pipeline_path = Path(settings.path)
if not pipeline_path.exists():
msg = f"Pipeline path does not exist: {pipeline_path}"
logger.error(msg)
raise HTTPException(status_code=404, detail=msg)

try:
msg = f"Loading pipeline from: {pipeline_path}"
logger.info(msg)
pipeline = Pipeline.load(pipeline_path)
logger.info("Pipeline loaded successfully")

except Exception:
logger.exception("Failed to load pipeline")
raise
else:
return pipeline


@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
"""Load pipe."""
load_pipeline()
yield


app = FastAPI(
title="AutoIntent Pipeline API",
description="API for serving AutoIntent predictions",
version="0.0.1",
lifespan=lifespan,
)


@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy"}


@app.post("/predict")
async def predict(request: PredictRequest) -> PredictResponse:
"""Predict class labels for the given utterances.

Args:
request: Request containing list of utterances to classify

Returns:
Response containing predicted class labels
"""
current_pipeline = load_pipeline()

if not request.utterances:
return PredictResponse(predictions=[])

predictions = current_pipeline.predict(request.utterances)

return PredictResponse(predictions=predictions)


def main() -> None:
"""Main entry point for the HTTP server."""
import uvicorn

uvicorn.run(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Можно потом добавить пример как запускать с fastapi run

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

а зачем это может быть полезно? я никогда так не запускал, не шарю

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Альтернатива такому запуску https://fastapi.tiangolo.com/deployment/manually/. Тут можно поменять на uvicorn.run(app)

"autointent.server.http:app",
host=settings.host,
port=settings.port,
reload=False,
)
Loading
Loading