Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -576,15 +576,19 @@ async def parse_agent_output(
output.tool_calls.extend(cur_tool_calls) # type: ignore
await ctx.store.set("current_tool_calls", [])

if self.structured_output_fn is not None:
structured_output_fn = self.structured_output_fn
output_cls = self.output_cls
if output_cls is None and structured_output_fn is None:
structured_output_fn = agent.structured_output_fn
output_cls = agent.output_cls

if structured_output_fn is not None:
try:
if inspect.iscoroutinefunction(self.structured_output_fn):
output.structured_response = await self.structured_output_fn(
messages
)
if inspect.iscoroutinefunction(structured_output_fn):
output.structured_response = await structured_output_fn(messages)
else:
output.structured_response = cast(
Dict[str, Any], self.structured_output_fn(messages)
Dict[str, Any], structured_output_fn(messages)
)
ctx.write_event_to_stream(
AgentStreamStructuredOutput(output=output.structured_response)
Expand All @@ -594,7 +598,7 @@ async def parse_agent_output(
f"There was a problem with the generation of the structured output: {e}",
stacklevel=2,
)
if self.output_cls is not None:
if output_cls is not None:
try:
llm_input = [*messages]
if agent.system_prompt:
Expand All @@ -603,7 +607,7 @@ async def parse_agent_output(
*llm_input,
]
output.structured_response = await generate_structured_response(
messages=llm_input, llm=agent.llm, output_cls=self.output_cls
messages=llm_input, llm=agent.llm, output_cls=output_cls
)
ctx.write_event_to_stream(
AgentStreamStructuredOutput(output=output.structured_response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,26 @@ async def test_structured_output_fn_agentworkflow(
assert response.get_pydantic_model(Structure) == Structure(hello="bonjour", world=2)


@pytest.mark.asyncio
async def test_agent_structured_output_fn_agentworkflow(
function_agent_struct_fn: FunctionAgent,
) -> None:
wf = AgentWorkflow(
agents=[function_agent_struct_fn],
root_agent=function_agent_struct_fn.name,
)
handler = wf.run(user_msg="test")
streaming_event = False
async for event in handler.stream_events():
if isinstance(event, AgentStreamStructuredOutput):
streaming_event = True

response = await handler
assert streaming_event
assert "Success with the FunctionAgent" in str(response.response)
assert response.get_pydantic_model(Structure) == Structure(hello="bonjour", world=2)


@pytest.mark.asyncio
async def test_astructured_output_fn_agentworkflow(
function_agent_output_cls: FunctionAgent,
Expand Down