Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
# ---------------------------------------------------------
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

from agent_framework import AgentProtocol, WorkflowBuilder

from azure.ai.agentserver.agentframework._version import VERSION
from azure.ai.agentserver.agentframework._agent_framework import AgentFrameworkCBAgent
from azure.ai.agentserver.agentframework._ai_agent_adapter import AgentFrameworkAIAgentAdapter
from azure.ai.agentserver.agentframework._workflow_agent_adapter import AgentFrameworkWorkflowAdapter
from azure.ai.agentserver.agentframework._foundry_tools import FoundryToolsChatMiddleware
from azure.ai.agentserver.core.application import PackageMetadata, set_current_app

Expand All @@ -15,12 +19,16 @@


def from_agent_framework(
agent,
agent: Union[AgentProtocol, WorkflowBuilder],
credentials: Optional["AsyncTokenCredential"] = None,
**kwargs: Any,
) -> "AgentFrameworkCBAgent":

return AgentFrameworkCBAgent(agent, credentials=credentials, **kwargs)
if isinstance(agent, WorkflowBuilder):
return AgentFrameworkWorkflowAdapter(workflow_builder=agent, credentials=credentials, **kwargs)
if isinstance(agent, AgentProtocol):
return AgentFrameworkAIAgentAdapter(agent, credentials=credentials, **kwargs)
raise TypeError("agent must be an instance of AgentProtocol or WorkflowBuilder")


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import os
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List

from agent_framework import AgentProtocol, AIFunction, CheckpointStorage, InMemoryCheckpointStorage, WorkflowCheckpoint
from agent_framework import AgentProtocol, AIFunction, AgentThread, WorkflowAgent
from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module
from agent_framework._workflows import get_checkpoint_summary
from opentelemetry import trace

from azure.ai.agentserver.core.tools import OAuthConsentRequiredError
Expand Down Expand Up @@ -75,9 +74,6 @@ class AgentFrameworkCBAgent(FoundryCBAgent):

