Skip to content

Commit 31ea08e

Browse files
committed
.
1 parent 8de1117 commit 31ea08e

4 files changed

Lines changed: 99 additions & 8 deletions

File tree

deepeval/openai_agents/extractors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
SpanData,
3636
HandoffSpanData,
3737
CustomSpanData,
38+
GuardrailSpanData,
3839
)
3940
openai_agents_available = True
4041
except ImportError:
@@ -66,6 +67,8 @@ def custom_update_span_attributes(span: BaseSpan, span_data: SpanData):
6667
update_attributes_from_handoff_span_data(span, span_data)
6768
elif isinstance(span_data, CustomSpanData):
6869
update_attributes_from_custom_span_data(span, span_data)
70+
elif isinstance(span_data, GuardrailSpanData):
71+
update_attributes_from_guardrail_span_data(span, span_data)
6972

7073
########################################################
7174
### LLM Span ###########################################
@@ -213,6 +216,14 @@ def update_attributes_from_custom_span_data(
213216
span.name = custom_span_data.name
214217
span.metadata = {"data": custom_span_data.data}
215218

219+
def update_attributes_from_guardrail_span_data(
220+
span: BaseSpan,
221+
guardrail_span_data: GuardrailSpanData
222+
):
223+
# Update Span
224+
span.name = "Guardrail: " + guardrail_span_data.name
225+
span.metadata = {"data": guardrail_span_data.triggered, "type": guardrail_span_data.type}
226+
216227
########################################################
217228
### Parse Input Utils ##################################
218229
########################################################
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import json
5+
6+
from pydantic import BaseModel, Field
7+
8+
from agents import (
9+
Agent,
10+
GuardrailFunctionOutput,
11+
OutputGuardrailTripwireTriggered,
12+
RunContextWrapper,
13+
Runner,
14+
output_guardrail,
15+
)
16+
17+
"""
18+
This example shows how to use output guardrails.
19+
20+
Output guardrails are checks that run on the final output of an agent.
21+
They can be used to do things like:
22+
- Check if the output contains sensitive data
23+
- Check if the output is a valid response to the user's message
24+
25+
In this example, we'll use a (contrived) example where we check if the agent's response contains
26+
a phone number.
27+
"""
28+
29+
30+
# The agent's output type
31+
class MessageOutput(BaseModel):
32+
reasoning: str = Field(description="Thoughts on how to respond to the user's message")
33+
response: str = Field(description="The response to the user's message")
34+
user_name: str | None = Field(description="The name of the user who sent the message, if known")
35+
36+
37+
@output_guardrail
38+
async def sensitive_data_check(
39+
context: RunContextWrapper, agent: Agent, output: MessageOutput
40+
) -> GuardrailFunctionOutput:
41+
phone_number_in_response = "650" in output.response
42+
phone_number_in_reasoning = "650" in output.reasoning
43+
44+
return GuardrailFunctionOutput(
45+
output_info={
46+
"phone_number_in_response": phone_number_in_response,
47+
"phone_number_in_reasoning": phone_number_in_reasoning,
48+
},
49+
tripwire_triggered=phone_number_in_response or phone_number_in_reasoning,
50+
)
51+
52+
53+
agent = Agent(
54+
name="Assistant",
55+
instructions="You are a helpful assistant.",
56+
output_type=MessageOutput,
57+
output_guardrails=[sensitive_data_check],
58+
)
59+
60+
61+
async def output_guardrails_agent():
62+
# This should be ok
63+
await Runner.run(agent, "What's the capital of California?")
64+
print("First message passed")
65+
66+
# This should trip the guardrail
67+
try:
68+
result = await Runner.run(
69+
agent, "My phone number is 650-123-4567. Where do you think I live?"
70+
)
71+
print(
72+
f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}"
73+
)
74+
75+
except OutputGuardrailTripwireTriggered as e:
76+
print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}")
77+

tests/integrations/openai_agents/streaming_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ async def streaming_agent():
77
name="Joker",
88
instructions="You are a helpful assistant.",
99
)
10-
1110
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
1211
async for event in result.stream_events():
1312
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):

tests/integrations/openai_agents/test_openai_agent.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# Import agents
1212
from tests.integrations.openai_agents.streaming_guardrails_agent import streaming_guardrails_agent
13+
from tests.integrations.openai_agents.output_guardrails_agent import output_guardrails_agent
1314
from tests.integrations.openai_agents.code_interpreter_agent import code_interpreter_agent
1415
from tests.integrations.openai_agents.customer_service_agent import customer_service_agent
1516
from tests.integrations.openai_agents.streaming_agent import streaming_agent
@@ -21,17 +22,20 @@
2122

2223
# Run agents
2324
if __name__ == "__main__":
24-
# asyncio.run(customer_service_agent())
25-
# asyncio.run(research_agent())
2625
# if not shutil.which("uvx"):
2726
# raise RuntimeError("uvx is not installed. Please install it with `pip install uvx`.")
2827
# asyncio.run(git_agent())
28+
29+
# asyncio.run(customer_service_agent())
30+
# asyncio.run(research_agent())
2931
# asyncio.run(code_interpreter_agent())
3032
# asyncio.run(remote_agent())
3133
# asyncio.run(streaming_agent())
34+
# asyncio.run(streaming_guardrails_agent())
35+
asyncio.run(output_guardrails_agent())
3236

33-
# Run streaming agent 10 times
34-
async def gather_streaming_agents():
35-
tasks = [streaming_agent() for _ in range(10)]
36-
await asyncio.gather(*tasks)
37-
asyncio.run(gather_streaming_agents())
37+
## Run streaming agent 10 times
38+
# async def gather_streaming_agents():
39+
# tasks = [streaming_agent() for _ in range(10)]
40+
# await asyncio.gather(*tasks)
41+
# asyncio.run(gather_streaming_agents())

0 commit comments

Comments
 (0)