Skip to content

Commit aa5bfcf

Browse files
authored
Merge pull request #2084 from confident-ai/mayank/openai_agents_types
OpenAI agents types
2 parents f82e3a6 + 0f9d811 commit aa5bfcf

6 files changed

Lines changed: 490 additions & 215 deletions

File tree

deepeval/openai_agents/agent.py

Lines changed: 113 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,40 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field, replace
4-
from typing import Any, Optional, Awaitable, Callable, Generic, TypeVar
4+
from typing import Any, Optional, Awaitable, Callable, Generic, TypeVar, List
55

66
from deepeval.tracing import observe
77
from deepeval.prompt import Prompt
8+
from deepeval.tracing.tracing import Observer
9+
from deepeval.metrics import BaseMetric
10+
from deepeval.tracing.utils import make_json_serializable
11+
from deepeval.tracing.types import LlmSpan
12+
from deepeval.tracing.context import current_span_context
813

914
try:
1015
from agents.agent import Agent as BaseAgent
1116
from agents.models.interface import Model, ModelProvider
17+
from openai.types.responses import ResponseCompletedEvent
1218
except Exception as e:
1319
raise RuntimeError(
1420
"openai-agents is required for this integration. Please install it."
1521
) from e
1622

1723
TContext = TypeVar("TContext")
1824

19-
2025
class _ObservedModel(Model):
2126
def __init__(
2227
self,
2328
inner: Model,
24-
*,
25-
metrics: Optional[list[Any]] = None,
26-
metric_collection: Optional[str] = None,
27-
deepeval_prompt: Optional[Any] = None,
29+
llm_metric_collection: str = None,
30+
llm_metrics: List[BaseMetric] = None,
31+
confident_prompt: Prompt = None,
2832
) -> None:
2933
self._inner = inner
30-
self._metrics = metrics
31-
self._metric_collection = metric_collection
32-
self._deepeval_prompt = deepeval_prompt
34+
self._llm_metric_collection = llm_metric_collection
35+
self._llm_metrics = llm_metrics
36+
self._confident_prompt = confident_prompt
3337

34-
# Delegate attributes not overridden
3538
def __getattr__(self, name: str) -> Any:
3639
return getattr(self._inner, name)
3740

@@ -59,29 +62,48 @@ async def get_response(
5962
previous_response_id,
6063
conversation_id,
6164
prompt,
65+
**kwargs,
6266
):
6367
model_name = self._get_model_name()
64-
65-
wrapped = observe(
66-
metrics=self._metrics,
67-
metric_collection=self._metric_collection,
68-
type="llm",
69-
model=model_name,
70-
prompt=self._deepeval_prompt,
71-
)(self._inner.get_response)
72-
73-
return await wrapped(
74-
system_instructions,
75-
input,
76-
model_settings,
77-
tools,
78-
output_schema,
79-
handoffs,
80-
tracing,
81-
previous_response_id=previous_response_id,
82-
conversation_id=conversation_id,
83-
prompt=prompt,
84-
)
68+
with Observer(
69+
span_type="llm",
70+
func_name="LLM",
71+
function_kwargs={
72+
"system_instructions": system_instructions,
73+
"input": input,
74+
"model_settings": model_settings,
75+
"tools": tools,
76+
"output_schema": output_schema,
77+
"handoffs": handoffs,
78+
# "tracing": tracing, # not important for llm spans
79+
# "previous_response_id": previous_response_id, # not important for llm spans
80+
# "conversation_id": conversation_id, # not important for llm spans
81+
"prompt": prompt,
82+
**kwargs,
83+
},
84+
observe_kwargs={"model": model_name},
85+
metrics=self._llm_metrics,
86+
metric_collection=self._llm_metric_collection,
87+
) as observer:
88+
result = await self._inner.get_response(
89+
system_instructions,
90+
input,
91+
model_settings,
92+
tools,
93+
output_schema,
94+
handoffs,
95+
tracing,
96+
previous_response_id=previous_response_id,
97+
conversation_id=conversation_id,
98+
prompt=prompt,
99+
**kwargs,
100+
)
101+
llm_span: LlmSpan = current_span_context.get()
102+
llm_span.prompt = self._confident_prompt
103+
104+
observer.result = make_json_serializable(result.output)
105+
106+
return result
85107

