Skip to content

Commit 51c7531

Browse files
authored
Merge pull request #2087 from confident-ai/mayank/agents_trace_support_1
open ai agents trace support
2 parents 12353b1 + cd65595 commit 51c7531

5 files changed

Lines changed: 113 additions & 38 deletions

File tree

deepeval/openai_agents/callback_handler.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from deepeval.tracing.tracing import (
22
Observer,
33
current_span_context,
4+
trace_manager,
45
)
56
from deepeval.openai_agents.extractors import *
67
from deepeval.tracing.context import current_trace_context
8+
from deepeval.tracing.utils import make_json_serializable
9+
from time import perf_counter
10+
from deepeval.tracing.types import TraceSpanStatus
711

812
try:
913
from agents.tracing import Span, Trace, TracingProcessor
@@ -33,14 +37,49 @@ def _check_openai_agents_available():
3337
class DeepEvalTracingProcessor(TracingProcessor):
3438
def __init__(self) -> None:
3539
_check_openai_agents_available()
36-
self.root_span_observers: dict[str, Observer] = {}
3740
self.span_observers: dict[str, Observer] = {}
3841

3942
def on_trace_start(self, trace: "Trace") -> None:
40-
pass
43+
trace_dict = trace.export()
44+
_trace_uuid = trace_dict.get("id")
45+
_thread_id = trace_dict.get("group_id")
46+
_trace_name = trace_dict.get("workflow_name")
47+
_trace_metadata = trace_dict.get("metadata")
48+
49+
if _thread_id or _trace_metadata:
50+
_trace = trace_manager.start_new_trace(trace_uuid=str(_trace_uuid))
51+
_trace.thread_id = str(_thread_id)
52+
_trace.name = str(_trace_name)
53+
_trace.metadata = make_json_serializable(_trace_metadata)
54+
current_trace_context.set(_trace)
55+
56+
trace_manager.add_span( # adds a dummy root span
57+
BaseSpan(
58+
uuid=_trace_uuid,
59+
trace_uuid=_trace_uuid,
60+
parent_uuid=None,
61+
start_time=perf_counter(),
62+
name=_trace_name,
63+
status=TraceSpanStatus.IN_PROGRESS,
64+
children=[],
65+
)
66+
)
67+
else:
68+
current_trace = current_trace_context.get()
69+
if current_trace:
70+
current_trace.name = str(_trace_name)
4171

4272
def on_trace_end(self, trace: "Trace") -> None:
43-
pass
73+
trace_dict = trace.export()
74+
_trace_uuid = trace_dict.get("id")
75+
_thread_id = trace_dict.get("group_id")
76+
_trace_name = trace_dict.get("workflow_name")
77+
_trace_metadata = trace_dict.get("metadata")
78+
79+
if _thread_id or _trace_metadata:
80+
trace_manager.remove_span(_trace_uuid) # removing the dummy root span
81+
trace_manager.end_trace(_trace_uuid)
82+
current_trace_context.set(None)
4483

4584
def on_span_start(self, span: "Span") -> None:
4685
if not span.started_at:

