diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index 14344a5..5440efd 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -6,7 +6,7 @@ import logging import os import pkgutil -from typing import Any, Dict, Hashable, List, Mapping, NamedTuple +from typing import Any, Dict, Hashable, List, Mapping, NamedTuple, Optional import agent_workflow_server.agents.adapters from agent_workflow_server.agents.oas_generator import generate_agent_oapi @@ -70,7 +70,9 @@ def _read_manifest(path: str) -> AgentACPDescriptor: return AgentACPDescriptor(**manifest_data) -def _resolve_agent(name: str, path: str) -> AgentInfo: +def _resolve_agent( + name: str, path: str, add_manifest_paths: List[str] = [] +) -> AgentInfo: if ":" not in path: raise ValueError( f"""Invalid format for AGENTS_REF environment variable. \ @@ -127,8 +129,7 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: # Load manifest. Check in paths below (in order) manifest_paths = [ os.path.join(os.path.dirname(module.__file__), "manifest.json"), - os.environ.get("AGENT_MANIFEST_PATH", "manifest.json") or "manifest.json", - ] + ] + add_manifest_paths for manifest_path in manifest_paths: manifest = _read_manifest(manifest_path) @@ -150,18 +151,16 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: return AgentInfo(agent=agent, manifest=manifest, schema=schema) -def load_agents(): - # Simulate loading the config from environment variable - +def load_agents(agents_ref: Optional[str] = None, add_manifest_paths: List[str] = []): try: - config: Dict[str, str] = json.loads(os.getenv("AGENTS_REF", {})) + config: Dict[str, str] = json.loads(agents_ref) if agents_ref else {} except json.JSONDecodeError: raise ValueError("""Invalid format for AGENTS_REF environment variable. \ Must be a dictionary of agent_id -> module:var pairs. \ Example: {"agent1": "agent1_module:agent1_var", "agent2": "agent2_module:agent2_var"}""") for agent_id, agent_path in config.items(): try: - agent = _resolve_agent(agent_id, agent_path) + agent = _resolve_agent(agent_id, agent_path, add_manifest_paths) AGENTS[agent_id] = agent logger.info(f"Registered Agent: '{agent_id}'", {"agent_id": agent_id}) except Exception as e: diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index c28d9f8..0cf9b9e 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -3,7 +3,7 @@ # coding: utf-8 -from typing import Any, Dict, List, Optional +from typing import Any, AsyncIterator, Dict, List, Optional, Union from fastapi import ( APIRouter, @@ -15,6 +15,7 @@ Response, status, ) +from fastapi.responses import StreamingResponse from pydantic import Field, StrictBool, StrictStr from typing_extensions import Annotated @@ -34,6 +35,10 @@ from agent_workflow_server.generated.models.run_wait_response_stateless import ( RunWaitResponseStateless, ) +from agent_workflow_server.generated.models.stream_event_payload import ( + StreamEventPayload, +) +from agent_workflow_server.generated.models.streaming_mode import StreamingMode from agent_workflow_server.services.runs import Runs from agent_workflow_server.services.validation import ( InvalidFormatException, @@ -51,6 +56,25 @@ async def _validate_run_create_stateless( if run_create_stateless.agent_id is None: """Pre-process the RunCreateStateless object to set the agent_id if not provided.""" run_create_stateless.agent_id = get_default_agent().agent_id + if run_create_stateless.stream_mode is not None: + # Server only supports VALUES streaming at the moment. + if ( + isinstance(run_create_stateless.stream_mode.actual_instance, List) + and StreamingMode.VALUES + not in run_create_stateless.stream_mode.actual_instance + ): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail='stream mode: "values" required', + ) + elif ( + not isinstance(run_create_stateless.stream_mode.actual_instance, List) + and StreamingMode.VALUES != run_create_stateless.stream_mode.actual_instance + ): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail='stream mode: "values" required', + ) return run_create_stateless @@ -113,6 +137,22 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: ) +async def _stream_sse_events( + stream: AsyncIterator[StreamEventPayload | None], +) -> AsyncIterator[Union[str, bytes]]: + last_event_id = 0 + async for event in stream: + if event is None: + yield ":" + else: + last_event_id += 1 + yield f"""id: {last_event_id} +event: agent_event +data: {event.to_json()} + +""" + + @router.post( "/runs/{run_id}/cancel", responses={ @@ -165,7 +205,21 @@ async def create_and_stream_stateless_run_output( ] = Body(None, description=""), ) -> RunOutputStream: """Create a stateless run and join its output stream. See 'GET /runs/{run_id}/stream' for details on the return values.""" - raise HTTPException(status_code=500, detail="Not implemented") + try: + new_run = await Runs.put(run_create_stateless) + return StreamingResponse( + _stream_sse_events(Runs.stream_events(new_run.run_id)), + media_type="text/event-stream", + ) + except HTTPException: + raise + except TimeoutError: + return Response(status_code=status.HTTP_204_NO_CONTENT) + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Run create error", + ) @router.post( @@ -326,7 +380,26 @@ async def stream_stateless_run_output( ), ) -> RunOutputStream: """Join the output stream of an existing run. This endpoint streams output in real-time from a run. Only output produced after this endpoint is called will be streamed.""" - raise HTTPException(status_code=500, detail="Not implemented") + try: + run = Runs.get(run_id) + if run is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Run with ID {run_id} not found", + ) + return StreamingResponse( + _stream_sse_events(Runs.stream_events(run_id)), + media_type="text/event-stream", + ) + except HTTPException: + raise + except TimeoutError: + return Response(status_code=status.HTTP_204_NO_CONTENT) + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Run with ID {run_id} error", + ) @router.get( diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index 5391fd3..826185d 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -1,6 +1,5 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 - import asyncio import logging import os @@ -8,11 +7,9 @@ import sys import uvicorn -import uvicorn.logging -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv from fastapi import Depends, FastAPI -import agent_workflow_server.logging.logger # noqa: F401 from agent_workflow_server.agents.load import load_agents from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter from agent_workflow_server.apis.agents import router as AgentsApiRouter @@ -23,10 +20,12 @@ from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter from agent_workflow_server.services.queue import start_workers -load_dotenv() +load_dotenv(dotenv_path=find_dotenv(usecwd=True)) DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 +DEFAULT_NUM_WORKERS = 5 +DEFAULT_AGENT_MANIFEST_PATH = "manifest.json" logger = logging.getLogger(__name__) @@ -59,14 +58,20 @@ def start(): try: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - load_agents() - n_workers = int(os.environ.get("NUM_WORKERS", 5)) + + agents_ref = os.getenv("AGENTS_REF", None) + agent_manifest_path = os.getenv( + "AGENT_MANIFEST_PATH", DEFAULT_AGENT_MANIFEST_PATH + ) + load_agents(agents_ref, agent_manifest_path) + n_workers = int(os.environ.get("NUM_WORKERS", DEFAULT_NUM_WORKERS)) loop = asyncio.get_event_loop() loop.create_task(start_workers(n_workers)) + # use module import method to support reload argument config = uvicorn.Config( - app, + "agent_workflow_server.main:app", host=os.getenv("API_HOST", DEFAULT_HOST) or DEFAULT_HOST, port=int(os.getenv("API_PORT", DEFAULT_PORT)) or DEFAULT_PORT, loop="asyncio", diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index 75075f4..a7daa8c 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -101,7 +101,6 @@ async def worker(worker_id: int): last_message = None async for message in stream: message.data = make_serializable(message.data) - await Runs.Stream.publish(run_id, message) last_message = message if last_message.type == "interrupt": log_run( @@ -110,7 +109,10 @@ async def worker(worker_id: int): "interrupted", message_data=json.dumps(message.data), ) + await Runs.Stream.publish(run_id, message) break + else: + await Runs.Stream.publish(run_id, message) except Exception as error: await Runs.Stream.publish( run_id, Message(type="message", data=(str(error))) @@ -134,7 +136,6 @@ async def worker(worker_id: int): ) validate_output(run_id, run["agent_id"], last_message.data) DB.add_run_output(run_id, last_message.data) - await Runs.Stream.publish(run_id, Message(type="control", data="done")) if last_message.type == "interrupt": interrupt = Interrupt( event=last_message.event, ai_data=last_message.data @@ -144,12 +145,13 @@ async def worker(worker_id: int): else: await Runs.set_status(run_id, "success") log_run(worker_id, run_id, "succeeded", **run_stats(run_info)) + await Runs.Stream.publish(run_id, Message(type="control", data="done")) except InvalidFormatException as error: + log_run(worker_id, run_id, "failed") await Runs.Stream.publish( run_id, Message(type="message", data=str(error)) ) - log_run(worker_id, run_id, "failed") raise RunError(str(error)) except AttemptsExceededError: diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index d970b34..56ec23d 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -6,7 +6,7 @@ from collections import defaultdict from datetime import datetime from itertools import islice -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional from uuid import uuid4 from agent_workflow_server.generated.models.run_create_stateless import ( @@ -18,6 +18,12 @@ from agent_workflow_server.generated.models.run_stateless import ( RunStateless as ApiRun, ) +from agent_workflow_server.generated.models.stream_event_payload import ( + StreamEventPayload, +) +from agent_workflow_server.generated.models.value_run_result_update import ( + ValueRunResultUpdate, +) from agent_workflow_server.storage.models import Run, RunInfo, RunStatus from agent_workflow_server.storage.storage import DB @@ -247,6 +253,37 @@ async def wait_for_output(run_id: str, timeout: float = None): return None, None + @staticmethod + async def stream_events(run_id: str) -> AsyncIterator[StreamEventPayload | None]: + async for message in Runs.Stream.join(run_id): + msg_data = message.data + + if message.type == "control": + if message.data == "done": + break + elif message.data == "timeout": + yield None + continue + else: + logger.error( + f'received unknown control message "{message.data}" in stream events for run: {run_id}' + ) + continue + + # We need to get the latest value to return + run = DB.get_run(run_id) + if run is None: + raise ValueError(f"Run {run_id} not found") + + yield StreamEventPayload( + ValueRunResultUpdate( + type="values", + run_id=run["run_id"], + status=run["status"], + values=msg_data, + ) + ) + class Stream: @staticmethod async def publish(run_id: str, message: Message) -> None: @@ -263,11 +300,21 @@ async def join( run_id: str, ) -> AsyncGenerator[Message, None]: queue = await Runs.Stream.subscribe(run_id) + + # Check after subscribe whether the run is completed to + # avoid race condition. + run = DB.get_run(run_id) + if run is None: + raise ValueError(f"Run {run_id} not found") + if run["status"] != "pending" and queue.empty(): + return + while True: try: - message: Message = await asyncio.wait_for(queue.get(), timeout=1) - if message.topic == "control" and message.data == "done": - break + message: Message = await asyncio.wait_for(queue.get(), timeout=10) yield message + if message.type == "control" and message.data == "done": + break except TimeoutError as error: logger.error(f"Timeout waiting for run {run_id}: {error}") + yield Message(type="control", data="timeout") diff --git a/tests/test_load.py b/tests/test_load.py index 38a60dd..f8f6778 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,5 +1,6 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 +import os import pytest from pytest_mock import MockerFixture @@ -21,10 +22,16 @@ ) +def _env_load_agents(): + agent_ref = os.getenv("AGENTS_REF") + manifest_path = os.getenv("AGENT_MANIFEST_PATH") + load_agents(agents_ref=agent_ref, add_manifest_paths=[manifest_path]) + + def test_load_agents(mocker: MockerFixture): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 assert isinstance(AGENTS[MOCK_AGENT_ID].agent, MockAgent) @@ -37,7 +44,7 @@ def test_load_agents(mocker: MockerFixture): def test_get_agent_info(mocker: MockerFixture, agent_id: str, expected: bool): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 @@ -55,7 +62,7 @@ def test_get_agent_info(mocker: MockerFixture, agent_id: str, expected: bool): def test_get_agent(mocker: MockerFixture, agent_id: str, expected: bool): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 @@ -87,7 +94,7 @@ def test_search_agents( ): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1