11"""FastAPI endpoints for OEWS Data Agent."""
22
3- from fastapi import FastAPI , HTTPException
3+ from fastapi import FastAPI , HTTPException , Request
44from fastapi .middleware .cors import CORSMiddleware
55from contextlib import asynccontextmanager
66import time
77import os
88import asyncio
9+ from asyncio import Semaphore
910from concurrent .futures import ThreadPoolExecutor
1011from typing import AsyncGenerator
12+ from slowapi import Limiter , _rate_limit_exceeded_handler
13+ from slowapi .util import get_remote_address
14+ from slowapi .errors import RateLimitExceeded
1115
1216from src .api .models import (
1317 QueryRequest ,
@@ -62,6 +66,22 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
6266 lifespan = lifespan
6367)
6468
69+ # Initialize rate limiter
70+ limiter = Limiter (
71+ key_func = get_remote_address ,
72+ default_limits = [os .getenv ("RATE_LIMIT_DEFAULT" , "100/hour" )],
73+ storage_uri = "memory://"
74+ )
75+
76+ # Register limiter with app
77+ app .state .limiter = limiter
78+ app .add_exception_handler (RateLimitExceeded , _rate_limit_exceeded_handler )
79+
80+ # Limit concurrent workflow executions
81+ max_concurrent_requests = Semaphore (
82+ int (os .getenv ("MAX_CONCURRENT_REQUESTS" , "8" ))
83+ )
84+
6585# Add CORS middleware
6686# Get frontend origins from environment (comma-separated)
6787cors_origins_str = os .getenv ("CORS_ORIGINS" , "http://localhost:3000" )
@@ -151,7 +171,8 @@ async def health_check():
151171 503 : {"model" : ErrorResponse }
152172 }
153173)
154- async def query (request : QueryRequest ) -> QueryResponse :
174+ @limiter .limit (os .getenv ("RATE_LIMIT_QUERY_ENDPOINT" , "10/hour" ))
175+ async def query (request : Request , query_request : QueryRequest ) -> QueryResponse :
155176 """
156177 Process a natural language query about OEWS employment data.
157178
@@ -166,7 +187,8 @@ async def query(request: QueryRequest) -> QueryResponse:
166187 5. Format the response
167188
168189 Args:
169- request: Query request with natural language question and optional model overrides
190+ request: FastAPI Request object (for rate limiting)
191+ query_request: Query request with natural language question and optional model overrides
170192
171193 Returns:
172194 Formatted response with answer, charts, and metadata
@@ -180,92 +202,101 @@ async def query(request: QueryRequest) -> QueryResponse:
180202 detail = "Workflow not initialized. Check API keys and configuration."
181203 )
182204
205+ # Check if system is overloaded
206+ if max_concurrent_requests .locked ():
207+ raise HTTPException (
208+ status_code = 503 ,
209+ detail = "System at capacity. Please retry in 1-2 minutes." ,
210+ headers = {"Retry-After" : "60" }
211+ )
212+
183213 # Record start time
184214 start_time = time .time ()
185215
186216 # DIAGNOSTIC: Test if logging works in API process
187217 api_logger .debug ("query_received" , extra = {
188218 "data" : {
189- "query" : request .query [:100 ],
190- "enable_charts" : request .enable_charts ,
191- "reasoning_model" : request .reasoning_model or "default" ,
192- "implementation_model" : request .implementation_model or "default"
219+ "query" : query_request .query [:100 ],
220+ "enable_charts" : query_request .enable_charts ,
221+ "reasoning_model" : query_request .reasoning_model or "default" ,
222+ "implementation_model" : query_request .implementation_model or "default"
193223 }
194224 })
195225
196226 try :
197- # Prepare initial state with model overrides
198- enabled_agents = ["cortex_researcher" , "synthesizer" ]
199- if request .enable_charts :
200- enabled_agents .insert (- 1 , "chart_generator" )
201-
202- initial_state = {
203- "messages" : [],
204- "user_query" : request .query ,
205- "enabled_agents" : enabled_agents ,
206- "plan" : {},
207- "current_step" : 0 ,
208- "max_steps" : 10 ,
209- "replans" : 0 ,
210- "model_usage" : {},
211- # Pass model overrides to workflow (note: lowercase key names match state structure)
212- "reasoning_model" : request .reasoning_model ,
213- "implementation_model" : request .implementation_model
214- }
215-
216- # Invoke workflow with timeout
217- # Run blocking workflow_graph.invoke() in thread pool to avoid blocking event loop
218- # Force flush API log
219- import logging
220- for handler in api_logger .handlers :
221- handler .flush ()
222-
223- # Define blocking workflow execution
224- def run_workflow ():
225- return workflow_graph .invoke (
226- initial_state ,
227- config = {"recursion_limit" : 100 }
228- )
229-
230- # Run with timeout (5 minutes)
231- loop = asyncio .get_event_loop ()
232- try :
233- result = await asyncio .wait_for (
234- loop .run_in_executor (executor , run_workflow ),
235- timeout = REQUEST_TIMEOUT
236- )
237- except asyncio .TimeoutError :
238- raise HTTPException (
239- status_code = 504 ,
240- detail = f"Request processing exceeded { REQUEST_TIMEOUT } second timeout. The query may be too complex or the system is under heavy load."
241- )
227+ async with max_concurrent_requests :
228+ # Prepare initial state with model overrides
229+ enabled_agents = ["cortex_researcher" , "synthesizer" ]
230+ if query_request .enable_charts :
231+ enabled_agents .insert (- 1 , "chart_generator" )
232+
233+ initial_state = {
234+ "messages" : [],
235+ "user_query" : query_request .query ,
236+ "enabled_agents" : enabled_agents ,
237+ "plan" : {},
238+ "current_step" : 0 ,
239+ "max_steps" : 10 ,
240+ "replans" : 0 ,
241+ "model_usage" : {},
242+ # Pass model overrides to workflow (note: lowercase key names match state structure)
243+ "reasoning_model" : query_request .reasoning_model ,
244+ "implementation_model" : query_request .implementation_model
245+ }
242246
243- # Extract formatted response
244- formatted = result .get ("formatted_response" , {})
245-
246- # Calculate execution time
247- execution_time = int ((time .time () - start_time ) * 1000 )
248-
249- # Build response
250- response = QueryResponse (
251- answer = formatted .get ("answer" , result .get ("final_answer" , "No answer generated." )),
252- charts = [
253- ChartSpec (** chart )
254- for chart in formatted .get ("charts" , [])
255- ],
256- data_sources = [
257- DataSource (** source )
258- for source in formatted .get ("data_sources" , [])
259- ],
260- metadata = Metadata (
261- models_used = result .get ("model_usage" , {}),
262- execution_time_ms = execution_time ,
263- plan = result .get ("plan" ),
264- replans = result .get ("replans" , 0 )
247+ # Invoke workflow with timeout
248+ # Run blocking workflow_graph.invoke() in thread pool to avoid blocking event loop
249+ # Force flush API log
250+ import logging
251+ for handler in api_logger .handlers :
252+ handler .flush ()
253+
254+ # Define blocking workflow execution
255+ def run_workflow ():
256+ return workflow_graph .invoke (
257+ initial_state ,
258+ config = {"recursion_limit" : 100 }
259+ )
260+
261+ # Run with timeout (5 minutes)
262+ loop = asyncio .get_event_loop ()
263+ try :
264+ result = await asyncio .wait_for (
265+ loop .run_in_executor (executor , run_workflow ),
266+ timeout = REQUEST_TIMEOUT
267+ )
268+ except asyncio .TimeoutError :
269+ raise HTTPException (
270+ status_code = 504 ,
271+ detail = f"Request processing exceeded { REQUEST_TIMEOUT } second timeout. The query may be too complex or the system is under heavy load."
272+ )
273+
274+ # Extract formatted response
275+ formatted = result .get ("formatted_response" , {})
276+
277+ # Calculate execution time
278+ execution_time = int ((time .time () - start_time ) * 1000 )
279+
280+ # Build response
281+ response = QueryResponse (
282+ answer = formatted .get ("answer" , result .get ("final_answer" , "No answer generated." )),
283+ charts = [
284+ ChartSpec (** chart )
285+ for chart in formatted .get ("charts" , [])
286+ ],
287+ data_sources = [
288+ DataSource (** source )
289+ for source in formatted .get ("data_sources" , [])
290+ ],
291+ metadata = Metadata (
292+ models_used = result .get ("model_usage" , {}),
293+ execution_time_ms = execution_time ,
294+ plan = result .get ("plan" ),
295+ replans = result .get ("replans" , 0 )
296+ )
265297 )
266- )
267298
268- return response
299+ return response
269300
270301 except Exception as e :
271302 # Sanitize error message before sending to client
@@ -274,7 +305,7 @@ def run_workflow():
274305 # Log full error server-side
275306 api_logger .error ("query_failed" , extra = {
276307 "data" : {
277- "query" : request .query ,
308+ "query" : query_request .query ,
278309 "error_type" : type (e ).__name__ ,
279310 "error_message" : str (e )
280311 }
0 commit comments