86108
def stream_response(
87109
self,
@@ -96,91 +118,74 @@ def stream_response(
96118
previous_response_id,
97119
conversation_id,
98120
prompt,
121+
**kwargs,
99122
):
100-
# Optional: if you also want to observe streaming, uncomment and wrap similarly.
101-
# wrapped = observe(
102-
# metrics=self._metrics,
103-
# metric_collection=self._metric_collection,
104-
# type="llm",
105-
# model=model_name,
106-
# )(self._inner.stream_response)
107-
# return wrapped(
108-
# system_instructions,
109-
# input,
110-
# model_settings,
111-
# tools,
112-
# output_schema,
113-
# handoffs,
114-
# tracing,
115-
# previous_response_id=previous_response_id,
116-
# conversation_id=conversation_id,
117-
# prompt=prompt,
118-
# )
119-
return self._inner.stream_response(
120-
system_instructions,
121-
input,
122-
model_settings,
123-
tools,
124-
output_schema,
125-
handoffs,
126-
tracing,
127-
previous_response_id=previous_response_id,
128-
conversation_id=conversation_id,
129-
prompt=prompt,
130-
)
131-
132-
133-
class _ObservedProvider(ModelProvider):
134-
def __init__(
135-
self,
136-
base: ModelProvider,
137-
*,
138-
metrics: Optional[list[Any]] = None,
139-
metric_collection: Optional[str] = None,
140-
deepeval_prompt: Optional[Any] = None,
141-
) -> None:
142-
self._base = base
143-
self._metrics = metrics
144-
self._metric_collection = metric_collection
145-
self._deepeval_prompt = deepeval_prompt
123+
model_name = self._get_model_name()
146124

147-
def get_model(self, model_name: str | None) -> Model:
148-
model = self._base.get_model(model_name)
149-
return _ObservedModel(
150-
model,
151-
metrics=self._metrics,
152-
metric_collection=self._metric_collection,
153-
deepeval_prompt=self._deepeval_prompt,
154-
)
125+
async def _gen():
126+
observer = Observer(
127+
span_type="llm",
128+
func_name="LLM",
129+
function_kwargs={
130+
"system_instructions": system_instructions,
131+
"input": input,
132+
"model_settings": model_settings,
133+
"tools": tools,
134+
"output_schema": output_schema,
135+
"handoffs": handoffs,
136+
# "tracing": tracing,
137+
# "previous_response_id": previous_response_id,
138+
# "conversation_id": conversation_id,
139+
"prompt": prompt,
140+
**kwargs,
141+
},
142+
observe_kwargs={"model": model_name},
143+
metrics=self._llm_metrics,
144+
metric_collection=self._llm_metric_collection,
145+
)
146+
observer.__enter__()
147+
148+
llm_span: LlmSpan = current_span_context.get()
149+
llm_span.prompt = self._confident_prompt
155150

151+
try:
152+
async for event in self._inner.stream_response(
153+
system_instructions,
154+
input,
155+
model_settings,
156+
tools,
157+
output_schema,
158+
handoffs,
159+
tracing,
160+
previous_response_id=previous_response_id,
161+
conversation_id=conversation_id,
162+
prompt=prompt,
163+
):
164+
165+
if isinstance(event, ResponseCompletedEvent):
166+
observer.result = event.response.output_text #TODO: support other response types
167+
168+
yield event
169+
170+
observer.__exit__(None, None, None)
171+
except Exception as e:
172+
observer.__exit__(type(e), e, e.__traceback__)
173+
raise
174+
finally:
175+
176+
observer.__exit__(None, None, None)
177+
178+
return _gen()
156179

157180
@dataclass
158181
class DeepEvalAgent(BaseAgent[TContext], Generic[TContext]):
159182
"""
160-
A subclass of agents.Agent that accepts `metrics` and `metric_collection`
161-
and ensures the underlying model's `get_response` is wrapped with deepeval.observe.
183+
A subclass of agents.Agent.
162184
"""
163-
164-
metrics: list[Any] | None = field(default=None)
165-
metric_collection: str | None = field(default=None)
166-
deepeval_prompt: Prompt | None = field(default=None)
185+
llm_metric_collection: str = None
186+
llm_metrics: List[BaseMetric] = None
187+
confident_prompt: Prompt = None
167188

168189
def __post_init__(self):
169190
super().__post_init__()
170-
# If a direct Model instance is set on the agent, wrap it here.
171-
if self.model is not None and not isinstance(self.model, str):
172-
try:
173-
from agents.models.interface import (
174-
Model as _Model,
175-
) # local import for safety
176-
177-
if isinstance(self.model, _Model):
178-
self.model = _ObservedModel(
179-
self.model,
180-
metrics=self.metrics,
181-
metric_collection=self.metric_collection,
182-
deepeval_prompt=self.deepeval_prompt,
183-
)
184-
except Exception:
185-
# If we can't import or wrap, silently skip.
186-
pass
191+

deepeval/openai_agents/callback_handler.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,7 @@ def on_span_start(self, span: "Span") -> None:
4646
if not span.started_at:
4747
return
4848
span_type = self.get_span_kind(span.span_data)
49-
if span_type == "agent":
50-
if isinstance(span.span_data, AgentSpanData):
51-
current_trace = current_trace_context.get()
52-
if current_trace:
53-
current_trace.name = span.span_data.name
54-
55-
if span_type == "tool":
56-
return
57-
elif span_type == "llm":
58-
return
59-
else:
49+
if span_type and span_type == "agent":
6050
observer = Observer(span_type=span_type, func_name="NA")
6151
observer.update_span_properties = (
6252
lambda base_span: update_span_properties(
@@ -68,13 +58,13 @@ def on_span_start(self, span: "Span") -> None:
6858

6959
def on_span_end(self, span: "Span") -> None:
7060
span_type = self.get_span_kind(span.span_data)
71-
if span_type == "llm":
61+
if span_type and span_type == "agent":
7262
current_span = current_span_context.get()
7363
if current_span:
7464
update_span_properties(current_span, span.span_data)
75-
observer = self.span_observers.pop(span.span_id, None)
76-
if observer:
77-
observer.__exit__(None, None, None)
65+
observer = self.span_observers.pop(span.span_id, None)
66+
if observer:
67+
observer.__exit__(None, None, None)
7868

7969
def force_flush(self) -> None:
8070
pass
@@ -85,18 +75,19 @@ def shutdown(self) -> None:
8575
def get_span_kind(self, span_data: "SpanData") -> str:
8676
if isinstance(span_data, AgentSpanData):
8777
return "agent"
88-
if isinstance(span_data, FunctionSpanData):
89-
return "tool"
90-
if isinstance(span_data, MCPListToolsSpanData):
91-
return "tool"
92-
if isinstance(span_data, GenerationSpanData):
93-
return "llm"
94-
if isinstance(span_data, ResponseSpanData):
95-
return "llm"
96-
if isinstance(span_data, HandoffSpanData):
97-
return "custom"
98-
if isinstance(span_data, CustomSpanData):
99-
return "base"
100-
if isinstance(span_data, GuardrailSpanData):
101-
return "base"
102-
return "base"
78+
# if isinstance(span_data, FunctionSpanData):
79+
# return "tool"
80+
# if isinstance(span_data, MCPListToolsSpanData):
81+
# return "tool"
82+
# if isinstance(span_data, GenerationSpanData):
83+
# return "llm"
84+
# if isinstance(span_data, ResponseSpanData):
85+
# return "llm"
86+
# if isinstance(span_data, HandoffSpanData):
87+
# return "custom"
88+
# if isinstance(span_data, CustomSpanData):
89+
# return "base"
90+
# if isinstance(span_data, GuardrailSpanData):
91+
# return "base"
92+
# return "base"
93+
return None

0 commit comments

Comments
 (0)