Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -576,15 +576,21 @@ async def parse_agent_output(
output.tool_calls.extend(cur_tool_calls) # type: ignore
await ctx.store.set("current_tool_calls", [])

if self.structured_output_fn is not None:
structured_output_fn = self.structured_output_fn
output_cls = self.output_cls
if structured_output_fn is None and output_cls is None:
structured_output_fn = agent.structured_output_fn
output_cls = agent.output_cls

if structured_output_fn is not None:
try:
if inspect.iscoroutinefunction(self.structured_output_fn):
output.structured_response = await self.structured_output_fn(
if inspect.iscoroutinefunction(structured_output_fn):
output.structured_response = await structured_output_fn(
messages
)
else:
output.structured_response = cast(
Dict[str, Any], self.structured_output_fn(messages)
Dict[str, Any], structured_output_fn(messages)
)
ctx.write_event_to_stream(
AgentStreamStructuredOutput(output=output.structured_response)
Expand All @@ -594,7 +600,7 @@ async def parse_agent_output(
f"There was a problem with the generation of the structured output: {e}",
stacklevel=2,
)
if self.output_cls is not None:
if output_cls is not None:
try:
llm_input = [*messages]
if agent.system_prompt:
Expand All @@ -603,7 +609,7 @@ async def parse_agent_output(
*llm_input,
]
output.structured_response = await generate_structured_response(
messages=llm_input, llm=agent.llm, output_cls=self.output_cls
messages=llm_input, llm=agent.llm, output_cls=output_cls
)
ctx.write_event_to_stream(
AgentStreamStructuredOutput(output=output.structured_response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,26 @@ async def test_structured_output_agentworkflow(
assert response.get_pydantic_model(Structure) == Structure(hello="hello", world=1)


@pytest.mark.asyncio
async def test_agentworkflow_uses_agent_output_cls(
function_agent_output_cls: FunctionAgent,
) -> None:
wf = AgentWorkflow(
agents=[function_agent_output_cls],
root_agent=function_agent_output_cls.name,
)
handler = wf.run(user_msg="test")
streaming_event = False
async for event in handler.stream_events():
if isinstance(event, AgentStreamStructuredOutput):
streaming_event = True

response = await handler
assert streaming_event
assert "Success with the FunctionAgent" in str(response.response)
assert response.get_pydantic_model(Structure) == Structure(hello="hello", world=1)


@pytest.mark.asyncio
async def test_structured_output_fn_agentworkflow(
function_agent_output_cls: FunctionAgent,
Expand All @@ -265,6 +285,26 @@ async def test_structured_output_fn_agentworkflow(
assert response.get_pydantic_model(Structure) == Structure(hello="bonjour", world=2)


@pytest.mark.asyncio
async def test_agentworkflow_uses_agent_structured_output_fn(
function_agent_struct_fn: FunctionAgent,
) -> None:
wf = AgentWorkflow(
agents=[function_agent_struct_fn],
root_agent=function_agent_struct_fn.name,
)
handler = wf.run(user_msg="test")
streaming_event = False
async for event in handler.stream_events():
if isinstance(event, AgentStreamStructuredOutput):
streaming_event = True

response = await handler
assert streaming_event
assert "Success with the FunctionAgent" in str(response.response)
assert response.get_pydantic_model(Structure) == Structure(hello="bonjour", world=2)


@pytest.mark.asyncio
async def test_astructured_output_fn_agentworkflow(
function_agent_output_cls: FunctionAgent,
Expand Down