deepeval/openai_agents/runner.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,9 @@ async def run(
109109
metric_collection=metric_collection,
110110
metrics=metrics,
111111
func_name="run",
112-
function_kwargs={"input": input},
112+
function_kwargs={"input": input}, # also set below
113113
) as observer:
114114
update_trace_attributes(
115-
input=input,
116115
name=name,
117116
tags=tags,
118117
metadata=metadata,
@@ -123,7 +122,8 @@ async def run(
123122
)
124123
current_span = current_span_context.get()
125124
current_trace = current_trace_context.get()
126-
current_trace.input = input
125+
if not current_trace.input:
126+
current_trace.input = input
127127
if current_span:
128128
current_span.input = input
129129
res = await super().run(
@@ -138,8 +138,9 @@ async def run(
138138
session=session,
139139
**kwargs, # backwards compatibility
140140
)
141+
current_trace_thread_id = current_trace_context.get().thread_id
141142
_output = None
142-
if thread_id:
143+
if current_trace_thread_id:
143144
_output = res.final_output
144145
else:
145146
_output = str(res)
@@ -170,30 +171,30 @@ def run_sync(
170171
**kwargs,
171172
) -> RunResult:
172173
is_agents_available()
173-
input_val = input
174-
175-
update_trace_attributes(
176-
input=input_val,
177-
name=name,
178-
tags=tags,
179-
metadata=metadata,
180-
thread_id=thread_id,
181-
user_id=user_id,
182-
metric_collection=metric_collection,
183-
metrics=metrics,
184-
)
185174

186175
with Observer(
187176
span_type="custom",
188177
metric_collection=metric_collection,
189178
metrics=metrics,
190179
func_name="run_sync",
191-
function_kwargs={"input": input_val},
180+
function_kwargs={"input": input}, # also set below
192181
) as observer:
182+
update_trace_attributes(
183+
name=name,
184+
tags=tags,
185+
metadata=metadata,
186+
thread_id=thread_id,
187+
user_id=user_id,
188+
metric_collection=metric_collection,
189+
metrics=metrics,
190+
)
191+
193192
current_span = current_span_context.get()
194193
current_trace = current_trace_context.get()
194+
if not current_trace.input:
195+
current_trace.input = input
195196
if current_span:
196-
current_span.input = input_val
197+
current_span.input = input
197198
res = super().run_sync(
198199
starting_agent,
199200
input,
@@ -206,8 +207,9 @@ def run_sync(
206207
session=session,
207208
**kwargs, # backwards compatibility
208209
)
210+
current_trace_thread_id = current_trace_context.get().thread_id
209211
_output = None
210-
if thread_id:
212+
if current_trace_thread_id:
211213
_output = res.final_output
212214
else:
213215
_output = str(res)
@@ -250,7 +252,6 @@ def run_streamed(
250252
observer.__enter__()
251253

252254
update_trace_attributes(
253-
input=input,
254255
name=name,
255256
tags=tags,
256257
metadata=metadata,
@@ -259,7 +260,10 @@ def run_streamed(
259260
metric_collection=metric_collection,
260261
metrics=metrics,
261262
)
262-
263+
current_trace = current_trace_context.get()
264+
if not current_trace.input:
265+
current_trace.input = input
266+
263267
current_span = current_span_context.get()
264268
if current_span:
265269
current_span.input = input

tests/test_integrations/test_openai_agents/agents_app.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DeepEvalTracingProcessor,
88
)
99

10+
from deepeval.tracing.context import update_current_trace
1011
from deepeval.prompt import Prompt
1112

1213
add_trace_processor(DeepEvalTracingProcessor())
@@ -98,24 +99,55 @@ async def run_weather_agent(user_input: str):
9899
weather_agent,
99100
user_input,
100101
metric_collection="test_collection_1",
101-
name="test_name_1",
102-
user_id="test_user_id_1",
103-
thread_id="test_thread_id_1",
104-
tags=["test_tag_1"],
105-
metadata={"test_metadata_1": "test_metadata_1"},
102+
# name="test_name_1",
103+
# user_id="test_user_id_1",
104+
# thread_id="test_thread_id_1",
105+
# tags=["test_tag_1"],
106+
# metadata={"test_metadata_1": "test_metadata_1"},
106107
)
107108
return result.final_output
108109

110+
from agents import trace
111+
from multi_agents import triage_agent
112+
# with trace (group_id and metadata)
113+
async def main1():
114+
with trace(workflow_name="test_workflow_1", group_id="test_group_id_1", metadata={"test_metadata_1": "test_metadata_1"}):
115+
user_query = "What's the weather like in London today?"
116+
response_1 = await Runner.run(triage_agent, "Hola, ¿cómo estás?", metric_collection="test_collection_1", thread_id="test")
117+
response_2 = await Runner.run(weather_agent, user_query, metric_collection="test_collection_1")
118+
update_current_trace(input="initial input", output="final output")
119+
120+
# without trace (group_id and metadata not present)
121+
async def main2():
122+
user_query = "What's the weather like in London today?"
123+
response_1 = await Runner.run(triage_agent, "Hola, ¿cómo estás?", metric_collection="test_collection_1", thread_id="test")
124+
response_2 = await Runner.run(weather_agent, user_query, metric_collection="test_collection_1")
125+
109126

110-
# Usage example
111-
async def main():
127+
async def main3():
112128
user_query = "What's the weather like in London today?"
113-
response = await run_weather_agent(user_query)
114-
print(f"Agent Response: {response}")
129+
with trace(workflow_name="test_workflow_1", group_id="test_group_id_1", metadata={"test_metadata_1": "test_metadata_1"}):
130+
response_2 = await Runner.run(weather_agent, user_query, metric_collection="test_collection_1")
131+
with trace(workflow_name="test_workflow_2", group_id="test_group_id_2", metadata={"test_metadata_2": "test_metadata_2"}):
132+
response_1 = await Runner.run(triage_agent, "Hola, ¿cómo estás?", metric_collection="test_collection_1", thread_id="test")
115133

134+
async def main4():
135+
user_query = "What's the weather like in London today?"
136+
with trace(workflow_name="test_workflow_1", group_id="test_group_id_1", metadata={"test_metadata_1": "test_metadata_1"}):
137+
run_streamed_1 = Runner.run_streamed(weather_agent, user_query, metric_collection="test_collection_1")
138+
async for chunk in run_streamed_1.stream_events():
139+
print(chunk, end="", flush=True)
140+
print("=" * 50)
141+
run_streamed_2 = Runner.run_streamed(triage_agent, "Hola, ¿cómo estás?", metric_collection="test_collection_1", thread_id="test")
142+
async for chunk in run_streamed_2.stream_events():
143+
print(chunk, end="", flush=True)
144+
print("=" * 50)
116145

117146
def execute_agent():
118-
return asyncio.run(main())
147+
asyncio.run(main1())
148+
# asyncio.run(main2())
149+
# asyncio.run(main3())
150+
# asyncio.run(main4())
119151

120152

121-
execute_agent()
153+
# execute_agent()

tests/test_integrations/test_openai_agents/multi_agents.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
from deepeval.openai_agents import Agent, Runner
33
from deepeval.prompt import Prompt
4-
from deepeval.openai_agents import DeepEvalTracingProcessor
4+
# from deepeval.openai_agents import DeepEvalTracingProcessor
55

6-
from agents import add_trace_processor
6+
# from agents import add_trace_processor
77

8-
add_trace_processor(DeepEvalTracingProcessor())
8+
# add_trace_processor(DeepEvalTracingProcessor())
99

1010
prompt = Prompt(alias="asd")
1111
prompt.pull(version="00.00.01")

tests/test_integrations/test_pydanticai/ff.py

Whitespace-only changes.

0 commit comments

Comments
 (0)