Skip to content

Commit a6078ca

Browse files
fix TraceTree/match_rewards assign_to elements. (#403)
1 parent f1a8072 commit a6078ca

2 files changed

Lines changed: 68 additions & 4 deletions

File tree

agentlightning/adapter/triplet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,12 +492,12 @@ def match_rewards(self, reward_match: str, llm_calls: List["TraceTree"]) -> dict
492492
assign_to: List[Tuple[str, int]] = []
493493
for child in item.children:
494494
if child.id in llm_call_ids:
495-
assign_to.append(child.id) # type: ignore
495+
assign_to.append((child.id, child.end_time)) # type: ignore
496496

497-
agentops_output = item.maybe_reward_dict()
497+
agentops_output = child.maybe_reward_dict()
498498
if agentops_output and agentops_output.get("type") == "reward":
499499
for assign_to_id, assign_to_end_time in reversed(assign_to):
500-
if assign_to_end_time > item.start_time: # type: ignore
500+
if assign_to_end_time > child.start_time: # type: ignore
501501
# This reward happens before the end of the LLM call.
502502
continue
503503
if assign_to_id in rewards:

tests/adapter/test_triplet_trace_tree.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from agentlightning.adapter.triplet import TracerTraceToTriplet, TraceTree
9+
from agentlightning.adapter.triplet import RewardMatchPolicy, TracerTraceToTriplet, TraceTree
1010
from agentlightning.types import Span
1111
from agentlightning.types.tracer import SpanNames
1212
from agentlightning.utils.otel import filter_and_unflatten_attributes
@@ -598,6 +598,70 @@ def test_extract_prompt_image_urls_from_list_payload():
598598
assert image_urls == ["https://example.com/a.png", "https://example.com/b.png"]
599599

600600

601+
def test_tracer_trace_to_triplet_reward_match_first_sibling():
602+
root = make_span("root", "session", parent_id=None, start_time=0.0, end_time=10.0)
603+
agent = make_span(
604+
"agent",
605+
"agent.node",
606+
parent_id="root",
607+
start_time=1.0,
608+
end_time=9.0,
609+
attributes={"agent.name": "sibling-agent"},
610+
)
611+
other_agent = make_span(
612+
"agent-2",
613+
"agent.node",
614+
parent_id="root",
615+
start_time=1.0,
616+
end_time=9.0,
617+
attributes={"agent.name": "sibling-agent"},
618+
)
619+
llm_1 = make_llm_span(
620+
"llm-1",
621+
parent_id="agent",
622+
start=2.0,
623+
end=3.0,
624+
prompt_ids=[1],
625+
response_ids=[2],
626+
response_id="resp-1",
627+
)
628+
reward = make_span(
629+
"reward",
630+
"agent.reward",
631+
parent_id="agent",
632+
start_time=3.5,
633+
end_time=3.6,
634+
attributes=reward_attributes(0.8),
635+
)
636+
llm_2 = make_llm_span(
637+
"llm-2",
638+
parent_id="agent-2",
639+
start=3.1,
640+
end=3.2,
641+
prompt_ids=[3],
642+
response_ids=[4],
643+
response_id="resp-2",
644+
)
645+
646+
spans = [root, agent, other_agent, llm_1, reward, llm_2]
647+
648+
adapter = TracerTraceToTriplet(
649+
agent_match="sibling-agent",
650+
reward_match=RewardMatchPolicy.FIRST_SIBLING,
651+
_skip_empty_token_spans=True,
652+
)
653+
triplets = adapter.adapt(spans)
654+
655+
assert len(triplets) == 2
656+
t1, t2 = triplets
657+
658+
assert t1.metadata["response_id"] == "resp-1"
659+
assert t1.reward == 0.8
660+
661+
assert t2.metadata["response_id"] == "resp-2"
662+
assert t2.reward is None
663+
664+
601665
def test_extract_prompt_image_urls_handles_numeric_dict_keys():
602666
tree = make_trace_tree_root()
603667
prompt_raw_content = {

0 commit comments

Comments
 (0)