Skip to content

Commit 8bbe253

Browse files
committed
drop bad training examples
1 parent 897a977 commit 8bbe253

2 files changed

Lines changed: 131 additions & 20 deletions

File tree

surogates/jobs/training_collector.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33+
_SKILL_TRAJECTORY_BOUNDARY: frozenset[str] = frozenset({
34+
EventType.USER_MESSAGE.value,
35+
EventType.SKILL_INVOKED.value,
36+
EventType.SESSION_COMPLETE.value,
37+
EventType.SESSION_FAIL.value,
38+
})
39+
40+
3341
def _strip_skill_prefix(raw_message: str, skill_name: str) -> str:
3442
"""Remove the leading ``/<skill_name>`` from *raw_message*.
3543
@@ -302,7 +310,12 @@ async def collect_for_skill(
302310
exclude_tainted:
303311
When True (default), sessions with ``policy.denied``,
304312
``harness.crash``, ``saga.compensate``, or
305-
``expert.override`` events are skipped entirely.
313+
``expert.override`` events are skipped entirely, and any
314+
individual trajectory whose assistant response received a
315+
``user.feedback`` with ``rating: "down"`` is rejected.
316+
Per-trajectory granularity is intentional: one thumbs-down
317+
on an unrelated response in the same session should not
318+
poison sibling invocations with their own class labels.
306319
"""
307320
invocations = await self._session_store.find_skill_invocations(
308321
org_id, skill_name, since=since,
@@ -322,6 +335,7 @@ async def collect_for_skill(
322335

323336
examples: list[TrainingExample] = []
324337
skipped_tainted = 0
338+
skipped_thumbs_down = 0
325339

326340
for session_id, session_invocations in by_session.items():
327341
if exclude_tainted:
@@ -332,20 +346,37 @@ async def collect_for_skill(
332346
events = await self._session_store.get_events(session_id)
333347
events_by_id = {e.id: i for i, e in enumerate(events)}
334348

349+
# Scanned once per session so every trajectory can share it.
350+
down_rated_response_ids: set[int] = set()
351+
if exclude_tainted:
352+
for e in events:
353+
if e.type != EventType.USER_FEEDBACK.value:
354+
continue
355+
if e.data.get("rating") != "down":
356+
continue
357+
target = e.data.get("target_event_id")
358+
if isinstance(target, int):
359+
down_rated_response_ids.add(target)
360+
335361
for inv in session_invocations:
336362
start_idx = events_by_id.get(inv.id)
337363
if start_idx is None:
338364
continue
339-
example = self._collect_skill_trajectory(
365+
example, rejected_thumbs_down = self._collect_skill_trajectory(
340366
events, start_idx, skill_name, session_id, inv,
367+
down_rated_response_ids=down_rated_response_ids,
341368
)
342-
if example is not None:
369+
if rejected_thumbs_down:
370+
skipped_thumbs_down += 1
371+
elif example is not None:
343372
examples.append(example)
344373

345374
logger.info(
346375
"Collected %d training examples for skill '%s' "
347-
"(%d invocations, %d tainted sessions skipped)",
348-
len(examples), skill_name, len(invocations), skipped_tainted,
376+
"(%d invocations, %d tainted sessions skipped, "
377+
"%d trajectories skipped for thumbs-down)",
378+
len(examples), skill_name, len(invocations),
379+
skipped_tainted, skipped_thumbs_down,
349380
)
350381
return examples
351382

@@ -356,18 +387,28 @@ def _collect_skill_trajectory(
356387
skill_name: str,
357388
session_id: UUID,
358389
skill_event: Any,
359-
) -> TrainingExample | None:
390+
*,
391+
down_rated_response_ids: set[int] | None = None,
392+
) -> tuple[TrainingExample | None, bool]:
360393
"""Walk events from a ``skill.invoked`` through its trajectory.
361394
362-
Returns a :class:`TrainingExample` when a complete trajectory
363-
is found (at least one assistant response with content). The
364-
trajectory ends at the first of: next ``user.message``, next
365-
``skill.invoked``, ``session.complete`` or ``session.fail``.
395+
Returns ``(example, rejected_thumbs_down)``. ``example`` is a
396+
:class:`TrainingExample` when a complete trajectory is found
397+
(at least one assistant response with content), otherwise
398+
``None``. ``rejected_thumbs_down`` is ``True`` iff the
399+
trajectory was dropped because an ``llm.response`` within it
400+
received a judge thumbs-down (id in *down_rated_response_ids*);
401+
a down-rated assistant turn is the wrong class label for
402+
*skill_name* and must not be exported.
403+
404+
The trajectory ends at the first of: next ``user.message``,
405+
next ``skill.invoked``, ``session.complete`` or
406+
``session.fail``.
366407
"""
367408
raw_message = skill_event.data.get("raw_message", "")
368409
user_text = _strip_skill_prefix(raw_message, skill_name)
369410
if not user_text:
370-
return None
411+
return None, False
371412

372413
messages: list[dict[str, Any]] = [
373414
{"role": "user", "content": user_text},
@@ -378,15 +419,15 @@ def _collect_skill_trajectory(
378419
event = events[i]
379420
etype = event.type
380421

381-
if etype in (
382-
EventType.USER_MESSAGE.value,
383-
EventType.SKILL_INVOKED.value,
384-
EventType.SESSION_COMPLETE.value,
385-
EventType.SESSION_FAIL.value,
386-
):
387-
break # trajectory boundary
422+
if etype in _SKILL_TRAJECTORY_BOUNDARY:
423+
break
388424

389425
if etype == EventType.LLM_RESPONSE.value:
426+
if (
427+
down_rated_response_ids
428+
and event.id in down_rated_response_ids
429+
):
430+
return None, True
390431
msg = event.data.get("message")
391432
if not isinstance(msg, dict):
392433
continue
@@ -409,14 +450,15 @@ def _collect_skill_trajectory(
409450
# ``llm.response`` message's ``tool_calls`` field.
410451

411452
if not has_final_assistant_content:
412-
return None
453+
return None, False
413454

414-
return TrainingExample(
455+
example = TrainingExample(
415456
messages=messages,
416457
session_id=session_id,
417458
expert_name=skill_name,
418459
created_at=getattr(skill_event, "created_at", None),
419460
)
461+
return example, False
420462

421463
async def export_jsonl(
422464
self,

tests/integration/test_training_collector.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,75 @@ async def test_collect_for_skill_excludes_tainted_session(
246246
assert {e.session_id for e in all_examples} == {clean.id, tainted.id}
247247

248248

249+
async def test_collect_for_skill_excludes_trajectory_with_thumbs_down(
250+
session_store, session_factory,
251+
):
252+
"""A ``user.feedback`` rating=down on an LLM response in a trajectory
253+
rejects that trajectory only — sibling invocations in the same session
254+
still yield training examples.
255+
256+
Regression: ``session_has_taint`` only checks ``policy.denied`` and
257+
friends, so the judge's thumbs-down verdicts used to pass the filter
258+
and poison the training set with negative class labels.
259+
"""
260+
org_id = await create_org(session_factory)
261+
user_id = await create_user(session_factory, org_id)
262+
session = await session_store.create_session(
263+
user_id=user_id, org_id=org_id, agent_id="test-agent",
264+
)
265+
266+
# First invocation: rated down — should be excluded.
267+
await session_store.emit_event(
268+
session.id, EventType.USER_MESSAGE,
269+
{"content": "/sql_writer bad query"},
270+
)
271+
await session_store.emit_event(
272+
session.id, EventType.SKILL_INVOKED,
273+
{"skill": "sql_writer", "raw_message": "/sql_writer bad query",
274+
"staged_at": None},
275+
)
276+
bad_response_id = await session_store.emit_event(
277+
session.id, EventType.LLM_RESPONSE,
278+
{
279+
"message": {"role": "assistant", "content": "SELECT 1;"},
280+
"model": "gpt-4o",
281+
"input_tokens": 1,
282+
"output_tokens": 1,
283+
},
284+
)
285+
await session_store.emit_event(
286+
session.id, EventType.USER_FEEDBACK,
287+
{
288+
"target_event_id": bad_response_id,
289+
"rating": "down",
290+
"source": "service_account",
291+
"rated_by_service_account_id": "00000000-0000-0000-0000-000000000001",
292+
"reason": "query missed the WHERE clause",
293+
},
294+
)
295+
296+
# Second invocation: untouched — should survive.
297+
await _seed_skill_invocation(
298+
session_store, session.id,
299+
raw_message="/sql_writer good query",
300+
assistant_content="SELECT * FROM users;",
301+
)
302+
303+
collector = TrainingDataCollector(session_store=session_store)
304+
examples = await collector.collect_for_skill("sql_writer", org_id)
305+
306+
assert len(examples) == 1
307+
assert examples[0].messages[0]["content"] == "good query"
308+
assert examples[0].messages[-1]["content"] == "SELECT * FROM users;"
309+
310+
# With exclude_tainted=False the rejected trajectory comes back.
311+
all_examples = await collector.collect_for_skill(
312+
"sql_writer", org_id, exclude_tainted=False,
313+
)
314+
asks = sorted(ex.messages[0]["content"] for ex in all_examples)
315+
assert asks == ["bad query", "good query"]
316+
317+
249318
async def test_collect_for_skill_skips_trajectory_with_no_final_assistant(
250319
session_store, session_factory,
251320
):

0 commit comments

Comments
 (0)