Skip to content

Commit d6f4696

Browse files
feat(optimisation): implement phase 1 flag-only metric with correct escalation detection
- Switch metric to flag_only_metric (ScoreResult) for phase 1 - invoke_agent returns JSON with comment + escalated bool - Escalation detected via sim.get_issue_data labels post-run, not comment text - Add _parse_output helper used by all metrics to parse invoke_agent output - Set enable_context=False + prompt_overrides to prevent {data}/{issue_message} placeholders - Remove auto-save of best prompt to Opik — user promotes manually Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d7da2a8 commit d6f4696

1 file changed

Lines changed: 63 additions & 31 deletions

File tree

optimisation/run_prompt_optimisation.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313
from __future__ import annotations
1414

15+
import json
1516
import logging
1617
import os
1718
import sys
@@ -33,6 +34,7 @@
3334
from dotenv import load_dotenv # noqa: E402
3435
from opik_optimizer import ChatPrompt, MetaPromptOptimizer # noqa: E402
3536
from opik_optimizer.agents.optimizable_agent import OptimizableAgent # noqa: E402
37+
from opik_optimizer.utils.prompt_library import PromptLibrary # noqa: E402
3638

3739
load_dotenv()
3840

@@ -69,8 +71,6 @@ def invoke_agent(
6971
) -> str:
7072
prompt = next(iter(prompts.values()))
7173

