forked from llm-d-incubation/llm-d-planner
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
133 lines (111 loc) · 4.1 KB
/
database.py
File metadata and controls
133 lines (111 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Database management API routes.
Provides endpoints for uploading benchmark data, checking DB status,
and resetting the benchmark database. These endpoints enable remote
management of NeuralNav deployments (e.g., on Kubernetes) without
needing shell access.
"""
import json
import logging
import os
import psycopg2
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from planner.knowledge_base.loader import get_db_stats, insert_benchmarks, reset_benchmarks
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1", tags=["database"])
_DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://postgres:neuralnav@localhost:5432/neuralnav",
)
def _get_connection():
"""Get a database connection for DB management operations."""
return psycopg2.connect(_DATABASE_URL)
@router.get("/db/status")
async def db_status():
"""Get current benchmark database statistics."""
try:
conn = _get_connection()
try:
stats = get_db_stats(conn)
return {"success": True, **stats}
finally:
conn.close()
except Exception as e:
logger.error(f"Failed to get DB status: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"Database not accessible: {e}"
) from e
@router.post("/db/upload-benchmarks")
async def upload_benchmarks(file: UploadFile = File(...)):
"""Upload a benchmark JSON file and load it into the database.
The JSON file should have a top-level "benchmarks" array containing
benchmark records. Duplicates (same model/hardware/traffic config)
are silently skipped.
Usage:
curl -X POST -F 'file=@benchmarks.json' http://host/api/v1/db/upload-benchmarks
"""
if not file.filename or not file.filename.endswith(".json"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="File must be a .json file"
)
try:
content = await file.read()
data = json.loads(content)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON: {e}"
) from e
benchmarks = data.get("benchmarks", [])
if not benchmarks:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='No benchmarks found. JSON must have a top-level "benchmarks" array.',
)
try:
conn = _get_connection()
try:
stats = insert_benchmarks(conn, benchmarks)
logger.info(
f"Uploaded {len(benchmarks)} benchmarks from {file.filename}, "
f"DB now has {stats['total_benchmarks']} total"
)
return {
"success": True,
"filename": file.filename,
"records_in_file": len(benchmarks),
**stats,
}
finally:
conn.close()
except Exception as e:
logger.error(f"Failed to load benchmarks: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to load benchmarks: {e}",
) from e
@router.post("/db/reset")
async def reset_database():
"""Reset the benchmark database by removing all benchmark data.
This truncates the exported_summaries table (cascading to related tables).
The schema is preserved — only data is removed.
Usage:
curl -X POST http://host/api/v1/db/reset
"""
try:
conn = _get_connection()
try:
reset_benchmarks(conn)
stats = get_db_stats(conn)
logger.info("Benchmark database reset via API")
return {
"success": True,
"message": "Benchmark database has been reset",
**stats,
}
finally:
conn.close()
except Exception as e:
logger.error(f"Failed to reset database: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to reset database: {e}",
) from e