forked from llm-d-incubation/llm-d-planner
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreference_data.py
More file actions
173 lines (139 loc) · 6.64 KB
/
reference_data.py
File metadata and controls
173 lines (139 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Reference data endpoints (models, GPU types, benchmarks, etc.)."""
import csv
import json
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status
from planner.api.dependencies import get_model_catalog, get_slo_repo
from planner.knowledge_base.model_catalog import ModelCatalog
from planner.knowledge_base.slo_templates import SLOTemplateRepository
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1", tags=["reference-data"])
def _get_data_path() -> Path:
"""Get the base data directory path."""
return Path(__file__).parent.parent.parent.parent.parent / "data"
@router.get("/models")
async def list_models(model_catalog: ModelCatalog = Depends(get_model_catalog)):
"""Get list of available models."""
try:
models = model_catalog.get_all_models()
return {"models": [model.to_dict() for model in models], "count": len(models)}
except Exception as e:
logger.error(f"Failed to list models: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e
@router.get("/gpu-types")
async def list_gpu_types(model_catalog: ModelCatalog = Depends(get_model_catalog)):
"""Get list of available GPU types."""
try:
gpu_types = model_catalog.get_all_gpu_types()
return {"gpu_types": [gpu.to_dict() for gpu in gpu_types], "count": len(gpu_types)}
except Exception as e:
logger.error(f"Failed to list GPU types: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e
@router.get("/use-cases")
async def list_use_cases(slo_repo: SLOTemplateRepository = Depends(get_slo_repo)):
"""Get list of supported use cases with SLO templates."""
try:
templates = slo_repo.get_all_templates()
return {
"use_cases": {use_case: template.to_dict() for use_case, template in templates.items()},
"count": len(templates),
}
except Exception as e:
logger.error(f"Failed to list use cases: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e
@router.get("/benchmarks")
async def get_benchmarks():
"""Get all 206 models benchmark data from opensource_all_benchmarks.csv."""
try:
csv_path = _get_data_path() / "benchmarks" / "accuracy" / "opensource_all_benchmarks.csv"
if not csv_path.exists():
logger.error(f"Benchmark CSV not found at: {csv_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Benchmark data file not found"
)
# Read CSV using built-in csv module
records = []
with open(csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
# Filter out rows with empty/missing Model Name
if row.get("Model Name") and row["Model Name"].strip():
records.append(row)
logger.info(f"Loaded {len(records)} benchmark records from CSV")
return {"success": True, "count": len(records), "benchmarks": records}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to load benchmarks: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to load benchmarks: {str(e)}",
) from e
@router.get("/priority-weights")
async def get_priority_weights():
"""Get priority to weight mapping configuration.
Returns the priority_weights.json data for UI to use
when setting initial weights based on priority dropdowns.
"""
try:
json_path = _get_data_path() / "configuration" / "priority_weights.json"
if not json_path.exists():
logger.error(f"Priority weights config not found at: {json_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Priority weights configuration not found",
)
with open(json_path) as f:
data = json.load(f)
return {"success": True, "priority_weights": data}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to load priority weights: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e
@router.get("/weighted-scores/{use_case}")
async def get_weighted_scores(use_case: str):
"""Get use-case-specific weighted scores from CSV."""
try:
# Map use case to CSV filename
use_case_to_file = {
"chatbot_conversational": "opensource_chatbot_conversational.csv",
"code_completion": "opensource_code_completion.csv",
"code_generation_detailed": "opensource_code_generation_detailed.csv",
"document_analysis_rag": "opensource_document_analysis_rag.csv",
"summarization_short": "opensource_summarization_short.csv",
"long_document_summarization": "opensource_long_document_summarization.csv",
"translation": "opensource_translation.csv",
"content_generation": "opensource_content_generation.csv",
"research_legal_analysis": "opensource_research_legal_analysis.csv",
}
filename = use_case_to_file.get(use_case)
if not filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid use case: {use_case}. Valid options: {list(use_case_to_file.keys())}",
)
csv_path = _get_data_path() / "benchmarks" / "accuracy" / "weighted_scores" / filename
if not csv_path.exists():
logger.error(f"Weighted scores CSV not found at: {csv_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Weighted scores file not found for use case: {use_case}",
)
# Read CSV using built-in csv module
records = []
with open(csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
records.append(row)
logger.info(f"Loaded {len(records)} weighted score records for use case: {use_case}")
return {"success": True, "use_case": use_case, "count": len(records), "scores": records}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to load weighted scores: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to load weighted scores: {str(e)}",
) from e