-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
64 lines (54 loc) · 2.44 KB
/
main.py
File metadata and controls
64 lines (54 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from typing import List, Union
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from fastembed.embedding import FlagEmbedding as Embedding
import numpy as np
import uvicorn
class Prompt(BaseModel):
prompt: str = Field(default="what time is it?")
class Document(BaseModel):
document: List[str] = Field(default=["what time is it?","get the time","i want to debug"])
class Config(BaseModel):
model: str = Field(default="BAAI/bge-small-en")
max_length: int = Field(default=512)
threads: int = Field(default=0)
cache_dir = os.getenv("CACHE_DIR", "./cache")
app = FastAPI()
@app.post("/embeddings/prompt")
def submit_prompt(body_json: Prompt):
"""Submit the prompt to the embedder."""
if not hasattr(app.state, "embedding_model"):
raise HTTPException(status_code=500, detail="Embedding model not loaded yet.")
prompt_before = body_json.prompt
embeddings: List[np.ndarray] = list(app.state.embedding_model.passage_embed(prompt_before)) # notice that we are casting the generator to a list
embeddings = [embedding.tolist() for embedding in embeddings]
return embeddings[0]
@app.post("/embeddings/document")
def submit_document(body_json: Document):
"""Submit the prompt to the embedder."""
if not hasattr(app.state, "embedding_model"):
raise HTTPException(status_code=500, detail="Embedding model not loaded yet.")
prompt_before = body_json.document
embeddings: List[np.ndarray] = list(app.state.embedding_model.passage_embed(prompt_before)) # notice that we are casting the generator to a list
embeddings = [embedding.tolist() for embedding in embeddings]
return embeddings
@app.post("/embeddings")
def setup(config: Config):
try:
# Sometimes the cache dir is not created, so we create it here
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
embedding_model = Embedding(model_name=config.model,max_length=config.max_length, cache_dir=cache_dir, threads=config.threads)
app.state.embedding_model = embedding_model
return {"status": "ok"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True
)