Skip to content

Commit b45de09

Browse files
authored
Inspect AI integration (#7)
* feat(inspect): add Inspect AI adapters - add `skill_solver` and `step_scorer` in `tk.llmbda.inspect` - expose optional `[inspect]` extra; also install in dev group - alias stdlib `inspect` in `__init__.py` to avoid submodule shadowing - document integration, install, and viewer in README - ignore `logs/` and `.pytest_cache/` * refactor(examples): split triage into skill/main/scoring - `examples/triage/skill.py` holds the reusable Skill definition - `main.py` runs and prints traces; `scoring.py` runs an Inspect eval - `scoring.py` demos `step_scorer`, custom `@scorer`, custom `@metric` - target is a `[intent, priority]` list consumed by scorers - TODO: add an LLM-graded scorer for a non-deterministic judge path * fixed crap code
1 parent 0076cc8 commit b45de09

9 files changed

Lines changed: 539 additions & 96 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ build/
1010
_version.py
1111
uv.lock
1212
.coverage
13+
logs/
14+
.pytest_cache/

README.md

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,59 @@ uv run examples/date_extraction.py
188188
uv run examples/calendar_booking.py
189189

190190
# support triage: extraction, classification, validation loop
191-
uv run examples/support_triage.py
191+
uv run examples/triage/main.py
192+
193+
# same skill, scored step-by-step with Inspect AI (see section below)
194+
uv run examples/triage/scoring.py
192195

193196
# all 20 use cases in one file (no external deps)
194197
uv run examples/showcase.py
195198
```
196199

200+
## Inspect AI integration
201+
202+
Score individual steps of a skill with [Inspect AI](https://inspect.aisi.org.uk/) scorers.
203+
204+
- `skill_solver(skill)` wraps a skill as an Inspect `Solver`. Final `result.value` becomes the completion; full trace lands in `state.metadata["llmbda.trace"]`.
205+
- `step_scorer(name, inner)` adapts any Inspect scorer to read a named step's value instead of the final completion.
206+
207+
```python
208+
from inspect_ai import Task
209+
from inspect_ai.scorer import match, model_graded_qa
210+
from tk.llmbda.inspect import skill_solver, step_scorer
211+
212+
Task(
213+
dataset=tickets,
214+
solver=skill_solver(support_triage),
215+
scorer=[
216+
step_scorer("λ::identifiers", match(location="any")),
217+
step_scorer("ψ::draft", model_graded_qa()),
218+
match(), # final completion
219+
],
220+
)
221+
```
222+
223+
- `entry=` on `skill_solver` customises how the skill input is extracted from `TaskState` (default: `s.input_text`).
224+
- `project=` on `step_scorer` stringifies non-str step values before the inner scorer sees them (default: `str`; pass `json.dumps` for dicts).
225+
- Metrics are inherited from the inner scorer; override with `metrics=[...]`.
226+
227+
### Install and run
228+
229+
- **As a library user:** `pip install tk-llmbda[inspect]` — the `inspect` extra pulls in `inspect-ai`.
230+
- **In this repo:** `inspect-ai` is already in the dev dependency group, so `uv sync` installs it automatically.
231+
232+
A runnable end-to-end example lives in `examples/triage/scoring.py` (the skill itself is defined in `examples/triage/skill.py` and reused by `main.py`):
233+
234+
```bash
235+
# run the skill + Inspect eval (scripted model, no API keys needed)
236+
uv run examples/triage/scoring.py
237+
238+
# browse per-sample traces, scorer breakdowns, and step values in the viewer
239+
uv run inspect view
240+
```
241+
242+
Every `inspect_eval(...)` call writes a `.eval` log file under `./logs/` which `inspect view` picks up automatically.
243+
197244
## Development
198245

199246
```bash

examples/triage/main.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# %% [markdown]
2+
# # Support triage: deterministic extraction, scripted LLM steps, repair loop
3+
#
4+
# Runs `support_triage` from `skill.py` on the bundled tickets. See
5+
# `scoring.py` for the Inspect AI evaluation of the same skill.
6+
7+
# %%
8+
import json
9+
10+
from skill import TICKETS, run_skill, support_triage
11+
12+
# %% [markdown]
13+
# ## Run the skill on every ticket
14+
15+
# %%
16+
for ticket in TICKETS:
17+
result = run_skill(support_triage, ticket)
18+
print(f"\n{ticket['id']} · {ticket['subject']}")
19+
print(f"resolved_by: {result.resolved_by}")
20+
print(json.dumps(result.value, indent=2))
21+
print(f"validation: {result.metadata}")
22+
23+
# %% [markdown]
24+
# ## Inspect one trace
25+
26+
# %%
27+
result = run_skill(support_triage, TICKETS[1])
28+
for name, step_result in result.trace.items():
29+
print(f"\n{name}")
30+
print(f"value: {json.dumps(step_result.value, indent=2)}")
31+
print(f"metadata: {json.dumps(step_result.metadata, indent=2)}")

examples/triage/scoring.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# %%
2+
"""Inspect AI scoring for the support triage skill.
3+
4+
Run: uv run python examples/triage/scoring.py
5+
With LLM grader:
6+
INSPECT_GRADER=openai/gpt-4o-mini uv run python examples/triage/scoring.py
7+
View logs: uv run inspect view (from examples/triage/)
8+
"""
9+
10+
import os
11+
12+
from inspect_ai import Task
13+
from inspect_ai import eval as inspect_eval
14+
from inspect_ai.dataset import Sample
15+
from inspect_ai.log import EvalLog
16+
from inspect_ai.scorer import (
17+
Metric,
18+
Score,
19+
Target,
20+
accuracy,
21+
match,
22+
mean,
23+
metric,
24+
model_graded_qa,
25+
scorer,
26+
stderr,
27+
)
28+
from inspect_ai.solver import TaskState
29+
from skill import CLASSIFY, DRAFT, IDENTIFIERS, SUMMARIZE, TICKETS, support_triage
30+
31+
from tk.llmbda.inspect import skill_solver, step_scorer
32+
33+
# %%
34+
EXPECTED = {
35+
"SUP-1001": ("billing_refund", "P2"),
36+
"SUP-1002": ("production_incident", "P0"),
37+
"SUP-1003": ("account_access", "P1"), # mis-expects P1 so one cell fails
38+
}
39+
40+
EVAL_SAMPLES = [
41+
Sample(
42+
id=t["id"],
43+
input=t["subject"],
44+
target=list(EXPECTED[t["id"]]),
45+
metadata={"ticket": t},
46+
)
47+
for t in TICKETS
48+
]
49+
50+
51+
# %%
52+
def _trace(state: TaskState) -> dict:
53+
return (state.metadata or {}).get("llmbda.trace", {})
54+
55+
56+
classify_matches_intent = step_scorer(
57+
CLASSIFY,
58+
match(location="exact"),
59+
project=lambda v: v["intent"],
60+
)
61+
62+
63+
@scorer(metrics=[accuracy(), stderr()])
64+
def draft_priority_scorer():
65+
async def score(state: TaskState, target: Target) -> Score:
66+
got = _trace(state)[DRAFT].value["priority"]
67+
want = target[1]
68+
return Score(
69+
value="C" if got == want else "I",
70+
answer=got,
71+
explanation=f"expected {want!r}, got {got!r}",
72+
)
73+
74+
return score
75+
76+
77+
# %% LLM-graded reply quality (requires INSPECT_GRADER env var)
78+
REPLY_QUALITY_TEMPLATE = """\
79+
You are evaluating a customer support reply for quality.
80+
81+
[BEGIN DATA]
82+
***
83+
[Customer request]: {question}
84+
***
85+
[Support reply]: {answer}
86+
***
87+
[END DATA]
88+
89+
Grade the reply as CORRECT if it:
90+
- Acknowledges the customer's specific issue
91+
- Is professional and actionable
92+
- Requests missing information when identifiers are absent
93+
94+
Grade as INCORRECT if the reply is generic, dismissive, or ignores
95+
key details from the request.
96+
97+
{instructions}
98+
"""
99+
100+
if grader := os.environ.get(gradevar := "INSPECT_GRADER"):
101+
g = model_graded_qa(template=REPLY_QUALITY_TEMPLATE, model=grader)
102+
draft_reply_quality = step_scorer(DRAFT, g, project=lambda v: v["customer_reply"])
103+
else:
104+
print(f"[W] env: `set {gradevar}` to run model judge")
105+
106+
107+
# %% Heuristic reply scorer — partial credit (0.0 / 0.5 / 1.0)
108+
_ISSUE_KW = ["refund", "charge", "outage", "escalat", "access", "restore"]
109+
110+
111+
@scorer(metrics=[mean(), stderr()])
112+
def draft_reply_heuristic():
113+
async def score(state: TaskState, target: Target) -> Score: # noqa: ARG001
114+
tr = _trace(state)
115+
reply = tr[DRAFT].value.get("customer_reply", "").lower()
116+
missing_ids = not tr[IDENTIFIERS].value["account_ids"]
117+
ack = 0.5 if any(kw in reply for kw in _ISSUE_KW) else 0.0
118+
info = (0.5 if "account" in reply else 0.0) if missing_ids else 0.5
119+
pts = ack + info
120+
reasons = []
121+
if ack:
122+
reasons.append("acknowledges issue")
123+
else:
124+
reasons.append("generic reply")
125+
if missing_ids:
126+
msg = "requests missing id" if info else "missing id not requested"
127+
reasons.append(msg)
128+
else:
129+
reasons.append("no missing ids")
130+
return Score(value=pts, answer=reply, explanation="; ".join(reasons))
131+
132+
return score
133+
134+
135+
# %%
136+
@metric
137+
def strict_accuracy() -> Metric:
138+
def m(scores: list) -> float:
139+
if not scores:
140+
return 0.0
141+
return sum(float(s.score.value) == 1.0 for s in scores) / len(scores)
142+
143+
return m
144+
145+
146+
@scorer(metrics=[accuracy(), stderr(), strict_accuracy()])
147+
def final_status_scorer():
148+
async def score(state: TaskState, target: Target) -> Score: # noqa: ARG001
149+
status = _trace(state)[SUMMARIZE].value.get("status")
150+
return Score(
151+
value="C" if status == "validated" else "I",
152+
answer=str(status),
153+
explanation=f"status={status!r}",
154+
)
155+
156+
return score
157+
158+
159+
# %%
160+
_scorers = [
161+
classify_matches_intent,
162+
draft_priority_scorer(),
163+
draft_reply_heuristic(),
164+
final_status_scorer(),
165+
]
166+
if draft_reply_quality is not None:
167+
_scorers.insert(2, draft_reply_quality)
168+
169+
eval_task = Task(
170+
name="support_triage_eval",
171+
dataset=EVAL_SAMPLES,
172+
solver=skill_solver(support_triage, entry=lambda s: s.metadata["ticket"]),
173+
scorer=_scorers,
174+
)
175+
176+
eval_logs = inspect_eval(eval_task, model="none/none", display="none")
177+
assert isinstance((log := eval_logs[0]), EvalLog), f"{log=}" # noqa: RUF018
178+
179+
# %%
180+
print(f"status: {log.status}")
181+
if log.status != "success":
182+
if log.error:
183+
print(f"error: {log.error.message}")
184+
if log.error.traceback:
185+
print(log.error.traceback)
186+
raise SystemExit(1)
187+
188+
assert log.results is not None
189+
for sr in log.results.scores:
190+
print(f"\n{sr.name}")
191+
for name, mr in sr.metrics.items():
192+
print(f" {name:16s} = {mr.value:.3f}")
193+
194+
# %%
195+
assert log.samples is not None
196+
for sample in log.samples:
197+
print(f"\n{sample.id}")
198+
assert sample.scores is not None
199+
for name, sc in sample.scores.items():
200+
print(f" {name:28s} {sc.value} ({sc.explanation})")

0 commit comments

Comments
 (0)