-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathapp.py
More file actions
179 lines (141 loc) · 5.04 KB
/
app.py
File metadata and controls
179 lines (141 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
""" Backend for Knowledge Base Chat App """
import json
import logging
import os
import sys
import asyncio
import httpx
from dotenv import dotenv_values, load_dotenv
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from fastapi.staticfiles import StaticFiles
from starlette.exceptions import HTTPException as StarletteHTTPException
import chatbot
import collections_loader as cl
from helpers import logging_config
# Load local env vars if present
load_dotenv()
# Initialize logger
logging_config()
_logger = logging.getLogger(__name__)
# Get config
config = {
**dotenv_values(".env"), # load shared development variables
**dotenv_values(".env.secret"), # load sensitive variables
**os.environ, # override loaded values with environment variables
}
_logger.info(f"Config loaded...")
# Load configuration from JSON file
config_file = config.get("CONFIG_FILE")
if config_file and os.path.exists(config_file):
with open(config_file, "r") as file:
config_data = json.load(file)
config.update(config_data)
_logger.info(f"Configuration loaded from {config_file}")
else:
_logger.warning(f"Config file {config_file} not found or not specified")
# Get collections, LLMs, vectorstore and embeddings config
collections_config = config.get("collections")
llms_config = config.get("llms")
vectorstore_config = config.get("vectorstore")
embeddings_config = config.get("embeddings")
# Load collections
collections_loader = cl.CollectionsLoader(
collections_config, vectorstore_config, _logger
)
collections = collections_loader.load_collections()
total_milvus_collections = sum(len(collection.versions) for collection in collections)
_logger.info(
f"Loaded {total_milvus_collections} versioned collection(s) across {len(collections)} collection(s)"
)
# Initialize Chatbot
chatbot = chatbot.Chatbot(config, _logger)
# App creation
app = FastAPI()
origins = ["*"]
methods = ["*"]
headers = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=methods,
allow_headers=headers,
)
# Connection Manager for Websockets
class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
#############################
# API Endpoints definitions #
#############################
# Status
@app.get("/health")
async def health():
"""Basic status"""
return {"message": "Status:OK"}
@app.get("/api/llms")
async def get_llms():
"""Get llms"""
return llms_config
# Collections
@app.get("/api/collections")
async def get_collections():
"""Get collections"""
return collections
async def handle_client_request(websocket: WebSocket, data: dict):
async for next_item in chatbot.stream(
data["model"],
data["query"],
data["collection"],
data["collection_full_name"],
data["version"],
data["language"],
):
answer = json.dumps(next_item)
await websocket.send_text(answer)
@app.websocket("/ws/query/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
data = json.loads(data)
asyncio.create_task(handle_client_request(websocket, data))
except WebSocketDisconnect:
manager.disconnect(websocket)
_logger.info(f"Client {client_id} disconnected")
# Serve React App (frontend)
class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
if len(sys.argv) > 1 and sys.argv[1] == "dev":
# We are in Dev mode, proxy to the React dev server
async with httpx.AsyncClient() as client:
response = await client.get(f"http://localhost:9000/{path}")
return Response(response.text, status_code=response.status_code)
else:
try:
return await super().get_response(path, scope)
except (HTTPException, StarletteHTTPException) as ex:
if ex.status_code == 404:
return await super().get_response("index.html", scope)
else:
raise ex
app.mount("/", SPAStaticFiles(directory="public", html=True), name="spa-static-files")
# Launch the FastAPI server
if __name__ == "__main__":
from uvicorn import run
port = int(os.getenv("PORT", "5000"))
run("app:app", host="0.0.0.0", port=port)