|
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 |
|
9 | | -from agentlightning.adapter.triplet import TracerTraceToTriplet, TraceTree |
| 9 | +from agentlightning.adapter.triplet import RewardMatchPolicy, TracerTraceToTriplet, TraceTree |
10 | 10 | from agentlightning.types import Span |
11 | 11 | from agentlightning.types.tracer import SpanNames |
12 | 12 | from agentlightning.utils.otel import filter_and_unflatten_attributes |
@@ -598,6 +598,70 @@ def test_extract_prompt_image_urls_from_list_payload(): |
598 | 598 | assert image_urls == ["https://example.com/a.png", "https://example.com/b.png"] |
599 | 599 |
|
600 | 600 |
|
| 601 | +def test_tracer_trace_to_triplet_reward_match_first_sibling(): |
| 602 | + root = make_span("root", "session", parent_id=None, start_time=0.0, end_time=10.0) |
| 603 | + agent = make_span( |
| 604 | + "agent", |
| 605 | + "agent.node", |
| 606 | + parent_id="root", |
| 607 | + start_time=1.0, |
| 608 | + end_time=9.0, |
| 609 | + attributes={"agent.name": "sibling-agent"}, |
| 610 | + ) |
| 611 | + other_agent = make_span( |
| 612 | + "agent-2", |
| 613 | + "agent.node", |
| 614 | + parent_id="root", |
| 615 | + start_time=1.0, |
| 616 | + end_time=9.0, |
| 617 | + attributes={"agent.name": "sibling-agent"}, |
| 618 | + ) |
| 619 | + llm_1 = make_llm_span( |
| 620 | + "llm-1", |
| 621 | + parent_id="agent", |
| 622 | + start=2.0, |
| 623 | + end=3.0, |
| 624 | + prompt_ids=[1], |
| 625 | + response_ids=[2], |
| 626 | + response_id="resp-1", |
| 627 | + ) |
| 628 | + reward = make_span( |
| 629 | + "reward", |
| 630 | + "agent.reward", |
| 631 | + parent_id="agent", |
| 632 | + start_time=3.5, |
| 633 | + end_time=3.6, |
| 634 | + attributes=reward_attributes(0.8), |
| 635 | + ) |
| 636 | + llm_2 = make_llm_span( |
| 637 | + "llm-2", |
| 638 | + parent_id="agent-2", |
| 639 | + start=3.1, |
| 640 | + end=3.2, |
| 641 | + prompt_ids=[3], |
| 642 | + response_ids=[4], |
| 643 | + response_id="resp-2", |
| 644 | + ) |
| 645 | + |
| 646 | + spans = [root, agent, other_agent, llm_1, reward, llm_2] |
| 647 | + |
| 648 | + adapter = TracerTraceToTriplet( |
| 649 | + agent_match="sibling-agent", |
| 650 | + reward_match=RewardMatchPolicy.FIRST_SIBLING, |
| 651 | + _skip_empty_token_spans=True, |
| 652 | + ) |
| 653 | + triplets = adapter.adapt(spans) |
| 654 | + |
| 655 | + assert len(triplets) == 2 |
| 656 | + t1, t2 = triplets |
| 657 | + |
| 658 | + assert t1.metadata["response_id"] == "resp-1" |
| 659 | + assert t1.reward == 0.8 |
| 660 | + |
| 661 | + assert t2.metadata["response_id"] == "resp-2" |
| 662 | + assert t2.reward is None |
| 663 | + |
| 664 | + |
601 | 665 | def test_extract_prompt_image_urls_handles_numeric_dict_keys(): |
602 | 666 | tree = make_trace_tree_root() |
603 | 667 | prompt_raw_content = { |
|
0 commit comments