-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
102 lines (84 loc) · 2.55 KB
/
api.py
File metadata and controls
102 lines (84 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
FastAPI web service for RAG system
"""
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
from src.rag_engine import SecureRAGEngine
from api_models import (
QueryRequest,
QueryResponse,
HealthResponse,
StatsResponse
)
# Global RAG engine
rag_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown logic"""
global rag_engine
# Startup
logging.info("Loading RAG engine...")
rag_engine = SecureRAGEngine()
rag_engine.ingest_documents("./documents/")
logging.info("RAG engine ready!")
yield
# Shutdown
logging.info("Shutting down...")
# Create FastAPI app
app = FastAPI(
title="Enterprise RAG API",
description="Secure RAG with hybrid search",
version="1.0.0",
lifespan=lifespan
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Simple API keys
API_KEYS = {"dev_key_123", "prod_key_456"}
def verify_api_key(api_key: str = Header(..., alias="X-API-Key")):
if api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key
@app.get("/", tags=["Root"])
def root():
return {
"message": "Enterprise RAG API",
"docs": "/docs",
"health": "/health"
}
@app.get("/health", response_model=HealthResponse, tags=["System"])
def health_check():
return {"status": "healthy", "version": "1.0.0"}
@app.get("/stats", response_model=StatsResponse, tags=["System"])
def get_stats(api_key: str = Header(..., alias="X-API-Key")):
verify_api_key(api_key)
return {
"total_chunks": len(rag_engine.retriever.all_chunks),
"total_documents": len(set([m.get('file', 'unknown')
for m in rag_engine.collection.get()['metadatas']]))
}
@app.post("/query", response_model=QueryResponse, tags=["RAG"])
def query_documents(
request: QueryRequest,
api_key: str = Header(..., alias="X-API-Key")
):
"""Query the RAG system"""
verify_api_key(api_key)
result = rag_engine.query(request.question, user_id=request.user_id)
if 'error' in result:
raise HTTPException(status_code=400, detail=result['error'])
return QueryResponse(
answer=result['answer'],
sources=result['sources']
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)