Skip to content

Commit f9b6094

Browse files
Raady for deployment
1 parent 689aed0 commit f9b6094

16 files changed

Lines changed: 626 additions & 48 deletions

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
__pycache__/
2+
*.pyc
3+
*.db
4+
.env
5+
.DS_Store
6+
.pytest_cache/
7+
reset_db.py
File renamed without changes.

app/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# app/config.py
2+
import os
3+
4+
# Alert threshold (15%)
5+
ALERT_THRESHOLD = float(os.getenv("ALERT_THRESHOLD", 0.15))
6+
7+
# Discord webhook URL
8+
DISCORD_WEBHOOK = os.getenv("DISCORD_WEBHOOK", "")
9+
10+
# Database path
11+
DB_PATH = os.getenv("DB_PATH", "predictions.db")
12+
13+
# Model path
14+
MODEL_PATH = os.path.join(os.path.dirname(__file__), "model.pkl")

app/discord_alert.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import os
22
import requests
3+
import logging
34

45
WEBHOOK_URL = os.getenv("DISCORD_WEBHOOK", "https://discord.com/api/webhooks/1474862103138402306/UQo03RbfP4LXkxVNUVn_p-ZbLZXnmkBbrbDVJqOynKDbt32pBvi-TuCwrIRZl1_FGFvB")
56

67
def send_alert(message):
7-
data={"content": message}
8-
requests.post(WEBHOOK_URL,json=data)
8+
if not WEBHOOK_URL:
9+
logging.warning("DISCORD WEBHOOD not set, skipping alert.")
10+
return False
11+
try:
12+
data = {"content": message}
13+
response = requests.post(WEBHOOK_URL, json=data)
14+
response.raise_for_status()
15+
return True
16+
except Exception as e:
17+
logging.error(f"Failed to send Discord alert: {e}")
18+
919

app/drift_detector.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,92 @@
11
import numpy as np
2-
from scipy.stats import ks_2samp
2+
from scipy.stats import ks_2samp, chi2_contingency
3+
import logging
34

4-
def compute_drift_score(baseline, recent):
5+
def compute_drift_score(baseline, recent, method='psi'):
56
"""
67
Returns a drift score between 0 and 1.
7-
Here we use Kolmogorov-Smirnov statistic (p-value complement)
8+
9+
Methods:
10+
- 'psi': Population Stability Index (best for categorical)
11+
- 'ks': Kolmogorov-Smirnov (for continuous)
12+
- 'chi2': Chi-square test (for categorical)
813
"""
914
if len(baseline) == 0 or len(recent) == 0:
1015
return 0.0
11-
# If categorical predictions (e.g., classes), use PSI or simple distribution difference
12-
# For simplicity, we treat predictions as continuous (class indices) and use KS
13-
ks_stat, p_value = ks_2samp(baseline, recent)
14-
# Convert p-value to a score: lower p-value means higher drift
15-
drift_score = 1 - p_value # ranges 0-1
16+
17+
if method == 'psi':
18+
return compute_psi(baseline, recent)
19+
elif method == 'chi2':
20+
return compute_chi2_drift(baseline, recent)
21+
else:
22+
# Default to KS test
23+
ks_stat, p_value = ks_2samp(baseline, recent)
24+
# Convert to 0-1 scale where higher means more drift
25+
drift_score = 1 - p_value
26+
return min(max(drift_score, 0), 1) # Clamp between 0-1
27+
28+
def compute_psi(baseline, recent, bins=3):
29+
"""
30+
Population Stability Index
31+
PSI < 0.1: no significant drift
32+
PSI 0.1-0.2: moderate drift
33+
PSI > 0.2: significant drift
34+
"""
35+
# For categorical predictions (0,1,2), we can use the classes as bins
36+
unique_classes = sorted(set(baseline + recent))
37+
38+
# If we have all 3 classes, use them as bins
39+
if len(unique_classes) <= 3:
40+
# Calculate distributions
41+
baseline_counts = np.bincount(baseline, minlength=3)
42+
recent_counts = np.bincount(recent, minlength=3)
43+
44+
# Convert to percentages
45+
baseline_pct = baseline_counts / len(baseline)
46+
recent_pct = recent_counts / len(recent)
47+
48+
# Add small epsilon to avoid division by zero
49+
baseline_pct = np.clip(baseline_pct, 0.001, 0.999)
50+
recent_pct = np.clip(recent_pct, 0.001, 0.999)
51+
52+
# Calculate PSI
53+
psi = np.sum((recent_pct - baseline_pct) * np.log(recent_pct / baseline_pct))
54+
55+
# Normalize PSI to 0-1 scale (typical PSI > 0.2 is significant)
56+
# So we'll map PSI 0-0.2 to 0-1 scale
57+
normalized_psi = min(psi / 0.2, 1.0)
58+
59+
logging.info(f"PSI: {psi:.3f}, Normalized: {normalized_psi:.3f}")
60+
return normalized_psi
61+
62+
else:
63+
# Fallback to KS test if we have more classes
64+
ks_stat, p_value = ks_2samp(baseline, recent)
65+
return 1 - p_value
66+
67+
def compute_chi2_drift(baseline, recent):
68+
"""
69+
Chi-square test for categorical drift detection
70+
"""
71+
# Create contingency table
72+
unique_classes = sorted(set(baseline + recent))
73+
74+
# Count frequencies
75+
baseline_counts = [baseline.count(c) for c in unique_classes]
76+
recent_counts = [recent.count(c) for c in unique_classes]
77+
78+
# Create contingency table
79+
contingency = np.array([baseline_counts, recent_counts])
80+
81+
# Perform chi-square test
82+
chi2, p_value, dof, expected = chi2_contingency(contingency)
83+
84+
# Convert p-value to drift score (lower p-value = more drift)
85+
drift_score = 1 - p_value
86+
87+
logging.info(f"Chi2: {chi2:.3f}, p-value: {p_value:.3f}, Drift score: {drift_score:.3f}")
1688
return drift_score
1789

