-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_pipeline_api.py
More file actions
66 lines (48 loc) · 1.78 KB
/
inference_pipeline_api.py
File metadata and controls
66 lines (48 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import opik
from fastapi import FastAPI, HTTPException
from opik import opik_context
from pydantic import BaseModel
from core import settings
from core.application.rag.retriever import ContextRetriever
from core.application.utils import misc
from core.domain.embedded_chunks import EmbeddedChunk
from core.infrastructure.opik_utils import configure_opik
from core.model.inference import InferenceExecutor, LLMInferenceTransformers
configure_opik()
app = FastAPI()
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
answer: str
@opik.track
def call_llm_service(query: str, context: str | None) -> str:
llm = LLMInferenceTransformers(
model_id=settings.HUGGINGFACE_INFERENCE_MODEL_ID,
)
answer = InferenceExecutor(llm, query, context).execute()
return answer
@opik.track
def rag(query: str) -> str:
retriever = ContextRetriever(mock=False)
documents = retriever.search(query, k=3)
context = EmbeddedChunk.to_context(documents)
answer = call_llm_service(query, context)
opik_context.update_current_trace(
tags=["rag"],
metadata={
"model_id": settings.HUGGINGFACE_INFERENCE_MODEL_ID,
"embedding_model_id": settings.TEXT_EMBEDDING_MODEL_ID,
"temperature": settings.TEMPERATURE_INFERENCE,
"query_tokens": misc.compute_num_tokens(query),
"context_tokens": misc.compute_num_tokens(context),
"answer_tokens": misc.compute_num_tokens(answer),
},
)
return answer
@app.post("/rag", response_model=QueryResponse)
async def rag_endpoint(request: QueryRequest):
try:
answer = rag(query=request.query)
return {"answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e