def __init__(self, agent: AgentProtocol,
credentials: "Optional[AsyncTokenCredential]" = None,
*,
thread_repository: AgentThreadRepository = None,
checkpoint_repository: CheckpointRepository = None,
**kwargs: Any,
):
"""Initialize the AgentFrameworkCBAgent with an AgentProtocol or a factory function.
Expand All @@ -93,8 +89,6 @@ def __init__(self, agent: AgentProtocol,
super().__init__(credentials=credentials, **kwargs) # pylint: disable=unexpected-keyword-arg
self._agent: AgentProtocol = agent
self._hitl_helper = HumanInTheLoopHelper()
self._checkpoint_repository = checkpoint_repository
self._thread_repository = thread_repository

@property
def agent(self) -> "AgentProtocol":
Expand Down Expand Up @@ -229,153 +223,35 @@ async def agent_run( # pylint: disable=too-many-statements
OpenAIResponse,
AsyncGenerator[ResponseStreamEvent, Any],
]:
try:
logger.info(f"Starting agent_run with stream={context.stream}")
request_input = context.request.get("input")

agent_thread = None
checkpoint_storage = None
last_checkpoint = None
if self._thread_repository:
agent_thread = await self._thread_repository.get(context.conversation_id)
if agent_thread:
logger.info(f"Loaded agent thread for conversation: {context.conversation_id}")
else:
agent_thread = self.agent.get_new_thread()

if self._checkpoint_repository:
checkpoint_storage = await self._checkpoint_repository.get_or_create(context.conversation_id)
last_checkpoint = await self._get_latest_checkpoint(checkpoint_storage)
if last_checkpoint:
summary = get_checkpoint_summary(last_checkpoint)
if summary.status == "completed":
logger.warning("Last checkpoint is completed. Will not resume from it.")
last_checkpoint = None # Do not resume from completed checkpoints
if last_checkpoint:
await self._load_checkpoint(self.agent, last_checkpoint, checkpoint_storage)
logger.info(f"Loaded checkpoint with ID: {last_checkpoint.checkpoint_id}")

input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper)
message = await input_converter.transform_input(
request_input,
agent_thread=agent_thread,
checkpoint=last_checkpoint)
logger.debug(f"Transformed input message type: {type(message)}")

# Use split converters
if context.stream:
logger.info("Running agent in streaming mode")
streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper)

async def stream_updates():
try:
update_count = 0
try:
updates = self.agent.run_stream(
message,
thread=agent_thread,
checkpoint_storage=checkpoint_storage,
)
async for event in streaming_converter.convert(updates):
update_count += 1
yield event

if agent_thread and self._thread_repository:
await self._thread_repository.set(context.conversation_id, agent_thread, checkpoint_storage)
logger.info(f"Saved agent thread for conversation: {context.conversation_id}")

logger.info("Streaming completed with %d updates", update_count)
except OAuthConsentRequiredError as e:
logger.info("OAuth consent required during streaming updates")
if update_count == 0:
async for event in self.respond_with_oauth_consent_astream(context, e):
yield event
else:
# If we've already emitted events, we cannot safely restart a new
# OAuth-consent stream (it would reset sequence numbers).
yield ResponseErrorEvent(
sequence_number=streaming_converter.next_sequence(),
code="server_error",
message=f"OAuth consent required: {e.consent_url}",
param="agent_run",
)
yield ResponseFailedEvent(
sequence_number=streaming_converter.next_sequence(),
response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True)

# Emit well-formed error events instead of terminating the stream.
yield ResponseErrorEvent(
sequence_number=streaming_converter.next_sequence(),
code="server_error",
message=str(e),
param="agent_run",
)
yield ResponseFailedEvent(
sequence_number=streaming_converter.next_sequence(),
response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access
)
finally:
# No request-scoped resources to clean up here today.
# Keep this block as a hook for future request-scoped cleanup.
pass

return stream_updates()

# Non-streaming path
logger.info("Running agent in non-streaming mode")
non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context)
result = await self.agent.run(
message,
thread=agent_thread,
checkpoint_storage=checkpoint_storage)
logger.debug(f"Agent run completed, result type: {type(result)}")

if agent_thread and self._thread_repository:
await self._thread_repository.set(context.conversation_id, agent_thread)
logger.info(f"Saved agent thread for conversation: {context.conversation_id}")

transformed_result = non_streaming_converter.transform_output_for_response(result)
logger.info("Agent run and transformation completed successfully")
return transformed_result
except OAuthConsentRequiredError as e:
logger.info("OAuth consent required during agent run")
if context.stream:
# Yield OAuth consent response events
# Capture e in the closure by passing it as a default argument
async def oauth_consent_stream(error=e):
async for event in self.respond_with_oauth_consent_astream(context, error):
yield event
return oauth_consent_stream()
return await self.respond_with_oauth_consent(context, e)
finally:
pass

async def _get_latest_checkpoint(self,
checkpoint_storage: CheckpointStorage) -> Optional[Any]:
"""Load the latest checkpoint from the given storage.

:param checkpoint_storage: The checkpoint storage to load from.
:type checkpoint_storage: CheckpointStorage

:return: The latest checkpoint if available, None otherwise.
:rtype: Optional[Any]
raise NotImplementedError("This method is implemented in the base class.")

async def _load_agent_thread(self, context: AgentRunContext, agent: Union[AgentProtocol, WorkflowAgent]) -> Optional[AgentThread]:
"""Load the agent thread for a given conversation ID.

:param context: The agent run context.
:type context: AgentRunContext
:param agent: The agent instance.
:type agent: AgentProtocol | WorkflowAgent

:return: The loaded AgentThread if available, None otherwise.
:rtype: Optional[AgentThread]
"""
checkpoints = await checkpoint_storage.list_checkpoints()
if checkpoints:
latest_checkpoint = max(checkpoints, key=lambda cp: cp.timestamp)
return latest_checkpoint
if self._thread_repository:
agent_thread = await self._thread_repository.get(context.conversation_id)
if agent_thread:
logger.info(f"Loaded agent thread for conversation: {context.conversation_id}")
return agent_thread
return agent.get_new_thread()
return None

async def _load_checkpoint(self, agent: AgentProtocol,
checkpoint: WorkflowCheckpoint,
checkpoint_storage: CheckpointStorage) -> None:
"""Load the checkpoint data from the given WorkflowCheckpoint.
async def _save_agent_thread(self, context: AgentRunContext, agent_thread: AgentThread) -> None:
"""Save the agent thread for a given conversation ID.

