|
| 1 | +""" |
| 2 | +RAG Knowledge API Main Application Module |
| 3 | +
|
| 4 | +This module initializes and configures the FastAPI application for the RAG backend. |
| 5 | +It sets up CORS middleware, loads configuration and data, and wires together the |
| 6 | +Gemini-based Router, Retriever, and Responder components into a chat endpoint. |
| 7 | +""" |
| 8 | + |
1 | 9 | import pandas as pd |
2 | 10 | import structlog |
| 11 | +import uvicorn |
| 12 | +from fastapi import APIRouter, FastAPI |
| 13 | +from fastapi.middleware.cors import CORSMiddleware |
3 | 14 | from qdrant_client import QdrantClient |
4 | 15 |
|
5 | 16 | from flare_ai_rag.ai import GeminiEmbedding, GeminiProvider |
| 17 | +from flare_ai_rag.api import ChatRouter |
6 | 18 | from flare_ai_rag.responder import GeminiResponder, ResponderConfig |
7 | 19 | from flare_ai_rag.retriever import QdrantRetriever, RetrieverConfig, generate_collection |
8 | 20 | from flare_ai_rag.router import GeminiRouter, RouterConfig |
9 | 21 | from flare_ai_rag.settings import settings |
10 | | -from flare_ai_rag.utils import load_json, load_txt, save_json |
| 22 | +from flare_ai_rag.utils import load_json |
11 | 23 |
|
12 | 24 | logger = structlog.get_logger(__name__) |
13 | 25 |
|
@@ -83,58 +95,72 @@ def setup_responder(input_config: dict) -> GeminiResponder: |
83 | 95 | return GeminiResponder(client=gemini_provider, responder_config=responder_config) |
84 | 96 |
|
85 | 97 |
|
86 | | -def main() -> None: |
| 98 | +def create_app() -> FastAPI: |
| 99 | + """ |
| 100 | + Create and configure the FastAPI application instance. |
| 101 | +
|
| 102 | + This function: |
| 103 | + 1. Creates a new FastAPI instance with optional CORS middleware. |
| 104 | + 2. Loads configuration. |
| 105 | + 3. Sets up the Gemini Router, Qdrant Retriever, and Gemini Responder. |
| 106 | + 4. Loads RAG data and (re)generates the Qdrant collection. |
| 107 | + 5. Initializes a ChatRouter that wraps the RAG pipeline. |
| 108 | + 6. Registers the chat endpoint under the /chat prefix. |
| 109 | +
|
| 110 | + Returns: |
| 111 | + FastAPI: The configured FastAPI application instance. |
| 112 | + """ |
| 113 | + app = FastAPI(title="RAG Knowledge API", version="1.0", redirect_slashes=False) |
| 114 | + |
| 115 | + # Optional: configure CORS middleware using settings. |
| 116 | + app.add_middleware( |
| 117 | + CORSMiddleware, |
| 118 | + allow_origins=settings.cors_origins, |
| 119 | + allow_credentials=True, |
| 120 | + allow_methods=["*"], |
| 121 | + allow_headers=["*"], |
| 122 | + ) |
| 123 | + |
87 | 124 | # Load input configuration. |
88 | 125 | input_config = load_json(settings.input_path / "input_parameters.json") |
89 | 126 |
|
90 | | - # Set up the Gemini Router |
91 | | - router = setup_router(input_config) |
92 | | - |
93 | | - # Load data |
| 127 | + # Load RAG data. |
94 | 128 | df_docs = pd.read_csv(settings.data_path / "docs.csv", delimiter=",") |
95 | 129 | logger.info("Loaded CSV Data.", num_rows=len(df_docs)) |
96 | 130 |
|
97 | | - # Set up qdrant client. |
| 131 | + # Set up the RAG components: 1. Gemini Router |
| 132 | + router_component = setup_router(input_config) |
| 133 | + |
| 134 | + # 2a. Set up Qdrant client. |
98 | 135 | qdrant_client = setup_qdrant(input_config) |
99 | 136 |
|
100 | | - # Set up retriever. (Use Gemini Embedding.) |
101 | | - retriever = setup_retriever(qdrant_client, input_config, df_docs) |
| 137 | + # 2b. Set up the Retriever. |
| 138 | + retriever_component = setup_retriever(qdrant_client, input_config, df_docs) |
102 | 139 |
|
103 | | - # Set up responder. (Use Gemini Provider.) |
104 | | - responder = setup_responder(input_config) |
| 140 | + # 3. Set up the Responder. |
| 141 | + responder_component = setup_responder(input_config) |
105 | 142 |
|
106 | | - # Process user query. |
107 | | - query = load_txt(settings.input_path / "query.txt") |
108 | | - classification = router.route_query(query) |
109 | | - logger.info( |
110 | | - "Queried has been classified by the Router.", classification=classification |
| 143 | + # Create an APIRouter for chat endpoints and initialize ChatRouter. |
| 144 | + chat_router = ChatRouter( |
| 145 | + router=APIRouter(), |
| 146 | + query_router=router_component, |
| 147 | + retriever=retriever_component, |
| 148 | + responder=responder_component, |
111 | 149 | ) |
| 150 | + app.include_router(chat_router.router, prefix="/api/routes/chat", tags=["chat"]) |
| 151 | + |
| 152 | + return app |
112 | 153 |
|
113 | | - if classification == "ANSWER": |
114 | | - retrieved_docs = retriever.semantic_search(query, top_k=5) |
115 | | - logger.info("Docs have been retrieved.") |
116 | 154 |
|
117 | | - # Prepare answer |
118 | | - answer = responder.generate_response(query, retrieved_docs) |
119 | | - logger.info("Response has been generated.", answer=answer) |
| 155 | +app = create_app() |
120 | 156 |
|
121 | | - # Save answer |
122 | | - output_file = settings.data_path / "rag_answer.json" |
123 | | - save_json( |
124 | | - { |
125 | | - "query": query, |
126 | | - "answer": answer, |
127 | | - }, |
128 | | - output_file, |
129 | | - ) |
130 | 157 |
|
131 | | - elif classification == "CLARIFY": |
132 | | - logger.info("Your query needs clarification. Please provide more details.") |
133 | | - elif classification == "REJECT": |
134 | | - logger.info("Your query has been rejected as it is out of scope.") |
135 | | - else: |
136 | | - logger.info("Unexpected classification.", classification=classification) |
| 158 | +def start() -> None: |
| 159 | + """ |
| 160 | + Start the FastAPI application server. |
| 161 | + """ |
| 162 | + uvicorn.run(app, host="0.0.0.0", port=8080) # noqa: S104 |
137 | 163 |
|
138 | 164 |
|
139 | 165 | if __name__ == "__main__": |
140 | | - main() |
| 166 | + start() |
0 commit comments