11from fastapi import FastAPI , BackgroundTasks , HTTPException
22from pydantic import BaseModel
3- import numpy as np
43import pickle
4+ import numpy as np
5+ import os
56from datetime import datetime
6- from logger import log_prediction , get_recent_predictions
7- from drift_detector import compute_drift_score , should_alert
8- from discord_alert import send_alert
9- import asyncio
7+ from typing import List , Optional
8+ import json
9+
10+ # Import your modules
11+ from app .logger import log_prediction , get_recent_predictions , get_initial_predictions , init_db
12+ from app .drift_detector import compute_drift_score , should_alert
13+ from app .discord_alert import send_alert
14+ from app import config
15+
16+ # Initialize database on startup
17+ init_db ()
18+
19+ app = FastAPI (title = "MLOps Drift Detection API" )
20+
21+ # Load model
22+ model_path = os .path .join (os .path .dirname (__file__ ), "model.pkl" )
23+ if not os .path .exists (model_path ):
24+ print (f"Warning: Model file not found at { model_path } " )
25+ model = None
26+ else :
27+ with open (model_path , 'rb' ) as f :
28+ model = pickle .load (f )
29+
30+ class InputData (BaseModel ):
31+ features : List [float ]
32+
33+ class PredictionResponse (BaseModel ):
34+ prediction : int
35+ timestamp : str
36+
37+ class DriftResponse (BaseModel ):
38+ drift_score : float
39+ alert : bool
40+ threshold : float
41+ recent_samples : int
42+ baseline_samples : int
43+
44+ @app .get ("/" )
45+ async def root ():
46+ return {
47+ "message" : "MLOps Drift Detection API" ,
48+ "status" : "operational" ,
49+ "model_loaded" : model is not None
50+ }
51+
52+ @app .get ("/healthz" )
53+ async def healthz ():
54+ """Health check endpoint for Render"""
55+ return {"status" : "healthy" , "timestamp" : datetime .now ().isoformat ()}
56+
57+ @app .post ("/predict" , response_model = PredictionResponse )
58+ async def predict (data : InputData , background_tasks : BackgroundTasks ):
59+ """Make a prediction and log it for drift detection"""
60+ if model is None :
61+ raise HTTPException (status_code = 503 , detail = "Model not loaded" )
62+
63+ # Convert to numpy array and reshape
64+ features = np .array (data .features ).reshape (1 , - 1 )
65+
66+ # Make prediction
67+ pred = model .predict (features )[0 ]
68+
69+ # Log prediction asynchronously
70+ timestamp = datetime .now ()
71+ background_tasks .add_task (log_prediction , data .features , int (pred ), timestamp )
72+
73+ return {
74+ "prediction" : int (pred ),
75+ "timestamp" : timestamp .isoformat ()
76+ }
77+
78+ @app .get ("/drift_score" , response_model = DriftResponse )
79+ async def get_drift (background_tasks : BackgroundTasks ):
80+ """Calculate current drift score"""
81+ # Get recent predictions (last 100)
82+ recent = get_recent_predictions (limit = 100 )
83+
84+ # Get baseline predictions (first 100)
85+ baseline = get_initial_predictions (limit = 100 )
86+
87+ if len (recent ) < 10 or len (baseline ) < 10 :
88+ return {
89+ "drift_score" : 0.0 ,
90+ "alert" : False ,
91+ "threshold" : config .ALERT_THRESHOLD ,
92+ "recent_samples" : len (recent ),
93+ "baseline_samples" : len (baseline )
94+ }
95+
96+ # Compute drift score
97+ score = compute_drift_score (baseline , recent , method = 'psi' )
98+ alert = should_alert (score , config .ALERT_THRESHOLD )
99+
100+ # Send alert if needed
101+ if alert :
102+ background_tasks .add_task (send_alert , f"🚨 Drift detected! Score: { score :.3f} " )
103+
104+ return {
105+ "drift_score" : round (score , 3 ),
106+ "alert" : alert ,
107+ "threshold" : config .ALERT_THRESHOLD ,
108+ "recent_samples" : len (recent ),
109+ "baseline_samples" : len (baseline )
110+ }
10111
112+ @app .get ("/stats" )
113+ async def get_stats ():
114+ """Get prediction statistics"""
115+ recent = get_recent_predictions (limit = 1000 )
116+ baseline = get_initial_predictions (limit = 100 )
117+
118+ return {
119+ "total_predictions" : len (recent ),
120+ "baseline_size" : len (baseline ),
121+ "recent_distribution" : {
122+ "0" : recent .count (0 ) if recent else 0 ,
123+ "1" : recent .count (1 ) if recent else 0 ,
124+ "2" : recent .count (2 ) if recent else 0
125+ } if recent else {}
126+ }
0 commit comments