72-
# get_messages handles both system= and messages= formats and substitutes
73-
# {issue_message} from dataset_item automatically.
7474
messages = prompt.get_messages(dataset_item)
7575
system_prompt = next(
7676
(m["content"] for m in messages if m.get("role") == "system"), ""
@@ -97,7 +97,21 @@ def invoke_agent(
9797
model=MODEL,
9898
max_tokens=MAX_TOKENS,
9999
)
100-
return comment or ""
100+
escalated = SCOUT_ESCALATION_TAG in sim.get_issue_data(target).get("labels", [])
101+
return json.dumps({"comment": comment or "", "escalated": escalated})
102+
103+
104+
def _parse_output(llm_output: str) -> tuple[str, bool]:
105+
"""Parse invoke_agent output into (comment, escalated).
106+
107+
invoke_agent returns JSON with {"comment": str, "escalated": bool}.
108+
Falls back to plain string + tag-in-text detection for safety.
109+
"""
110+
try:
111+
parsed = json.loads(llm_output)
112+
return parsed["comment"], bool(parsed["escalated"])
113+
except (json.JSONDecodeError, KeyError):
114+
return llm_output, SCOUT_ESCALATION_TAG.lower() in llm_output.lower()
101115

102116

103117
def escalation_accuracy(dataset_item: dict, llm_output: str) -> float:
@@ -112,10 +126,8 @@ def escalation_accuracy(dataset_item: dict, llm_output: str) -> float:
112126
if "should_escalate" not in expected:
113127
return 1.0
114128

115-
should_escalate: bool = expected["should_escalate"]
116-
output_escalated = SCOUT_ESCALATION_TAG.lower() in llm_output.lower()
117-
118-
return 1.0 if output_escalated == should_escalate else 0.0
129+
_, output_escalated = _parse_output(llm_output)
130+
return 1.0 if output_escalated == expected["should_escalate"] else 0.0
119131

120132
_answer_relevance_metric = AnswerRelevance(
121133
model="anthropic/claude-haiku-4-5-20251001",
@@ -125,23 +137,25 @@ def escalation_accuracy(dataset_item: dict, llm_output: str) -> float:
125137

126138

127139
def answer_relevance(dataset_item: dict, llm_output: str) -> float:
140+
comment, _ = _parse_output(llm_output)
128141
result = _answer_relevance_metric.score(
129142
input=dataset_item["issue_message"],
130-
output=llm_output,
143+
output=comment,
131144
)
132145
return result.value
133146

134147

135148
def scout_quality(dataset_item: dict, llm_output: str) -> float:
136149
"""Combined metric: structural completeness (50%) + escalation accuracy (50%)."""
150+
comment, output_escalated = _parse_output(llm_output)
151+
137152
required_sections = ["## Solution", "## Code Investigation", "## Next Steps"]
138-
structure_score = sum(s in llm_output for s in required_sections) / len(required_sections)
153+
structure_score = sum(s in comment for s in required_sections) / len(required_sections)
139154

140155
data = dataset_item.get("data", dataset_item)
141156
expected = data.get("expected", {})
142157
if "should_escalate" in expected:
143-
escalated = SCOUT_ESCALATION_TAG.lower() in llm_output.lower()
144-
escalation_score = 1.0 if escalated == expected["should_escalate"] else 0.0
158+
escalation_score = 1.0 if output_escalated == expected["should_escalate"] else 0.0
145159
else:
146160
escalation_score = 1.0
147161

@@ -151,19 +165,17 @@ def scout_quality(dataset_item: dict, llm_output: str) -> float:
151165
def flag_only_metric(dataset_item: dict, llm_output: str) -> ScoreResult:
152166
"""Phase 1 — escalation flag correctness only. No LLM judge call.
153167
154-
Maps the actual dataset shape (data.expected.should_escalate) and plain-string
155-
agent output (escalation detected via SCOUT_ESCALATION_TAG) to the dev-plan
156-
flag_only_metric interface.
168+
Reads escalation state from the simulator label (via JSON output from
169+
invoke_agent) — the ground truth for whether apply_label() was called.
157170
"""
158171
data = dataset_item.get("data", dataset_item)
159172
expected = data.get("expected", {})
160173

161174
if "should_escalate" not in expected:
162-
# No ground truth for this item — treat as correct so it doesn't dilute signal.
163175
return ScoreResult(name="flag_accuracy", value=1.0, reason="No expected flag — skipped.")
164176

165177
should_escalate: bool = expected["should_escalate"]
166-
output_escalated = SCOUT_ESCALATION_TAG.lower() in llm_output.lower()
178+
_, output_escalated = _parse_output(llm_output)
167179
correct = output_escalated == should_escalate
168180

169181
return ScoreResult(
@@ -173,6 +185,34 @@ def flag_only_metric(dataset_item: dict, llm_output: str) -> ScoreResult:
173185
)
174186

175187

188+
def _scout_reasoning_override(prompts: PromptLibrary) -> None:
189+
"""Inject Scout-specific task context into the meta-LLM's reasoning prompt.
190+
191+
Replaces enable_context dataset sampling so the optimizer understands the
192+
task domain without advertising dataset fields as template variables.
193+
"""
194+
original = prompts.get("reasoning_system")
195+
prompts.set(
196+
"reasoning_system",
197+
original + """
198+
199+
Task context: You are optimising the system prompt for Scout, a GitHub issue triage agent.
200+
Scout receives a GitHub issue (title, body, author, labels) and must:
201+
1. Decide whether to escalate (true) or not (false)
202+
2. Write a reply to the issue author
203+
204+
Escalation means: the issue requires a major design decision, breaking API change, or
205+
architectural discussion that needs maintainer consensus.
206+
No escalation means: bugs, feature requests, duplicate reports, spam — things Scout
207+
can investigate and respond to directly.
208+
209+
Scout has access to tools that explore the repository codebase and search existing issues.
210+
It does NOT use template variables in its prompt — do not add placeholders like {data} or
211+
{issue_message} to the prompt you generate.
212+
""",
213+
)
214+
215+
176216
def main() -> None:
177217
opik_client = opik.Opik()
178218
dataset = opik_client.get_dataset(DATASET_NAME)
@@ -184,16 +224,15 @@ def main() -> None:
184224
if chat_prompt_obj is None:
185225
sys.exit(f"ERROR: Opik chat prompt {PROMPT_NAME!r} not found in project {OPIK_PROJECT!r}.")
186226

187-
initial_prompt = ChatPrompt(
188-
messages=[*chat_prompt_obj.template, {"role": "user", "content": "{issue_message}"}]
189-
)
227+
initial_prompt = ChatPrompt(messages=chat_prompt_obj.template)
190228

191229
optimizer = MetaPromptOptimizer(
192230
model=f"anthropic/{MODEL}",
193231
model_parameters={"temperature": 0.0},
194232
prompts_per_round=4,
195233
n_threads=4,
196-
enable_context=True,
234+
enable_context=False,
235+
prompt_overrides=_scout_reasoning_override,
197236
seed=42,
198237
skip_perfect_score=False,
199238
)
@@ -212,18 +251,11 @@ def main() -> None:
212251

213252
result.display()
214253

215-
# Save the best prompt back to Opik — strip the user template, keep only system messages.
216-
best_prompt = result.prompt if isinstance(result.prompt, ChatPrompt) else list(result.prompt.values())[0]
217-
all_messages = best_prompt.messages or (
218-
[{"role": "system", "content": best_prompt.system}] if best_prompt.system else []
219-
)
220-
system_messages = [m for m in all_messages if m.get("role") == "system"]
221-
opik_client.create_chat_prompt(
222-
name=PROMPT_NAME,
223-
messages=system_messages,
224-
project_name=OPIK_PROJECT,
254+
logger.info(
255+
"Review results in the Opik dashboard and manually promote the best prompt "
256+
"to a new version of %r if you decide it is better.",
257+
PROMPT_NAME,
225258
)
226-
logger.info("Best prompt saved to Opik under %r as a new version.", PROMPT_NAME)
227259

228260

229261
if __name__ == "__main__":

0 commit comments

Comments
 (0)