Skip to content

Commit 57b2135

Browse files
committed
Reformat
1 parent aa5bfcf commit 57b2135

6 files changed

Lines changed: 150 additions & 126 deletions

File tree

deepeval/openai_agents/agent.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
TContext = TypeVar("TContext")
2424

25+
2526
class _ObservedModel(Model):
2627
def __init__(
2728
self,
@@ -69,12 +70,12 @@ async def get_response(
6970
span_type="llm",
7071
func_name="LLM",
7172
function_kwargs={
72-
"system_instructions": system_instructions,
73-
"input": input,
74-
"model_settings": model_settings,
75-
"tools": tools,
76-
"output_schema": output_schema,
77-
"handoffs": handoffs,
73+
"system_instructions": system_instructions,
74+
"input": input,
75+
"model_settings": model_settings,
76+
"tools": tools,
77+
"output_schema": output_schema,
78+
"handoffs": handoffs,
7879
# "tracing": tracing, # not important for llm spans
7980
# "previous_response_id": previous_response_id, # not important for llm spans
8081
# "conversation_id": conversation_id, # not important for llm spans
@@ -102,7 +103,7 @@ async def get_response(
102103
llm_span.prompt = self._confident_prompt
103104

104105
observer.result = make_json_serializable(result.output)
105-
106+
106107
return result
107108

108109
def stream_response(
@@ -163,7 +164,9 @@ async def _gen():
163164
):
164165

165166
if isinstance(event, ResponseCompletedEvent):
166-
observer.result = event.response.output_text #TODO: support other response types
167+
observer.result = (
168+
event.response.output_text
169+
) # TODO: support other response types
167170

168171
yield event
169172

@@ -177,15 +180,16 @@ async def _gen():
177180

178181
return _gen()
179182

183+
180184
@dataclass
181185
class DeepEvalAgent(BaseAgent[TContext], Generic[TContext]):
182186
"""
183187
A subclass of agents.Agent.
184188
"""
189+
185190
llm_metric_collection: str = None
186191
llm_metrics: List[BaseMetric] = None
187192
confident_prompt: Prompt = None
188193

189194
def __post_init__(self):
190195
super().__post_init__()
191-

deepeval/openai_agents/runner.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from agents.run import AgentRunner
2020
from agents.run_context import TContext
2121
from agents.models.interface import Model
22+
2223
agents_available = True
2324
except:
2425
agents_available = False
2526

27+
2628
def is_agents_available():
2729
if not agents_available:
2830
raise ImportError(
@@ -44,33 +46,37 @@ def _patch_default_agent_runner_get_model():
4446
global _PATCHED_DEFAULT_GET_MODEL
4547
if _PATCHED_DEFAULT_GET_MODEL:
4648
return
47-
49+
4850
original_get_model = AgentRunner._get_model
49-
51+
5052
@classmethod
51-
def patched_get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
53+
def patched_get_model(
54+
cls, agent: Agent[Any], run_config: RunConfig
55+
) -> Model:
5256
model = original_get_model(agent, run_config)
53-
57+
5458
# Extract attributes from agent if it's a DeepEvalAgent
55-
llm_metrics = getattr(agent, 'llm_metrics', None)
56-
llm_metric_collection = getattr(agent, 'llm_metric_collection', None)
57-
confident_prompt = getattr(agent, 'confident_prompt', None)
59+
llm_metrics = getattr(agent, "llm_metrics", None)
60+
llm_metric_collection = getattr(agent, "llm_metric_collection", None)
61+
confident_prompt = getattr(agent, "confident_prompt", None)
5862
model = _ObservedModel(
5963
inner=model,
6064
llm_metric_collection=llm_metric_collection,
6165
llm_metrics=llm_metrics,
6266
confident_prompt=confident_prompt,
6367
)
64-
68+
6569
return model
66-
70+
6771
# Replace the method
6872
AgentRunner._get_model = patched_get_model
6973
_PATCHED_DEFAULT_GET_MODEL = True
7074

75+
7176
if agents_available:
7277
_patch_default_agent_runner_get_model()
7378

79+
7480
class Runner(AgentsRunner):
7581

7682
@classmethod
@@ -86,15 +92,14 @@ async def run(
8692
previous_response_id: Optional[str] = None,
8793
conversation_id: Optional[str] = None,
8894
session: Optional[Session] = None,
89-
9095
metrics: Optional[List[BaseMetric]] = None,
9196
metric_collection: Optional[str] = None,
9297
name: Optional[str] = None,
9398
tags: Optional[List[str]] = None,
9499
metadata: Optional[dict] = None,
95100
thread_id: Optional[str] = None,
96101
user_id: Optional[str] = None,
97-
**kwargs, # backwards compatibility
102+
**kwargs, # backwards compatibility
98103
) -> RunResult:
99104
is_agents_available()
100105
# _patch_default_agent_runner_get_model()
@@ -131,7 +136,7 @@ async def run(
131136
previous_response_id=previous_response_id,
132137
conversation_id=conversation_id,
133138
session=session,
134-
**kwargs, # backwards compatibility
139+
**kwargs, # backwards compatibility
135140
)
136141
_output = None
137142
if thread_id:
@@ -155,7 +160,6 @@ def run_sync(
155160
previous_response_id: Optional[str] = None,
156161
conversation_id: Optional[str] = None,
157162
session: Optional[Session] = None,
158-
159163
metrics: Optional[List[BaseMetric]] = None,
160164
metric_collection: Optional[str] = None,
161165
name: Optional[str] = None,
@@ -200,7 +204,7 @@ def run_sync(
200204
previous_response_id=previous_response_id,
201205
conversation_id=conversation_id,
202206
session=session,
203-
**kwargs, # backwards compatibility
207+
**kwargs, # backwards compatibility
204208
)
205209
_output = None
206210
if thread_id:
@@ -211,7 +215,7 @@ def run_sync(
211215
observer.result = _output
212216

213217
return res
214-
218+
215219
@classmethod
216220
def run_streamed(
217221
cls,
@@ -225,15 +229,14 @@ def run_streamed(
225229
previous_response_id: Optional[str] = None,
226230
conversation_id: Optional[str] = None,
227231
session: Optional[Session] = None,
228-
229232
metrics: Optional[List[BaseMetric]] = None,
230233
metric_collection: Optional[str] = None,
231234
name: Optional[str] = None,
232235
tags: Optional[List[str]] = None,
233236
metadata: Optional[dict] = None,
234237
thread_id: Optional[str] = None,
235238
user_id: Optional[str] = None,
236-
**kwargs, # backwards compatibility
239+
**kwargs, # backwards compatibility
237240
) -> RunResultStreaming:
238241
is_agents_available()
239242
# Manually enter observer; we'll exit when streaming finishes
@@ -271,7 +274,7 @@ def run_streamed(
271274
previous_response_id=previous_response_id,
272275
conversation_id=conversation_id,
273276
session=session,
274-
**kwargs, # backwards compatibility
277+
**kwargs, # backwards compatibility
275278
)
276279

277280
# Runtime-patch stream_events so the observer closes only after streaming completes
@@ -290,10 +293,11 @@ async def _patched_stream_events(self: RunResultStreaming):
290293
observer.__exit__(None, None, None)
291294

292295
from types import MethodType as _MethodType
296+
293297
res.stream_events = _MethodType(_patched_stream_events, res)
294298

295299
return res
296-
300+
297301

298302
def update_trace_attributes(
299303
input: Any = None,
@@ -324,4 +328,4 @@ def update_trace_attributes(
324328
if metric_collection:
325329
current_trace.metric_collection = metric_collection
326330
if metrics:
327-
current_trace.metrics = metrics
331+
current_trace.metrics = metrics

0 commit comments

Comments
 (0)