Skip to content

Commit 44213c3

Browse files
feat(optimisation): implement phase 2 triage_accuracy metric with LLM judge
- Extract all metrics into optimisation/metrics.py to keep main script clean - Add triage_accuracy metric: flag accuracy (0.6) + reply quality judge (0.4) - Add _reply_quality helper using claude-haiku for cost-efficient judging - Strip markdown code fences from judge response before JSON parsing - Update _scout_reasoning_override to include reply quality context - Switch main() from flag_only_metric to triage_accuracy for phase 2 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d6f4696 commit 44213c3

2 files changed

Lines changed: 197 additions & 87 deletions

File tree

optimisation/metrics.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Metric functions for Scout prompt optimisation.
2+
3+
All metrics share the same invoke_agent output contract:
4+
llm_output = JSON string {"comment": str, "escalated": bool}
5+
6+
Phases:
7+
Phase 1 — flag_only_metric (escalation flag only, no LLM judge)
8+
Phase 2 — triage_accuracy (flag 0.6 + reply quality judge 0.4)
9+
"""
10+
from __future__ import annotations
11+
12+
import json
13+
import os
14+
import sys
15+
16+
_repo_root = os.path.join(os.path.dirname(__file__), "..")
17+
sys.path.insert(0, _repo_root)
18+
sys.path.insert(0, os.path.join(_repo_root, "src"))
19+
20+
import litellm # noqa: E402
21+
from opik.evaluation.metrics import AnswerRelevance # noqa: E402
22+
from opik.evaluation.metrics.score_result import ScoreResult # noqa: E402
23+
24+
from scout.triage import OPIK_PROJECT, SCOUT_ESCALATION_TAG # noqa: E402
25+
26+
JUDGE_MODEL = "anthropic/claude-haiku-4-5-20251001"
27+
28+
JUDGE_PROMPT = """\
29+
You are evaluating a GitHub issue reply written by Scout, an AI triage agent.
30+
31+
Scout's role is to:
32+
- Help the issue author understand whether their issue is a bug or not
33+
- Clarify the code behaviour where possible
34+
- NOT suggest fixes or code changes
35+
- Ask for reproduction steps if a bug is suspected
36+
- Link to relevant documentation when helpful
37+
- Always introduce itself as Scout with a friendly tone
38+
- Escalate only when the issue requires a major design decision, breaking API change,
39+
or architectural discussion needing maintainer consensus
40+
41+
---
42+
43+
GitHub Issue:
44+
{issue}
45+
46+
Scout's Reply:
47+
{reply}
48+
49+
Escalation Applied: {escalated}
50+
51+
---
52+
53+
Score the reply from 0.0 to 1.0:
54+
- 0.0: fails a hard rule (no Scout intro, suggests fixes, wrong tone, escalation contradicts reply)
55+
- 0.5: meets hard rules but vague — missing repro steps or docs when clearly needed
56+
- 0.75: solid reply with minor gaps
57+
- 1.0: excellent — clear, friendly, on-scope, correctly escalated, repro steps/docs where appropriate
58+
59+
Return JSON only: {{"score": float, "reason": "one sentence"}}
60+
"""
61+
62+
_answer_relevance_metric = AnswerRelevance(
63+
model=JUDGE_MODEL,
64+
project_name=OPIK_PROJECT,
65+
require_context=False,
66+
)
67+
68+
69+
# ---------------------------------------------------------------------------
70+
# Shared helpers
71+
# ---------------------------------------------------------------------------
72+
73+
def _parse_output(llm_output: str) -> tuple[str, bool]:
74+
"""Parse invoke_agent JSON output into (comment, escalated).
75+
76+
Falls back to plain string + tag-in-text detection if JSON is malformed.
77+
"""
78+
try:
79+
parsed = json.loads(llm_output)
80+
return parsed["comment"], bool(parsed["escalated"])
81+
except (json.JSONDecodeError, KeyError):
82+
return llm_output, SCOUT_ESCALATION_TAG.lower() in llm_output.lower()
83+
84+
85+
def _expected_escalation(dataset_item: dict) -> bool | None:
86+
"""Return the expected escalation bool, or None if not present in the item."""
87+
data = dataset_item.get("data", dataset_item)
88+
expected = data.get("expected", {})
89+
val = expected.get("should_escalate")
90+
return bool(val) if val is not None else None
91+
92+
93+
# ---------------------------------------------------------------------------
94+
# Phase 1 — flag accuracy only
95+
# ---------------------------------------------------------------------------
96+
97+
def flag_only_metric(dataset_item: dict, llm_output: str) -> ScoreResult:
98+
"""Escalation flag correctness only. No LLM judge call."""
99+
should_escalate = _expected_escalation(dataset_item)
100+
if should_escalate is None:
101+
return ScoreResult(name="flag_accuracy", value=1.0, reason="No expected flag — skipped.")
102+
103+
_, output_escalated = _parse_output(llm_output)
104+
correct = output_escalated == should_escalate
105+
return ScoreResult(
106+
name="flag_accuracy",
107+
value=1.0 if correct else 0.0,
108+
reason="Flag correct." if correct else f"Flag wrong — expected escalate={should_escalate}.",
109+
)
110+
111+
112+
# ---------------------------------------------------------------------------
113+
# Phase 2 — triage accuracy (flag + reply quality)
114+
# ---------------------------------------------------------------------------
115+
116+
def _reply_quality(issue: str, reply: str, escalated: bool) -> ScoreResult:
117+
"""LLM-as-judge for Scout's reply. Uses JUDGE_MODEL (Haiku) to keep costs low."""
118+
prompt = JUDGE_PROMPT.format(issue=issue, reply=reply, escalated=escalated)
119+
response = litellm.completion(
120+
model=JUDGE_MODEL,
121+
messages=[{"role": "user", "content": prompt}],
122+
)
123+
content = response.choices[0].message.content or ""
124+
content = content.strip().removeprefix("```json").removeprefix("```").removesuffix("```").strip()
125+
result = json.loads(content)
126+
return ScoreResult(
127+
name="reply_quality",
128+
value=float(result["score"]),
129+
reason=result["reason"],
130+
)
131+
132+
133+
def triage_accuracy(dataset_item: dict, llm_output: str) -> ScoreResult:
134+
"""Phase 2 — flag accuracy (0.6) + reply quality judge (0.4)."""
135+
comment, output_escalated = _parse_output(llm_output)
136+
should_escalate = _expected_escalation(dataset_item)
137+
138+
flag_score = 1.0
139+
flag_reason = "No expected flag."
140+
if should_escalate is not None:
141+
flag_score = 1.0 if output_escalated == should_escalate else 0.0
142+
flag_reason = "Flag correct." if flag_score == 1.0 else f"Flag wrong — expected escalate={should_escalate}."
143+
144+
issue = dataset_item.get("issue_message", "")
145+
reply_result = _reply_quality(issue, comment, output_escalated)
146+
147+
combined = (flag_score * 0.6) + (reply_result.value * 0.4)
148+
return ScoreResult(
149+
name="triage_accuracy",
150+
value=combined,
151+
reason=f"{flag_reason} Reply: {reply_result.reason}",
152+
)
153+
154+
155+
# ---------------------------------------------------------------------------
156+
# Legacy metrics (kept for reference and future phases)
157+
# ---------------------------------------------------------------------------
158+
159+
def escalation_accuracy(dataset_item: dict, llm_output: str) -> float:
160+
"""Score 1.0 if escalation decision matches expected, 0.0 otherwise."""
161+
should_escalate = _expected_escalation(dataset_item)
162+
if should_escalate is None:
163+
return 1.0
164+
_, output_escalated = _parse_output(llm_output)
165+
return 1.0 if output_escalated == should_escalate else 0.0
166+
167+
168+
def answer_relevance(dataset_item: dict, llm_output: str) -> float:
169+
"""AnswerRelevance score for the reply comment."""
170+
comment, _ = _parse_output(llm_output)
171+
result = _answer_relevance_metric.score(
172+
input=dataset_item["issue_message"],
173+
output=comment,
174+
)
175+
return result.value
176+
177+
178+
def scout_quality(dataset_item: dict, llm_output: str) -> float:
179+
"""Combined metric: structural completeness (50%) + escalation accuracy (50%)."""
180+
comment, output_escalated = _parse_output(llm_output)
181+
182+
required_sections = ["## Solution", "## Code Investigation", "## Next Steps"]
183+
structure_score = sum(s in comment for s in required_sections) / len(required_sections)
184+
185+
should_escalate = _expected_escalation(dataset_item)
186+
escalation_score = 1.0 if should_escalate is None else (
187+
1.0 if output_escalated == should_escalate else 0.0
188+
)
189+
190+
return 0.5 * structure_score + 0.5 * escalation_score