1890
def should_alert(score, threshold=0.15):
91+
"""Determine if drift score exceeds threshold"""
1992
return score > threshold

app/logger.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ def init_db():
88
conn=sqlite3.connect(DB_FILE)
99
c=conn.cursor()
1010
c.execute('''CREATE TABLE IF NOT EXISTS predictions
11-
id INTEGER PRIMARY KEY AUTOINCREMENT,
11+
(id INTEGER PRIMARY KEY AUTOINCREMENT,
1212
features TEXT,
13-
prediciton INTEGER,
14-
timestamp TEXT''')
13+
prediction INTEGER,
14+
timestamp TEXT)''')
1515
conn.commit()
1616
conn.close()
1717

1818
def log_prediction(features, pred, timestamp):
1919
conn=sqlite3.connect(DB_FILE)
2020
c=conn.cursor()
21-
c.execute("INSERT INTO predictions (features, prediction, timestamp) VALUES (?,?,?)", (json.dumps(features), pred, timestamp.isformat()))
21+
c.execute("INSERT INTO predictions (features, prediction, timestamp) VALUES (?,?,?)", (json.dumps(features), pred, timestamp.isoformat()))
2222
conn.commit()
2323
conn.close()
2424

@@ -31,7 +31,7 @@ def get_recent_predictions(limit=100):
3131
return [row[0] for row in rows]
3232

3333

34-
def initial_get_predictions(limit=100):
34+
def get_initial_predictions(limit=100):
3535
# For demo, using first 100 predictions as baseline
3636
conn=sqlite3.connect(DB_FILE)
3737
c=conn.cursor()
@@ -40,6 +40,5 @@ def initial_get_predictions(limit=100):
4040
conn.close()
4141
return [row[0] for row in rows]
4242

43-
init_db() # ensures table exists
4443

4544

app/main.py

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,126 @@
11
from fastapi import FastAPI, BackgroundTasks, HTTPException
22
from pydantic import BaseModel
3-
import numpy as np
43
import pickle
4+
import numpy as np
5+
import os
56
from datetime import datetime
6-
from logger import log_prediction, get_recent_predictions
7-
from drift_detector import compute_drift_score, should_alert
8-
from discord_alert import send_alert
9-
import asyncio
7+
from typing import List, Optional
8+
import json
9+
10+
# Import your modules
11+
from app.logger import log_prediction, get_recent_predictions, get_initial_predictions, init_db
12+
from app.drift_detector import compute_drift_score, should_alert
13+
from app.discord_alert import send_alert
14+
from app import config
15+
16+
# Initialize database on startup
17+
init_db()
18+
19+
app = FastAPI(title="MLOps Drift Detection API")
20+
21+
# Load model
22+
model_path = os.path.join(os.path.dirname(__file__), "model.pkl")
23+
if not os.path.exists(model_path):
24+
print(f"Warning: Model file not found at {model_path}")
25+
model = None
26+
else:
27+
with open(model_path, 'rb') as f:
28+
model = pickle.load(f)
29+
30+
class InputData(BaseModel):
31+
features: List[float]
32+
33+
class PredictionResponse(BaseModel):
34+
prediction: int
35+
timestamp: str
36+
37+
class DriftResponse(BaseModel):
38+
drift_score: float
39+
alert: bool
40+
threshold: float
41+
recent_samples: int
42+
baseline_samples: int
43+
44+
@app.get("/")
45+
async def root():
46+
return {
47+
"message": "MLOps Drift Detection API",
48+
"status": "operational",
49+
"model_loaded": model is not None
50+
}
51+
52+
@app.get("/healthz")
53+
async def healthz():
54+
"""Health check endpoint for Render"""
55+
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
56+
57+
@app.post("/predict", response_model=PredictionResponse)
58+
async def predict(data: InputData, background_tasks: BackgroundTasks):
59+
"""Make a prediction and log it for drift detection"""
60+
if model is None:
61+
raise HTTPException(status_code=503, detail="Model not loaded")
62+
63+
# Convert to numpy array and reshape
64+
features = np.array(data.features).reshape(1, -1)
65+
66+
# Make prediction
67+
pred = model.predict(features)[0]
68+
69+
# Log prediction asynchronously
70+
timestamp = datetime.now()
71+
background_tasks.add_task(log_prediction, data.features, int(pred), timestamp)
72+
73+
return {
74+
"prediction": int(pred),
75+
"timestamp": timestamp.isoformat()
76+
}
77+
78+
@app.get("/drift_score", response_model=DriftResponse)
79+
async def get_drift(background_tasks: BackgroundTasks):
80+
"""Calculate current drift score"""
81+
# Get recent predictions (last 100)
82+
recent = get_recent_predictions(limit=100)
83+
84+
# Get baseline predictions (first 100)
85+
baseline = get_initial_predictions(limit=100)
86+
87+
if len(recent) < 10 or len(baseline) < 10:
88+
return {
89+
"drift_score": 0.0,
90+
"alert": False,
91+
"threshold": config.ALERT_THRESHOLD,
92+
"recent_samples": len(recent),
93+
"baseline_samples": len(baseline)
94+
}
95+
96+
# Compute drift score
97+
score = compute_drift_score(baseline, recent, method='psi')
98+
alert = should_alert(score, config.ALERT_THRESHOLD)
99+
100+
# Send alert if needed
101+
if alert:
102+
background_tasks.add_task(send_alert, f"🚨 Drift detected! Score: {score:.3f}")
103+
104+
return {
105+
"drift_score": round(score, 3),
106+
"alert": alert,
107+
"threshold": config.ALERT_THRESHOLD,
108+
"recent_samples": len(recent),
109+
"baseline_samples": len(baseline)
110+
}
10111

112+
@app.get("/stats")
113+
async def get_stats():
114+
"""Get prediction statistics"""
115+
recent = get_recent_predictions(limit=1000)
116+
baseline = get_initial_predictions(limit=100)
117+
118+
return {
119+
"total_predictions": len(recent),
120+
"baseline_size": len(baseline),
121+
"recent_distribution": {
122+
"0": recent.count(0) if recent else 0,
123+
"1": recent.count(1) if recent else 0,
124+
"2": recent.count(2) if recent else 0
125+
} if recent else {}
126+
}

app/model.pkl

169 KB
Binary file not shown.

app/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def train_and_save():
77
X,y = iris.data, iris.target
88
model=RandomForestClassifier()
99
model.fit(X,y)
10-
with open('model.pkl', 'wb') as f:
10+
with open('app/model.pkl', 'wb') as f:
1111
pickle.dump(model, f)
1212
# Also store the training data distribution for drift baseline
1313
return y

0 commit comments

Comments
 (0)