:param checkpoint: The WorkflowCheckpoint to load data from.
:type checkpoint: WorkflowCheckpoint
:param context: The agent run context.
:type context: AgentRunContext
:param agent_thread: The agent thread to save.
:type agent_thread: AgentThread
"""
await agent.run(checkpoint_id=checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage)
if agent_thread and self._thread_repository:
await self._thread_repository.set(context.conversation_id, agent_thread)
logger.info(f"Saved agent thread for conversation: {context.conversation_id}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=logging-fstring-interpolation,no-name-in-module,no-member,do-not-import-asyncio
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Union

from agent_framework import AgentProtocol

from azure.ai.agentserver.core import AgentRunContext
from azure.ai.agentserver.core.tools import OAuthConsentRequiredError
from azure.ai.agentserver.core.logger import get_logger
from azure.ai.agentserver.core.models import (
Response as OpenAIResponse,
ResponseStreamEvent,
)
from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent

from .models.agent_framework_input_converters import AgentFrameworkInputConverter
from .models.agent_framework_output_non_streaming_converter import (
AgentFrameworkOutputNonStreamingConverter,
)
from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter
from ._agent_framework import AgentFrameworkCBAgent
from .persistence import AgentThreadRepository

logger = get_logger()

class AgentFrameworkAIAgentAdapter(AgentFrameworkCBAgent):
def __init__(self, agent: AgentProtocol,
*,
thread_repository: Optional[AgentThreadRepository]=None,
**kwargs) -> None:
super().__init__(agent=agent, **kwargs)
self._agent = agent
self._thread_repository = thread_repository

async def agent_run( # pylint: disable=too-many-statements
self, context: AgentRunContext
) -> Union[
OpenAIResponse,
AsyncGenerator[ResponseStreamEvent, Any],
]:
try:
logger.info(f"Starting agent_run with stream={context.stream}")
request_input = context.request.get("input")

agent_thread = self._load_agent_thread(context, self._agent)

input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper)
message = await input_converter.transform_input(
request_input,
agent_thread=agent_thread)
logger.debug(f"Transformed input message type: {type(message)}")

# Use split converters
if context.stream:
logger.info("Running agent in streaming mode")
streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper)

async def stream_updates():
try:
update_count = 0
try:
updates = self.agent.run_stream(
message,
thread=agent_thread,
)
async for event in streaming_converter.convert(updates):
update_count += 1
yield event

await self._save_agent_thread(context, agent_thread)

logger.info("Streaming completed with %d updates", update_count)
except OAuthConsentRequiredError as e:
logger.info("OAuth consent required during streaming updates")
if update_count == 0:
async for event in self.respond_with_oauth_consent_astream(context, e):
yield event
else:
# If we've already emitted events, we cannot safely restart a new
# OAuth-consent stream (it would reset sequence numbers).
yield ResponseErrorEvent(
sequence_number=streaming_converter.next_sequence(),
code="server_error",
message=f"OAuth consent required: {e.consent_url}",
param="agent_run",
)
yield ResponseFailedEvent(
sequence_number=streaming_converter.next_sequence(),
response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True)

# Emit well-formed error events instead of terminating the stream.
yield ResponseErrorEvent(
sequence_number=streaming_converter.next_sequence(),
code="server_error",
message=str(e),
param="agent_run",
)
yield ResponseFailedEvent(
sequence_number=streaming_converter.next_sequence(),
response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access
)
finally:
# No request-scoped resources to clean up here today.
# Keep this block as a hook for future request-scoped cleanup.
pass

return stream_updates()

# Non-streaming path
logger.info("Running agent in non-streaming mode")
result = await self.agent.run(
message,
thread=agent_thread)
logger.debug(f"Agent run completed, result type: {type(result)}")
await self._save_agent_thread(context, agent_thread)

non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper)
transformed_result = non_streaming_converter.transform_output_for_response(result)
logger.info("Agent run and transformation completed successfully")
return transformed_result
except OAuthConsentRequiredError as e:
logger.info("OAuth consent required during agent run")
if context.stream:
# Yield OAuth consent response events
# Capture e in the closure by passing it as a default argument
async def oauth_consent_stream(error=e):
async for event in self.respond_with_oauth_consent_astream(context, error):
yield event
return oauth_consent_stream()
return await self.respond_with_oauth_consent(context, e)
finally:
pass
Loading