optimisation/run_prompt_optimisation.py

Lines changed: 7 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
sys.path.insert(0, os.path.join(_repo_root, "src"))
3030

3131
import opik # noqa: E402
32-
from opik.evaluation.metrics import AnswerRelevance # noqa: E402
33-
from opik.evaluation.metrics.score_result import ScoreResult # noqa: E402
3432
from dotenv import load_dotenv # noqa: E402
3533
from opik_optimizer import ChatPrompt, MetaPromptOptimizer # noqa: E402
3634
from opik_optimizer.agents.optimizable_agent import OptimizableAgent # noqa: E402
@@ -47,6 +45,7 @@
4745
OPIK_PROJECT,
4846
SCOUT_ESCALATION_TAG,
4947
)
48+
from metrics import triage_accuracy # noqa: E402
5049

5150
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
5251
logger = logging.getLogger(__name__)
@@ -101,90 +100,6 @@ def invoke_agent(
101100
return json.dumps({"comment": comment or "", "escalated": escalated})
102101

103102

104-
def _parse_output(llm_output: str) -> tuple[str, bool]:
105-
"""Parse invoke_agent output into (comment, escalated).
106-
107-
invoke_agent returns JSON with {"comment": str, "escalated": bool}.
108-
Falls back to plain string + tag-in-text detection for safety.
109-
"""
110-
try:
111-
parsed = json.loads(llm_output)
112-
return parsed["comment"], bool(parsed["escalated"])
113-
except (json.JSONDecodeError, KeyError):
114-
return llm_output, SCOUT_ESCALATION_TAG.lower() in llm_output.lower()
115-
116-
117-
def escalation_accuracy(dataset_item: dict, llm_output: str) -> float:
118-
"""Score 1.0 if escalation decision matches expected, 0.0 otherwise.
119-
120-
Items without an expected.should_escalate field score 1.0 so they don't
121-
dilute the signal.
122-
"""
123-
data = dataset_item.get("data", dataset_item)
124-
expected = data.get("expected", {})
125-
126-
if "should_escalate" not in expected:
127-
return 1.0
128-
129-
_, output_escalated = _parse_output(llm_output)
130-
return 1.0 if output_escalated == expected["should_escalate"] else 0.0
131-
132-
_answer_relevance_metric = AnswerRelevance(
133-
model="anthropic/claude-haiku-4-5-20251001",
134-
project_name=OPIK_PROJECT,
135-
require_context=False,
136-
)
137-
138-
139-
def answer_relevance(dataset_item: dict, llm_output: str) -> float:
140-
comment, _ = _parse_output(llm_output)
141-
result = _answer_relevance_metric.score(
142-
input=dataset_item["issue_message"],
143-
output=comment,
144-
)
145-
return result.value
146-
147-
148-
def scout_quality(dataset_item: dict, llm_output: str) -> float:
149-
"""Combined metric: structural completeness (50%) + escalation accuracy (50%)."""
150-
comment, output_escalated = _parse_output(llm_output)
151-
152-
required_sections = ["## Solution", "## Code Investigation", "## Next Steps"]
153-
structure_score = sum(s in comment for s in required_sections) / len(required_sections)
154-
155-
data = dataset_item.get("data", dataset_item)
156-
expected = data.get("expected", {})
157-
if "should_escalate" in expected:
158-
escalation_score = 1.0 if output_escalated == expected["should_escalate"] else 0.0
159-
else:
160-
escalation_score = 1.0
161-
162-
return 0.5 * structure_score + 0.5 * escalation_score
163-
164-
165-
def flag_only_metric(dataset_item: dict, llm_output: str) -> ScoreResult:
166-
"""Phase 1 — escalation flag correctness only. No LLM judge call.
167-
168-
Reads escalation state from the simulator label (via JSON output from
169-
invoke_agent) — the ground truth for whether apply_label() was called.
170-
"""
171-
data = dataset_item.get("data", dataset_item)
172-
expected = data.get("expected", {})
173-
174-
if "should_escalate" not in expected:
175-
return ScoreResult(name="flag_accuracy", value=1.0, reason="No expected flag — skipped.")
176-
177-
should_escalate: bool = expected["should_escalate"]
178-
_, output_escalated = _parse_output(llm_output)
179-
correct = output_escalated == should_escalate
180-
181-
return ScoreResult(
182-
name="flag_accuracy",
183-
value=1.0 if correct else 0.0,
184-
reason="Flag correct." if correct else f"Flag wrong — expected escalate={should_escalate}.",
185-
)
186-
187-
188103
def _scout_reasoning_override(prompts: PromptLibrary) -> None:
189104
"""Inject Scout-specific task context into the meta-LLM's reasoning prompt.
190105
@@ -206,6 +121,11 @@ def _scout_reasoning_override(prompts: PromptLibrary) -> None:
206121
No escalation means: bugs, feature requests, duplicate reports, spam — things Scout
207122
can investigate and respond to directly.
208123
124+
The metric scoring Scout evaluates both escalation accuracy (60%) and reply quality (40%).
125+
A high-quality reply: introduces Scout by name, uses a friendly tone, does NOT suggest
126+
code fixes, asks for repro steps when a bug is suspected, and is consistent with the
127+
escalation decision.
128+
209129
Scout has access to tools that explore the repository codebase and search existing issues.
210130
It does NOT use template variables in its prompt — do not add placeholders like {data} or
211131
{issue_message} to the prompt you generate.
@@ -242,7 +162,7 @@ def main() -> None:
242162
result = optimizer.optimize_prompt(
243163
prompt=initial_prompt,
244164
dataset=dataset,
245-
metric=flag_only_metric,
165+
metric=triage_accuracy,
246166
agent=ScoutAgent(project_name=OPIK_PROJECT),
247167
n_samples=10,
248168
project_name=OPIK_PROJECT,

0 commit comments

Comments
 (0)