From c7fdb06612a567a1ea43dc880800dfd1e3a3ecf6 Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:28:50 +0200 Subject: [PATCH 1/8] feat: add streaming support. Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- src/agent_workflow_server/agents/load.py | 13 ++-- .../apis/stateless_runs.py | 18 ++++- src/agent_workflow_server/main.py | 67 ++++++++++++------- src/agent_workflow_server/services/runs.py | 34 +++++++++- 4 files changed, 96 insertions(+), 36 deletions(-) diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index 14344a5..a528f8a 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -6,7 +6,7 @@ import logging import os import pkgutil -from typing import Any, Dict, Hashable, List, Mapping, NamedTuple +from typing import Any, Dict, Hashable, List, Mapping, NamedTuple, Optional import agent_workflow_server.agents.adapters from agent_workflow_server.agents.oas_generator import generate_agent_oapi @@ -70,7 +70,7 @@ def _read_manifest(path: str) -> AgentACPDescriptor: return AgentACPDescriptor(**manifest_data) -def _resolve_agent(name: str, path: str) -> AgentInfo: +def _resolve_agent(name: str, path: str, add_manifest_paths: List[str] = []) -> AgentInfo: if ":" not in path: raise ValueError( f"""Invalid format for AGENTS_REF environment variable. \ @@ -127,8 +127,7 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: # Load manifest. Check in paths below (in order) manifest_paths = [ os.path.join(os.path.dirname(module.__file__), "manifest.json"), - os.environ.get("AGENT_MANIFEST_PATH", "manifest.json") or "manifest.json", - ] + ] + add_manifest_paths for manifest_path in manifest_paths: manifest = _read_manifest(manifest_path) @@ -150,18 +149,18 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: return AgentInfo(agent=agent, manifest=manifest, schema=schema) -def load_agents(): +def load_agents(agents_ref: Optional[str] = None, add_manifest_paths: List[str] = []): # Simulate loading the config from environment variable try: - config: Dict[str, str] = json.loads(os.getenv("AGENTS_REF", {})) + config: Dict[str, str] = json.loads(agents_ref) if agents_ref else {} except json.JSONDecodeError: raise ValueError("""Invalid format for AGENTS_REF environment variable. \ Must be a dictionary of agent_id -> module:var pairs. \ Example: {"agent1": "agent1_module:agent1_var", "agent2": "agent2_module:agent2_var"}""") for agent_id, agent_path in config.items(): try: - agent = _resolve_agent(agent_id, agent_path) + agent = _resolve_agent(agent_id, agent_path, add_manifest_paths) AGENTS[agent_id] = agent logger.info(f"Registered Agent: '{agent_id}'", {"agent_id": agent_id}) except Exception as e: diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index c28d9f8..ef22782 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -15,6 +15,7 @@ Response, status, ) +from fastapi.responses import StreamingResponse from pydantic import Field, StrictBool, StrictStr from typing_extensions import Annotated @@ -165,7 +166,14 @@ async def create_and_stream_stateless_run_output( ] = Body(None, description=""), ) -> RunOutputStream: """Create a stateless run and join its output stream. See 'GET /runs/{run_id}/stream' for details on the return values.""" - raise HTTPException(status_code=500, detail="Not implemented") + try: + new_run = await Runs.put(run_create_stateless) + return StreamingResponse(Runs.stream_events(new_run.run_id), media_type="text/event-stream") + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, # FIXME: use exception type to signal status + detail=f"Run create stream error: {str(exc)}", + ) @router.post( @@ -326,7 +334,13 @@ async def stream_stateless_run_output( ), ) -> RunOutputStream: """Join the output stream of an existing run. This endpoint streams output in real-time from a run. Only output produced after this endpoint is called will be streamed.""" - raise HTTPException(status_code=500, detail="Not implemented") + try: + return StreamingResponse(Runs.stream_events(run_id), media_type="text/event-stream") + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, # FIXME: use exception type to signal status + detail=f"Run with ID {run_id} error: {str(exc)}", + ) @router.get( diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index 5391fd3..d54f939 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -1,9 +1,11 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 - +import argparse import asyncio +import json import logging import os +import pathlib import signal import sys @@ -30,46 +32,61 @@ logger = logging.getLogger(__name__) -app = FastAPI( - title="Agent Workflow Server", - version="0.1", -) - -setup_api_key_auth(app) - -app.include_router( - router=AgentsApiRouter, - dependencies=[Depends(authentication_with_api_key)], -) -app.include_router( - router=PublicAgentsApiRouter, -) -app.include_router( - router=StatelessRunsApiRouter, - dependencies=[Depends(authentication_with_api_key)], -) - +def create_app() -> FastAPI: + app = FastAPI( + title="Agent Workflow Server", + version="0.1", + ) + + setup_api_key_auth(app) + + app.include_router( + router=AgentsApiRouter, + dependencies=[Depends(authentication_with_api_key)], + ) + app.include_router( + router=PublicAgentsApiRouter, + ) + app.include_router( + router=StatelessRunsApiRouter, + dependencies=[Depends(authentication_with_api_key)], + ) + return app def signal_handler(sig, frame): logger.warning(f"Received {signal.Signals(sig).name}. Exiting...") sys.exit(0) +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Agent Workflow Server") + parser.add_argument('--host', default=os.getenv("API_HOST", DEFAULT_HOST)) + parser.add_argument('--port', type=int, default=int(os.getenv("API_PORT", DEFAULT_PORT))) + parser.add_argument('--num-workers', type=int, default=int(os.environ.get("NUM_WORKERS", 5))) + parser.add_argument('--agent-manifest-path', action='append', type=pathlib.Path, default=[os.getenv('AGENT_MANIFEST_PATH', "manifest.json")]) + parser.add_argument('--agents-ref', default=os.getenv('AGENTS_REF', None)) + parser.add_argument('--reload', action='store_true') + #parser.add_argument('--log-level', default=logging.DEBUG) + return parser.parse_args() def start(): try: + args = parse_args() + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - load_agents() - n_workers = int(os.environ.get("NUM_WORKERS", 5)) + + load_agents(args.agents_ref, args.agent_manifest_path) + n_workers = args.num_workers loop = asyncio.get_event_loop() loop.create_task(start_workers(n_workers)) config = uvicorn.Config( - app, - host=os.getenv("API_HOST", DEFAULT_HOST) or DEFAULT_HOST, - port=int(os.getenv("API_PORT", DEFAULT_PORT)) or DEFAULT_PORT, + create_app(), + host=args.host, + port=args.port, loop="asyncio", + reload=args.reload, ) server = uvicorn.Server(config) loop.run_until_complete(server.serve()) diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index d970b34..fde593b 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -6,9 +6,15 @@ from collections import defaultdict from datetime import datetime from itertools import islice -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, AsyncIterator, Union from uuid import uuid4 +from agent_workflow_server.generated.models.run_output_stream import ( + RunOutputStream, +) +from agent_workflow_server.generated.models.value_run_result_update import ( + ValueRunResultUpdate, +) from agent_workflow_server.generated.models.run_create_stateless import ( RunCreateStateless as ApiRunCreate, ) @@ -247,6 +253,26 @@ async def wait_for_output(run_id: str, timeout: float = None): return None, None + @staticmethod + async def stream_events(run_id: str) -> AsyncIterator[Union[str,bytes]]: + async for message in Runs.Stream.join(run_id): + if message.type == "control": + continue + run = DB.get_run(run_id) + if run is None: + raise ValueError(f"Run {run_id} not found") + ret_obj = RunOutputStream( + id=str(uuid4()), # fresh event id + event="agent_event", + data={ + "type": "values", + "run_id": run["run_id"], + "status": run["status"], + "values": message.data, + }, + ) + yield ret_obj.model_dump_json(exclude_none=True) + class Stream: @staticmethod async def publish(run_id: str, message: Message) -> None: @@ -262,11 +288,15 @@ async def subscribe(run_id: str) -> asyncio.Queue: async def join( run_id: str, ) -> AsyncGenerator[Message, None]: + run = DB.get_run(run_id) + if run is None: + raise ValueError(f"Run {run_id} not found") + queue = await Runs.Stream.subscribe(run_id) while True: try: message: Message = await asyncio.wait_for(queue.get(), timeout=1) - if message.topic == "control" and message.data == "done": + if message.type == "control" and message.data == "done": break yield message except TimeoutError as error: From 33ce6647ce3f7f6b667efb59cd05fa0561b69940 Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Thu, 24 Apr 2025 15:37:18 +0200 Subject: [PATCH 2/8] chore: updated streaming support Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- .../agents/adapters/langgraph.py | 5 ++ .../apis/stateless_runs.py | 45 +++++++++++++--- src/agent_workflow_server/main.py | 51 +++++++++---------- src/agent_workflow_server/services/queue.py | 7 +-- src/agent_workflow_server/services/runs.py | 48 +++++++++++------ 5 files changed, 103 insertions(+), 53 deletions(-) diff --git a/src/agent_workflow_server/agents/adapters/langgraph.py b/src/agent_workflow_server/agents/adapters/langgraph.py index e90f7d1..1bd139f 100644 --- a/src/agent_workflow_server/agents/adapters/langgraph.py +++ b/src/agent_workflow_server/agents/adapters/langgraph.py @@ -8,6 +8,7 @@ from langgraph.graph.graph import CompiledGraph, Graph from langgraph.graph.state import CompiledStateGraph from langgraph.types import Command +from langgraph.graph import StateGraph from agent_workflow_server.agents.base import BaseAdapter, BaseAgent from agent_workflow_server.services.message import Message @@ -47,6 +48,10 @@ async def astream(self, run: Run): if "interrupt" in run and "user_data" in run["interrupt"]: input = Command(resume=run["interrupt"]["user_data"]) + # If we have a StateGraph, we can validate the schema + if isinstance(self.agent.builder, StateGraph): + input = self.agent.builder.schema.model_validate(input) + async for event in self.agent.astream( input=input, config=RunnableConfig( diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index ef22782..1c84c75 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -3,7 +3,7 @@ # coding: utf-8 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, AsyncIterator, Union from fastapi import ( APIRouter, @@ -36,6 +36,7 @@ RunWaitResponseStateless, ) from agent_workflow_server.services.runs import Runs +from agent_workflow_server.generated.models.stream_event_payload import StreamEventPayload from agent_workflow_server.services.validation import ( InvalidFormatException, ) @@ -114,6 +115,20 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: ) +async def __stream_sse_events(stream: AsyncIterator[StreamEventPayload | None]) -> AsyncIterator[Union[str,bytes]]: + last_event_id = 0 + async for event in stream: + if event is None: + yield ":" + else: + last_event_id += 1 + yield f"""id: {last_event_id} +event: agent_event +data: {event.to_json()} + +""" + + @router.post( "/runs/{run_id}/cancel", responses={ @@ -168,11 +183,15 @@ async def create_and_stream_stateless_run_output( """Create a stateless run and join its output stream. See 'GET /runs/{run_id}/stream' for details on the return values.""" try: new_run = await Runs.put(run_create_stateless) - return StreamingResponse(Runs.stream_events(new_run.run_id), media_type="text/event-stream") + return StreamingResponse(__stream_sse_events(Runs.stream_events(new_run.run_id)), media_type="text/event-stream") + except HTTPException: + raise + except TimeoutError as terr: + return Response(status_code=status.HTTP_204_NO_CONTENT) except Exception as exc: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, # FIXME: use exception type to signal status - detail=f"Run create stream error: {str(exc)}", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Run create error", ) @@ -335,11 +354,21 @@ async def stream_stateless_run_output( ) -> RunOutputStream: """Join the output stream of an existing run. This endpoint streams output in real-time from a run. Only output produced after this endpoint is called will be streamed.""" try: - return StreamingResponse(Runs.stream_events(run_id), media_type="text/event-stream") - except Exception as exc: + run = Runs.get(run_id) + if run is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Run with ID {run_id} not found", + ) + return StreamingResponse(__stream_sse_events(Runs.stream_events(run_id)), media_type="text/event-stream") + except HTTPException: + raise + except TimeoutError: + return Response(status_code=status.HTTP_204_NO_CONTENT) + except Exception: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, # FIXME: use exception type to signal status - detail=f"Run with ID {run_id} error: {str(exc)}", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Run with ID {run_id} error", ) diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index d54f939..abb7681 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import argparse import asyncio -import json import logging import os import pathlib @@ -10,11 +9,9 @@ import sys import uvicorn -import uvicorn.logging -from dotenv import load_dotenv +from dotenv import load_dotenv, find_dotenv from fastapi import Depends, FastAPI -import agent_workflow_server.logging.logger # noqa: F401 from agent_workflow_server.agents.load import load_agents from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter from agent_workflow_server.apis.agents import router as AgentsApiRouter @@ -25,33 +22,31 @@ from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter from agent_workflow_server.services.queue import start_workers -load_dotenv() +load_dotenv(dotenv_path=find_dotenv(usecwd=True)) DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 logger = logging.getLogger(__name__) -def create_app() -> FastAPI: - app = FastAPI( - title="Agent Workflow Server", - version="0.1", - ) - - setup_api_key_auth(app) - - app.include_router( - router=AgentsApiRouter, - dependencies=[Depends(authentication_with_api_key)], - ) - app.include_router( - router=PublicAgentsApiRouter, - ) - app.include_router( - router=StatelessRunsApiRouter, - dependencies=[Depends(authentication_with_api_key)], - ) - return app +app = FastAPI( + title="Agent Workflow Server", + version="0.1", +) + +setup_api_key_auth(app) + +app.include_router( + router=AgentsApiRouter, + dependencies=[Depends(authentication_with_api_key)], +) +app.include_router( + router=PublicAgentsApiRouter, +) +app.include_router( + router=StatelessRunsApiRouter, + dependencies=[Depends(authentication_with_api_key)], +) def signal_handler(sig, frame): logger.warning(f"Received {signal.Signals(sig).name}. Exiting...") @@ -65,12 +60,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--agent-manifest-path', action='append', type=pathlib.Path, default=[os.getenv('AGENT_MANIFEST_PATH', "manifest.json")]) parser.add_argument('--agents-ref', default=os.getenv('AGENTS_REF', None)) parser.add_argument('--reload', action='store_true') - #parser.add_argument('--log-level', default=logging.DEBUG) + parser.add_argument('--log-level', default=os.environ.get("NUM_WORKERS", logging.INFO)) return parser.parse_args() def start(): try: args = parse_args() + logging.basicConfig(level=args.log_level.upper()) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) @@ -81,8 +77,9 @@ def start(): loop = asyncio.get_event_loop() loop.create_task(start_workers(n_workers)) + # use module import method to support reload argument config = uvicorn.Config( - create_app(), + "agent_workflow_server.main:app", host=args.host, port=args.port, loop="asyncio", diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index 75075f4..58b4d48 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -101,7 +101,6 @@ async def worker(worker_id: int): last_message = None async for message in stream: message.data = make_serializable(message.data) - await Runs.Stream.publish(run_id, message) last_message = message if last_message.type == "interrupt": log_run( @@ -110,6 +109,8 @@ async def worker(worker_id: int): "interrupted", message_data=json.dumps(message.data), ) + await Runs.Stream.publish(run_id, message) + if last_message.type == "interrupt": break except Exception as error: await Runs.Stream.publish( @@ -134,7 +135,6 @@ async def worker(worker_id: int): ) validate_output(run_id, run["agent_id"], last_message.data) DB.add_run_output(run_id, last_message.data) - await Runs.Stream.publish(run_id, Message(type="control", data="done")) if last_message.type == "interrupt": interrupt = Interrupt( event=last_message.event, ai_data=last_message.data @@ -144,12 +144,13 @@ async def worker(worker_id: int): else: await Runs.set_status(run_id, "success") log_run(worker_id, run_id, "succeeded", **run_stats(run_info)) + await Runs.Stream.publish(run_id, Message(type="control", data="done")) except InvalidFormatException as error: + log_run(worker_id, run_id, "failed") await Runs.Stream.publish( run_id, Message(type="message", data=str(error)) ) - log_run(worker_id, run_id, "failed") raise RunError(str(error)) except AttemptsExceededError: diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index fde593b..d18ee58 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -24,6 +24,9 @@ from agent_workflow_server.generated.models.run_stateless import ( RunStateless as ApiRun, ) +from agent_workflow_server.generated.models.stream_event_payload import ( + StreamEventPayload +) from agent_workflow_server.storage.models import Run, RunInfo, RunStatus from agent_workflow_server.storage.storage import DB @@ -254,24 +257,33 @@ async def wait_for_output(run_id: str, timeout: float = None): return None, None @staticmethod - async def stream_events(run_id: str) -> AsyncIterator[Union[str,bytes]]: + async def stream_events(run_id: str) -> AsyncIterator[StreamEventPayload | None]: async for message in Runs.Stream.join(run_id): + msg_data = message.data + if message.type == "control": - continue + if message.data == "done": + break + elif message.data == "timeout": + yield None + continue + else: + logger.error(f"received control message \"{message.data}\" in stream events for run: {run_id}") + continue + + # We need to get the latest value to return run = DB.get_run(run_id) if run is None: raise ValueError(f"Run {run_id} not found") - ret_obj = RunOutputStream( - id=str(uuid4()), # fresh event id - event="agent_event", - data={ - "type": "values", - "run_id": run["run_id"], - "status": run["status"], - "values": message.data, - }, + + yield StreamEventPayload( + ValueRunResultUpdate( + type="values", + run_id=run["run_id"], + status=run["status"], + values=msg_data, + ) ) - yield ret_obj.model_dump_json(exclude_none=True) class Stream: @staticmethod @@ -288,16 +300,22 @@ async def subscribe(run_id: str) -> asyncio.Queue: async def join( run_id: str, ) -> AsyncGenerator[Message, None]: + queue = await Runs.Stream.subscribe(run_id) + + # Check after subscribe whether the run is completed to + # avoid race condition. run = DB.get_run(run_id) if run is None: raise ValueError(f"Run {run_id} not found") + if run["status"] != "pending" and queue.empty(): + return - queue = await Runs.Stream.subscribe(run_id) while True: try: - message: Message = await asyncio.wait_for(queue.get(), timeout=1) + message: Message = await asyncio.wait_for(queue.get(), timeout=10) + yield message if message.type == "control" and message.data == "done": break - yield message except TimeoutError as error: logger.error(f"Timeout waiting for run {run_id}: {error}") + yield Message(type="control", data="timeout") From 753595273cf81d68d1c2471cb37709177341b00c Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Thu, 24 Apr 2025 15:43:40 +0200 Subject: [PATCH 3/8] fix: fix formatting and linting errors. Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- .../agents/adapters/langgraph.py | 5 ---- src/agent_workflow_server/agents/load.py | 4 ++- .../apis/stateless_runs.py | 26 +++++++++++----- src/agent_workflow_server/main.py | 30 ++++++++++++++----- src/agent_workflow_server/services/runs.py | 23 +++++++------- 5 files changed, 54 insertions(+), 34 deletions(-) diff --git a/src/agent_workflow_server/agents/adapters/langgraph.py b/src/agent_workflow_server/agents/adapters/langgraph.py index 1bd139f..e90f7d1 100644 --- a/src/agent_workflow_server/agents/adapters/langgraph.py +++ b/src/agent_workflow_server/agents/adapters/langgraph.py @@ -8,7 +8,6 @@ from langgraph.graph.graph import CompiledGraph, Graph from langgraph.graph.state import CompiledStateGraph from langgraph.types import Command -from langgraph.graph import StateGraph from agent_workflow_server.agents.base import BaseAdapter, BaseAgent from agent_workflow_server.services.message import Message @@ -48,10 +47,6 @@ async def astream(self, run: Run): if "interrupt" in run and "user_data" in run["interrupt"]: input = Command(resume=run["interrupt"]["user_data"]) - # If we have a StateGraph, we can validate the schema - if isinstance(self.agent.builder, StateGraph): - input = self.agent.builder.schema.model_validate(input) - async for event in self.agent.astream( input=input, config=RunnableConfig( diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index a528f8a..01cb06b 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -70,7 +70,9 @@ def _read_manifest(path: str) -> AgentACPDescriptor: return AgentACPDescriptor(**manifest_data) -def _resolve_agent(name: str, path: str, add_manifest_paths: List[str] = []) -> AgentInfo: +def _resolve_agent( + name: str, path: str, add_manifest_paths: List[str] = [] +) -> AgentInfo: if ":" not in path: raise ValueError( f"""Invalid format for AGENTS_REF environment variable. \ diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index 1c84c75..33289f9 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -3,7 +3,7 @@ # coding: utf-8 -from typing import Any, Dict, List, Optional, AsyncIterator, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from fastapi import ( APIRouter, @@ -35,8 +35,10 @@ from agent_workflow_server.generated.models.run_wait_response_stateless import ( RunWaitResponseStateless, ) +from agent_workflow_server.generated.models.stream_event_payload import ( + StreamEventPayload, +) from agent_workflow_server.services.runs import Runs -from agent_workflow_server.generated.models.stream_event_payload import StreamEventPayload from agent_workflow_server.services.validation import ( InvalidFormatException, ) @@ -115,7 +117,9 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: ) -async def __stream_sse_events(stream: AsyncIterator[StreamEventPayload | None]) -> AsyncIterator[Union[str,bytes]]: +async def __stream_sse_events( + stream: AsyncIterator[StreamEventPayload | None], +) -> AsyncIterator[Union[str, bytes]]: last_event_id = 0 async for event in stream: if event is None: @@ -183,15 +187,18 @@ async def create_and_stream_stateless_run_output( """Create a stateless run and join its output stream. See 'GET /runs/{run_id}/stream' for details on the return values.""" try: new_run = await Runs.put(run_create_stateless) - return StreamingResponse(__stream_sse_events(Runs.stream_events(new_run.run_id)), media_type="text/event-stream") + return StreamingResponse( + __stream_sse_events(Runs.stream_events(new_run.run_id)), + media_type="text/event-stream", + ) except HTTPException: raise - except TimeoutError as terr: + except TimeoutError: return Response(status_code=status.HTTP_204_NO_CONTENT) - except Exception as exc: + except Exception: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Run create error", + detail="Run create error", ) @@ -360,7 +367,10 @@ async def stream_stateless_run_output( status_code=status.HTTP_404_NOT_FOUND, detail=f"Run with ID {run_id} not found", ) - return StreamingResponse(__stream_sse_events(Runs.stream_events(run_id)), media_type="text/event-stream") + return StreamingResponse( + __stream_sse_events(Runs.stream_events(run_id)), + media_type="text/event-stream", + ) except HTTPException: raise except TimeoutError: diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index abb7681..5f4d7af 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -9,7 +9,7 @@ import sys import uvicorn -from dotenv import load_dotenv, find_dotenv +from dotenv import find_dotenv, load_dotenv from fastapi import Depends, FastAPI from agent_workflow_server.agents.load import load_agents @@ -48,21 +48,35 @@ dependencies=[Depends(authentication_with_api_key)], ) + def signal_handler(sig, frame): logger.warning(f"Received {signal.Signals(sig).name}. Exiting...") sys.exit(0) + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Agent Workflow Server") - parser.add_argument('--host', default=os.getenv("API_HOST", DEFAULT_HOST)) - parser.add_argument('--port', type=int, default=int(os.getenv("API_PORT", DEFAULT_PORT))) - parser.add_argument('--num-workers', type=int, default=int(os.environ.get("NUM_WORKERS", 5))) - parser.add_argument('--agent-manifest-path', action='append', type=pathlib.Path, default=[os.getenv('AGENT_MANIFEST_PATH', "manifest.json")]) - parser.add_argument('--agents-ref', default=os.getenv('AGENTS_REF', None)) - parser.add_argument('--reload', action='store_true') - parser.add_argument('--log-level', default=os.environ.get("NUM_WORKERS", logging.INFO)) + parser.add_argument("--host", default=os.getenv("API_HOST", DEFAULT_HOST)) + parser.add_argument( + "--port", type=int, default=int(os.getenv("API_PORT", DEFAULT_PORT)) + ) + parser.add_argument( + "--num-workers", type=int, default=int(os.environ.get("NUM_WORKERS", 5)) + ) + parser.add_argument( + "--agent-manifest-path", + action="append", + type=pathlib.Path, + default=[os.getenv("AGENT_MANIFEST_PATH", "manifest.json")], + ) + parser.add_argument("--agents-ref", default=os.getenv("AGENTS_REF", None)) + parser.add_argument("--reload", action="store_true") + parser.add_argument( + "--log-level", default=os.environ.get("NUM_WORKERS", logging.INFO) + ) return parser.parse_args() + def start(): try: args = parse_args() diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index d18ee58..5e9f977 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -6,15 +6,9 @@ from collections import defaultdict from datetime import datetime from itertools import islice -from typing import Any, AsyncGenerator, Dict, List, Optional, AsyncIterator, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional from uuid import uuid4 -from agent_workflow_server.generated.models.run_output_stream import ( - RunOutputStream, -) -from agent_workflow_server.generated.models.value_run_result_update import ( - ValueRunResultUpdate, -) from agent_workflow_server.generated.models.run_create_stateless import ( RunCreateStateless as ApiRunCreate, ) @@ -25,7 +19,10 @@ RunStateless as ApiRun, ) from agent_workflow_server.generated.models.stream_event_payload import ( - StreamEventPayload + StreamEventPayload, +) +from agent_workflow_server.generated.models.value_run_result_update import ( + ValueRunResultUpdate, ) from agent_workflow_server.storage.models import Run, RunInfo, RunStatus from agent_workflow_server.storage.storage import DB @@ -268,9 +265,11 @@ async def stream_events(run_id: str) -> AsyncIterator[StreamEventPayload | None] yield None continue else: - logger.error(f"received control message \"{message.data}\" in stream events for run: {run_id}") + logger.error( + f'received control message "{message.data}" in stream events for run: {run_id}' + ) continue - + # We need to get the latest value to return run = DB.get_run(run_id) if run is None: @@ -302,13 +301,13 @@ async def join( ) -> AsyncGenerator[Message, None]: queue = await Runs.Stream.subscribe(run_id) - # Check after subscribe whether the run is completed to + # Check after subscribe whether the run is completed to # avoid race condition. run = DB.get_run(run_id) if run is None: raise ValueError(f"Run {run_id} not found") if run["status"] != "pending" and queue.empty(): - return + return while True: try: From d722aaef9ad50953020d44584bca1cc26cc964dc Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Thu, 24 Apr 2025 20:14:55 +0200 Subject: [PATCH 4/8] fix: update tests for parameterized load_agents. Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- tests/test_load.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_load.py b/tests/test_load.py index 38a60dd..e6d8799 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,6 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 - +import os import pytest from pytest_mock import MockerFixture @@ -20,11 +20,16 @@ MockAgent, ) +def _env_load_agents(): + agent_ref = os.getenv("AGENTS_REF") + manifest_path = os.getenv("AGENT_MANIFEST_PATH") + load_agents(agents_ref=agent_ref, add_manifest_paths=[manifest_path]) + def test_load_agents(mocker: MockerFixture): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 assert isinstance(AGENTS[MOCK_AGENT_ID].agent, MockAgent) @@ -37,7 +42,7 @@ def test_load_agents(mocker: MockerFixture): def test_get_agent_info(mocker: MockerFixture, agent_id: str, expected: bool): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 @@ -55,7 +60,7 @@ def test_get_agent_info(mocker: MockerFixture, agent_id: str, expected: bool): def test_get_agent(mocker: MockerFixture, agent_id: str, expected: bool): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 @@ -87,7 +92,7 @@ def test_search_agents( ): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) - load_agents() + _env_load_agents() assert len(AGENTS) == 1 From 08db807eeb058a0190cac0ef67913e490da80f37 Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Thu, 24 Apr 2025 20:24:42 +0200 Subject: [PATCH 5/8] fix: format Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- tests/test_load.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_load.py b/tests/test_load.py index e6d8799..f8f6778 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,7 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 import os + import pytest from pytest_mock import MockerFixture @@ -20,6 +21,7 @@ MockAgent, ) + def _env_load_agents(): agent_ref = os.getenv("AGENTS_REF") manifest_path = os.getenv("AGENT_MANIFEST_PATH") From fb03e210a94cb0245cca4d3739e16e5beaffeba1 Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Fri, 25 Apr 2025 09:56:07 +0200 Subject: [PATCH 6/8] chore: remove reload arg that is not working. Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- src/agent_workflow_server/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index 5f4d7af..9682ce3 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -70,7 +70,6 @@ def parse_args() -> argparse.Namespace: default=[os.getenv("AGENT_MANIFEST_PATH", "manifest.json")], ) parser.add_argument("--agents-ref", default=os.getenv("AGENTS_REF", None)) - parser.add_argument("--reload", action="store_true") parser.add_argument( "--log-level", default=os.environ.get("NUM_WORKERS", logging.INFO) ) @@ -97,7 +96,6 @@ def start(): host=args.host, port=args.port, loop="asyncio", - reload=args.reload, ) server = uvicorn.Server(config) loop.run_until_complete(server.serve()) From e40800b21d1df34e72c81c7a80d1dbf4aeb5ce81 Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:59:12 +0200 Subject: [PATCH 7/8] chore: updates for review comments. Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- src/agent_workflow_server/agents/load.py | 2 -- src/agent_workflow_server/apis/stateless_runs.py | 6 +++--- src/agent_workflow_server/services/queue.py | 5 +++-- src/agent_workflow_server/services/runs.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index 01cb06b..5440efd 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -152,8 +152,6 @@ def _resolve_agent( def load_agents(agents_ref: Optional[str] = None, add_manifest_paths: List[str] = []): - # Simulate loading the config from environment variable - try: config: Dict[str, str] = json.loads(agents_ref) if agents_ref else {} except json.JSONDecodeError: diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index 33289f9..3110509 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -117,7 +117,7 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: ) -async def __stream_sse_events( +async def _stream_sse_events( stream: AsyncIterator[StreamEventPayload | None], ) -> AsyncIterator[Union[str, bytes]]: last_event_id = 0 @@ -188,7 +188,7 @@ async def create_and_stream_stateless_run_output( try: new_run = await Runs.put(run_create_stateless) return StreamingResponse( - __stream_sse_events(Runs.stream_events(new_run.run_id)), + _stream_sse_events(Runs.stream_events(new_run.run_id)), media_type="text/event-stream", ) except HTTPException: @@ -368,7 +368,7 @@ async def stream_stateless_run_output( detail=f"Run with ID {run_id} not found", ) return StreamingResponse( - __stream_sse_events(Runs.stream_events(run_id)), + _stream_sse_events(Runs.stream_events(run_id)), media_type="text/event-stream", ) except HTTPException: diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index 58b4d48..a7daa8c 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -109,9 +109,10 @@ async def worker(worker_id: int): "interrupted", message_data=json.dumps(message.data), ) - await Runs.Stream.publish(run_id, message) - if last_message.type == "interrupt": + await Runs.Stream.publish(run_id, message) break + else: + await Runs.Stream.publish(run_id, message) except Exception as error: await Runs.Stream.publish( run_id, Message(type="message", data=(str(error))) diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index 5e9f977..56ec23d 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -266,7 +266,7 @@ async def stream_events(run_id: str) -> AsyncIterator[StreamEventPayload | None] continue else: logger.error( - f'received control message "{message.data}" in stream events for run: {run_id}' + f'received unknown control message "{message.data}" in stream events for run: {run_id}' ) continue From 12fac3ccf2ea82cccd42d2a040a274b58e4c10e5 Mon Sep 17 00:00:00 2001 From: Jeff Napper <103025963+jnapper7@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:47:19 +0200 Subject: [PATCH 8/8] fix: add check for supported stream and revert CLI support. Signed-off-by: Jeff Napper <103025963+jnapper7@users.noreply.github.com> --- .../apis/stateless_runs.py | 20 +++++++++ src/agent_workflow_server/main.py | 41 +++++-------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index 3110509..0cf9b9e 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -38,6 +38,7 @@ from agent_workflow_server.generated.models.stream_event_payload import ( StreamEventPayload, ) +from agent_workflow_server.generated.models.streaming_mode import StreamingMode from agent_workflow_server.services.runs import Runs from agent_workflow_server.services.validation import ( InvalidFormatException, @@ -55,6 +56,25 @@ async def _validate_run_create_stateless( if run_create_stateless.agent_id is None: """Pre-process the RunCreateStateless object to set the agent_id if not provided.""" run_create_stateless.agent_id = get_default_agent().agent_id + if run_create_stateless.stream_mode is not None: + # Server only supports VALUES streaming at the moment. + if ( + isinstance(run_create_stateless.stream_mode.actual_instance, List) + and StreamingMode.VALUES + not in run_create_stateless.stream_mode.actual_instance + ): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail='stream mode: "values" required', + ) + elif ( + not isinstance(run_create_stateless.stream_mode.actual_instance, List) + and StreamingMode.VALUES != run_create_stateless.stream_mode.actual_instance + ): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail='stream mode: "values" required', + ) return run_create_stateless diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index 9682ce3..826185d 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -1,10 +1,8 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 -import argparse import asyncio import logging import os -import pathlib import signal import sys @@ -26,6 +24,8 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 +DEFAULT_NUM_WORKERS = 5 +DEFAULT_AGENT_MANIFEST_PATH = "manifest.json" logger = logging.getLogger(__name__) @@ -54,38 +54,17 @@ def signal_handler(sig, frame): sys.exit(0) -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Agent Workflow Server") - parser.add_argument("--host", default=os.getenv("API_HOST", DEFAULT_HOST)) - parser.add_argument( - "--port", type=int, default=int(os.getenv("API_PORT", DEFAULT_PORT)) - ) - parser.add_argument( - "--num-workers", type=int, default=int(os.environ.get("NUM_WORKERS", 5)) - ) - parser.add_argument( - "--agent-manifest-path", - action="append", - type=pathlib.Path, - default=[os.getenv("AGENT_MANIFEST_PATH", "manifest.json")], - ) - parser.add_argument("--agents-ref", default=os.getenv("AGENTS_REF", None)) - parser.add_argument( - "--log-level", default=os.environ.get("NUM_WORKERS", logging.INFO) - ) - return parser.parse_args() - - def start(): try: - args = parse_args() - logging.basicConfig(level=args.log_level.upper()) - signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - load_agents(args.agents_ref, args.agent_manifest_path) - n_workers = args.num_workers + agents_ref = os.getenv("AGENTS_REF", None) + agent_manifest_path = os.getenv( + "AGENT_MANIFEST_PATH", DEFAULT_AGENT_MANIFEST_PATH + ) + load_agents(agents_ref, agent_manifest_path) + n_workers = int(os.environ.get("NUM_WORKERS", DEFAULT_NUM_WORKERS)) loop = asyncio.get_event_loop() loop.create_task(start_workers(n_workers)) @@ -93,8 +72,8 @@ def start(): # use module import method to support reload argument config = uvicorn.Config( "agent_workflow_server.main:app", - host=args.host, - port=args.port, + host=os.getenv("API_HOST", DEFAULT_HOST) or DEFAULT_HOST, + port=int(os.getenv("API_PORT", DEFAULT_PORT)) or DEFAULT_PORT, loop="asyncio", ) server = uvicorn.Server(config)