Skip to content

Commit 6a8a9d3

Browse files
committed
Prompt comparison: generate multiple @task functions per prompt variant
When create_eval_config receives multiple prompts, generates literal @task functions (eval_1, eval_2, ...) with prompt_template() solvers. Inspect's cartesian product (tasks × models) runs in a single eval with shared run_id — no agent involvement needed. - Normalize {question} → {prompt} for prompt_template() compatibility - Results keyed by task/model composite when multiple tasks in a group - Column order: same model side by side (P1, P2) for easy comparison - Collapsible prompt legend in frontend - /rebuild endpoint uses force=True to bust precomputed cache
1 parent ff4af56 commit 6a8a9d3

7 files changed

Lines changed: 208 additions & 57 deletions

File tree

backend/api/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ async def get_comparison_detail(group_id: str, user_id: str = Depends(_get_user_
160160
if data:
161161
return data
162162

163-
# Fallback: pre-compute this group (handles old evals without pre-computed JSON)
163+
# Fallback: compute this group on demand
164164
await precompute_eval_results(user_id)
165165
data = load_eval_detail(user_id, group_id)
166166
if data:
@@ -174,7 +174,7 @@ async def rebuild_results(user_id: str = Depends(_get_user_id)):
174174
175175
Use this once to migrate existing evals, or to fix corrupted data.
176176
"""
177-
await precompute_eval_results(user_id)
177+
await precompute_eval_results(user_id, force=True)
178178
data = load_eval_groups(user_id)
179179
count = len(data["groups"]) if data else 0
180180
return {"ok": True, "groups_rebuilt": count}

backend/core/eval_results.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ async def _read_full_logs(log_files: list[str]) -> list[dict]:
7474
log = await read_eval_log_async(f, header_only=False)
7575
entry: dict = {
7676
"file": f,
77+
"task": log.eval.task,
7778
"model": log.eval.model,
7879
"status": log.status,
7980
"samples": [],
@@ -154,14 +155,20 @@ def _build_groups_from_headers(headers: list[dict]) -> dict:
154155

155156
groups = []
156157
for run_id, run_logs in groups_map.items():
158+
distinct_tasks = list(dict.fromkeys(l.get("task", "") for l in run_logs))
159+
is_prompt_comparison = len(distinct_tasks) > 1
157160
models = list(dict.fromkeys(l["model"] for l in run_logs))
158-
scores_by_model = {}
161+
162+
scores_by_key = {}
159163
for l in run_logs:
160164
if l.get("scores"):
161165
metrics = {}
162166
for s in l["scores"]:
163167
metrics.update(s["metrics"])
164-
scores_by_model[l["model"]] = metrics
168+
if is_prompt_comparison:
169+
scores_by_key[f"{l.get('task', '')}/{l['model']}"] = metrics
170+
else:
171+
scores_by_key[l["model"]] = metrics
165172

166173
task_name = run_logs[0].get("task", "unknown")
167174
config_name = task_name.replace("eval_task", "").strip("_") or task_name
@@ -170,16 +177,21 @@ def _build_groups_from_headers(headers: list[dict]) -> dict:
170177
if status == "error" and run_logs[0].get("dataset_samples", 0) > 0:
171178
status = "completed"
172179

173-
groups.append({
180+
group = {
174181
"id": run_id,
175182
"task": task_name,
176183
"configName": config_name,
177184
"created": run_logs[0].get("created", ""),
178185
"models": models,
179186
"sampleCount": run_logs[0].get("dataset_samples", 0),
180187
"status": status,
181-
"scores": scores_by_model,
182-
})
188+
"scores": scores_by_key,
189+
}
190+
if is_prompt_comparison:
191+
group["promptComparison"] = True
192+
group["promptCount"] = len(distinct_tasks)
193+
194+
groups.append(group)
183195

184196
groups.sort(key=lambda g: g["created"], reverse=True)
185197
return {"groups": groups}
@@ -191,15 +203,32 @@ def _build_detail_from_logs(
191203
full_logs: list[dict],
192204
user_dir: Path,
193205
) -> dict:
194-
models = [l["model"] for l in full_logs]
206+
# Detect prompt comparison: multiple distinct task names in one group
207+
distinct_tasks = list(dict.fromkeys(l.get("task", "") for l in full_logs))
208+
is_prompt_comparison = len(distinct_tasks) > 1
209+
210+
if is_prompt_comparison:
211+
models = list(dict.fromkeys(f"{l.get('task', '')}/{l['model']}" for l in full_logs))
212+
# Sort by model first, then prompt number — same model side by side: P1 P2 P1 P2
213+
def _col_sort_key(k: str) -> tuple:
214+
task_part, model_part = k.split("/", 1)
215+
# Extract number from eval_N
216+
num = int(task_part.replace("eval_", "")) if task_part.startswith("eval_") else 0
217+
return (model_part, num)
218+
models.sort(key=_col_sort_key)
219+
else:
220+
models = [l["model"] for l in full_logs]
195221

196222
samples_by_id: dict[str, dict] = {}
197223
criteria_set: set[str] = set()
198224
pipeline_stages: list[dict] = []
199225
is_pipeline = False
200226

201227
for log in full_logs:
202-
model = log["model"]
228+
if is_prompt_comparison:
229+
column_key = f"{log.get('task', '')}/{log['model']}"
230+
else:
231+
column_key = log["model"]
203232
for sample in log.get("samples", []):
204233
sid = sample["id"]
205234
if sid not in samples_by_id:
@@ -255,7 +284,7 @@ def _build_detail_from_logs(
255284
for cr in criteria_results:
256285
criteria_set.add(cr["name"])
257286

258-
samples_by_id[sid]["results"][model] = score_data
287+
samples_by_id[sid]["results"][column_key] = score_data
259288

260289
if is_pipeline:
261290
first_sample = next(iter(samples_by_id.values()), None)
@@ -325,7 +354,11 @@ def _build_detail_from_logs(
325354

326355
stats: dict[str, dict] = {}
327356
for log_header in group_logs:
328-
model = log_header["model"]
357+
raw_model = log_header["model"]
358+
if is_prompt_comparison:
359+
stat_key = f"{log_header.get('task', '')}/{raw_model}"
360+
else:
361+
stat_key = raw_model
329362
started = log_header.get("started_at")
330363
completed = log_header.get("completed_at")
331364

@@ -341,14 +374,14 @@ def _build_detail_from_logs(
341374
sample_count = log_header.get("dataset_samples", 1) or 1
342375
avg_latency = latency_seconds / sample_count if latency_seconds else None
343376

344-
stats[model] = {
377+
stats[stat_key] = {
345378
"startedAt": started,
346379
"completedAt": completed,
347380
"totalSeconds": latency_seconds,
348381
"latencySeconds": round(avg_latency, 2) if avg_latency else None,
349382
}
350383

351-
full_log = next((fl for fl in full_logs if fl.get("model") == model), None)
384+
full_log = next((fl for fl in full_logs if fl.get("model") == raw_model and fl.get("task") == log_header.get("task")), None)
352385
agent_usage = full_log.get("agent_model_usage", {}) if full_log else {}
353386

354387
if agent_usage:
@@ -373,11 +406,11 @@ def _build_detail_from_logs(
373406
"total_tokens": tot,
374407
"cost": model_cost,
375408
}
376-
stats[model]["input_tokens"] = agent_input
377-
stats[model]["output_tokens"] = agent_output
378-
stats[model]["total_tokens"] = agent_tokens
379-
stats[model]["cost"] = agent_cost
380-
stats[model]["modelUsage"] = per_model
409+
stats[stat_key]["input_tokens"] = agent_input
410+
stats[stat_key]["output_tokens"] = agent_output
411+
stats[stat_key]["total_tokens"] = agent_tokens
412+
stats[stat_key]["cost"] = agent_cost
413+
stats[stat_key]["modelUsage"] = per_model
381414
elif log_header.get("model_usage"):
382415
total_input = 0
383416
total_output = 0
@@ -386,38 +419,37 @@ def _build_detail_from_logs(
386419
total_input += usage.get("input_tokens", 0)
387420
total_output += usage.get("output_tokens", 0)
388421
total_tokens += usage.get("total_tokens", 0)
389-
stats[model]["input_tokens"] = total_input
390-
stats[model]["output_tokens"] = total_output
391-
stats[model]["total_tokens"] = total_tokens
392-
stats[model]["cost"] = calculate_cost(model, total_input, total_output)
422+
stats[stat_key]["input_tokens"] = total_input
423+
stats[stat_key]["output_tokens"] = total_output
424+
stats[stat_key]["total_tokens"] = total_tokens
425+
stats[stat_key]["cost"] = calculate_cost(raw_model, total_input, total_output)
393426

394427
if latency_seconds and latency_seconds > 0:
395-
stats[model]["tokensPerSecond"] = round(agent_output / latency_seconds, 1)
428+
stats[stat_key]["tokensPerSecond"] = round(total_output / latency_seconds, 1)
396429

397430
criteria_descriptions = _load_criteria_descriptions(user_dir, task_name, criteria_set)
398431

399432
# For agent evals, replace model name with agent image name
400433
agent_image = None
434+
config_data = None
401435
configs_dir = user_dir / "configs"
402436
if configs_dir.exists():
403-
# Get config file name from task_file in log header
404437
task_file = group_logs[0].get("task_file") if group_logs else None
405438
if task_file:
406439
config_json = configs_dir / Path(task_file).with_suffix(".json").name
407440
else:
408441
config_json = configs_dir / f"{config_name}.json"
409442
if config_json.exists():
410443
try:
411-
data = json.loads(config_json.read_text())
412-
if data.get("agent_image"):
413-
agent_image = data["agent_image"]
444+
config_data = json.loads(config_json.read_text())
445+
if config_data.get("agent_image"):
446+
agent_image = config_data["agent_image"]
414447
except Exception:
415448
pass
416449

417450
display_models = models
418451
if agent_image and len(models) == 1:
419452
display_models = [f"agent/{agent_image}"]
420-
# Remap aggregate, stats, and sample results
421453
old_model = models[0]
422454
if old_model in aggregate:
423455
aggregate[f"agent/{agent_image}"] = aggregate.pop(old_model)
@@ -441,10 +473,12 @@ def _build_detail_from_logs(
441473
result["pipeline"] = pipeline_stages
442474
if agent_image:
443475
result["agentImage"] = agent_image
476+
if is_prompt_comparison and config_data and config_data.get("prompts"):
477+
result["prompts"] = config_data["prompts"]
444478
return result
445479

446480

447-
async def precompute_eval_results(user_id: str) -> None:
481+
async def precompute_eval_results(user_id: str, force: bool = False) -> None:
448482
"""Parse all .eval files for a user and save pre-computed JSON to S3/disk.
449483
450484
Called after an eval completes and by the migration script.
@@ -463,10 +497,10 @@ async def precompute_eval_results(user_id: str) -> None:
463497
for group in groups_response["groups"]:
464498
group_id = group["id"]
465499

466-
# Always re-compute details for running evals; cache completed ones
467-
existing = load_eval_detail(user_id, group_id)
468-
if existing and group.get("status") == "success":
469-
continue
500+
if not force:
501+
existing = load_eval_detail(user_id, group_id)
502+
if existing and group.get("status") == "success":
503+
continue
470504

471505
group_logs = [h for h in headers if (h.get("run_id") or h["file"]) == group_id]
472506
if not group_logs:

backend/mcp_servers/synthetic/server_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ async def create_eval_config(
259259
- Google: "google/gemini-2.5-pro" (requires GOOGLE_API_KEY)
260260
Use list_available_models() to discover available providers and models.
261261
judge: Name of judge from list_judges (REQUIRED - criteria adapted to QA pairs)
262-
prompts: Single prompt string OR list of prompts (default: "{{question}}")
262+
prompts: Single prompt string OR list of prompts for comparison. Use {question} or {prompt} as placeholder for the input text. (default: "{{question}}")
263263
configName: Name for this evaluation (default: "evaluation")
264264
description: Optional description of the evaluation
265265
judge_models: Optional list of Bedrock model IDs to use as judges

0 commit comments

Comments
 (0)