Skip to content

feat: Support streaming #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/agent_workflow_server/agents/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import os
import pkgutil
from typing import Any, Dict, Hashable, List, Mapping, NamedTuple
from typing import Any, Dict, Hashable, List, Mapping, NamedTuple, Optional

import agent_workflow_server.agents.adapters
from agent_workflow_server.agents.oas_generator import generate_agent_oapi
Expand Down Expand Up @@ -70,7 +70,9 @@ def _read_manifest(path: str) -> AgentACPDescriptor:
return AgentACPDescriptor(**manifest_data)


def _resolve_agent(name: str, path: str) -> AgentInfo:
def _resolve_agent(
name: str, path: str, add_manifest_paths: List[str] = []
) -> AgentInfo:
if ":" not in path:
raise ValueError(
f"""Invalid format for AGENTS_REF environment variable. \
Expand Down Expand Up @@ -127,8 +129,7 @@ def _resolve_agent(name: str, path: str) -> AgentInfo:
# Load manifest. Check in paths below (in order)
manifest_paths = [
os.path.join(os.path.dirname(module.__file__), "manifest.json"),
os.environ.get("AGENT_MANIFEST_PATH", "manifest.json") or "manifest.json",
]
] + add_manifest_paths

for manifest_path in manifest_paths:
manifest = _read_manifest(manifest_path)
Expand All @@ -150,18 +151,16 @@ def _resolve_agent(name: str, path: str) -> AgentInfo:
return AgentInfo(agent=agent, manifest=manifest, schema=schema)


def load_agents():
# Simulate loading the config from environment variable

def load_agents(agents_ref: Optional[str] = None, add_manifest_paths: List[str] = []):
try:
config: Dict[str, str] = json.loads(os.getenv("AGENTS_REF", {}))
config: Dict[str, str] = json.loads(agents_ref) if agents_ref else {}
except json.JSONDecodeError:
raise ValueError("""Invalid format for AGENTS_REF environment variable. \
Must be a dictionary of agent_id -> module:var pairs. \
Example: {"agent1": "agent1_module:agent1_var", "agent2": "agent2_module:agent2_var"}""")
for agent_id, agent_path in config.items():
try:
agent = _resolve_agent(agent_id, agent_path)
agent = _resolve_agent(agent_id, agent_path, add_manifest_paths)
AGENTS[agent_id] = agent
logger.info(f"Registered Agent: '{agent_id}'", {"agent_id": agent_id})
except Exception as e:
Expand Down
59 changes: 56 additions & 3 deletions src/agent_workflow_server/apis/stateless_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# coding: utf-8

from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, List, Optional, Union

from fastapi import (
APIRouter,
Expand All @@ -15,6 +15,7 @@
Response,
status,
)
from fastapi.responses import StreamingResponse
from pydantic import Field, StrictBool, StrictStr
from typing_extensions import Annotated

Expand All @@ -34,6 +35,9 @@
from agent_workflow_server.generated.models.run_wait_response_stateless import (
RunWaitResponseStateless,
)
from agent_workflow_server.generated.models.stream_event_payload import (
StreamEventPayload,
)
from agent_workflow_server.services.runs import Runs
from agent_workflow_server.services.validation import (
InvalidFormatException,
Expand Down Expand Up @@ -113,6 +117,22 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless:
)


async def _stream_sse_events(
stream: AsyncIterator[StreamEventPayload | None],
) -> AsyncIterator[Union[str, bytes]]:
last_event_id = 0
async for event in stream:
if event is None:
yield ":"
else:
last_event_id += 1
yield f"""id: {last_event_id}
event: agent_event
data: {event.to_json()}

"""


@router.post(
"/runs/{run_id}/cancel",
responses={
Expand Down Expand Up @@ -165,7 +185,21 @@ async def create_and_stream_stateless_run_output(
] = Body(None, description=""),
) -> RunOutputStream:
"""Create a stateless run and join its output stream. See 'GET /runs/{run_id}/stream' for details on the return values."""
raise HTTPException(status_code=500, detail="Not implemented")
try:
new_run = await Runs.put(run_create_stateless)
return StreamingResponse(
_stream_sse_events(Runs.stream_events(new_run.run_id)),
media_type="text/event-stream",
)
except HTTPException:
raise
except TimeoutError:
return Response(status_code=status.HTTP_204_NO_CONTENT)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Run create error",
)


@router.post(
Expand Down Expand Up @@ -326,7 +360,26 @@ async def stream_stateless_run_output(
),
) -> RunOutputStream:
"""Join the output stream of an existing run. This endpoint streams output in real-time from a run. Only output produced after this endpoint is called will be streamed."""
raise HTTPException(status_code=500, detail="Not implemented")
try:
run = Runs.get(run_id)
if run is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Run with ID {run_id} not found",
)
return StreamingResponse(
_stream_sse_events(Runs.stream_events(run_id)),
media_type="text/event-stream",
)
except HTTPException:
raise
except TimeoutError:
return Response(status_code=status.HTTP_204_NO_CONTENT)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Run with ID {run_id} error",
)


@router.get(
Expand Down
46 changes: 36 additions & 10 deletions src/agent_workflow_server/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import logging
import os
import pathlib
import signal
import sys

import uvicorn
import uvicorn.logging
from dotenv import load_dotenv
from dotenv import find_dotenv, load_dotenv
from fastapi import Depends, FastAPI

import agent_workflow_server.logging.logger # noqa: F401
from agent_workflow_server.agents.load import load_agents
from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter
from agent_workflow_server.apis.agents import router as AgentsApiRouter
Expand All @@ -23,7 +22,7 @@
from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter
from agent_workflow_server.services.queue import start_workers

load_dotenv()
load_dotenv(dotenv_path=find_dotenv(usecwd=True))

DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
Expand Down Expand Up @@ -55,20 +54,47 @@ def signal_handler(sig, frame):
sys.exit(0)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Agent Workflow Server")
parser.add_argument("--host", default=os.getenv("API_HOST", DEFAULT_HOST))
parser.add_argument(
"--port", type=int, default=int(os.getenv("API_PORT", DEFAULT_PORT))
)
parser.add_argument(
"--num-workers", type=int, default=int(os.environ.get("NUM_WORKERS", 5))
)
parser.add_argument(
"--agent-manifest-path",
action="append",
type=pathlib.Path,
default=[os.getenv("AGENT_MANIFEST_PATH", "manifest.json")],
)
parser.add_argument("--agents-ref", default=os.getenv("AGENTS_REF", None))
parser.add_argument(
"--log-level", default=os.environ.get("NUM_WORKERS", logging.INFO)
)
return parser.parse_args()


def start():
try:
args = parse_args()
logging.basicConfig(level=args.log_level.upper())

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
load_agents()
n_workers = int(os.environ.get("NUM_WORKERS", 5))

load_agents(args.agents_ref, args.agent_manifest_path)
n_workers = args.num_workers

loop = asyncio.get_event_loop()
loop.create_task(start_workers(n_workers))

# use module import method to support reload argument
config = uvicorn.Config(
app,
host=os.getenv("API_HOST", DEFAULT_HOST) or DEFAULT_HOST,
port=int(os.getenv("API_PORT", DEFAULT_PORT)) or DEFAULT_PORT,
"agent_workflow_server.main:app",
host=args.host,
port=args.port,
loop="asyncio",
)
server = uvicorn.Server(config)
Expand Down
8 changes: 5 additions & 3 deletions src/agent_workflow_server/services/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ async def worker(worker_id: int):
last_message = None
async for message in stream:
message.data = make_serializable(message.data)
await Runs.Stream.publish(run_id, message)
last_message = message
if last_message.type == "interrupt":
log_run(
Expand All @@ -110,7 +109,10 @@ async def worker(worker_id: int):
"interrupted",
message_data=json.dumps(message.data),
)
await Runs.Stream.publish(run_id, message)
break
else:
await Runs.Stream.publish(run_id, message)
except Exception as error:
await Runs.Stream.publish(
run_id, Message(type="message", data=(str(error)))
Expand All @@ -134,7 +136,6 @@ async def worker(worker_id: int):
)
validate_output(run_id, run["agent_id"], last_message.data)
DB.add_run_output(run_id, last_message.data)
await Runs.Stream.publish(run_id, Message(type="control", data="done"))
if last_message.type == "interrupt":
interrupt = Interrupt(
event=last_message.event, ai_data=last_message.data
Expand All @@ -144,12 +145,13 @@ async def worker(worker_id: int):
else:
await Runs.set_status(run_id, "success")
log_run(worker_id, run_id, "succeeded", **run_stats(run_info))
await Runs.Stream.publish(run_id, Message(type="control", data="done"))

except InvalidFormatException as error:
log_run(worker_id, run_id, "failed")
await Runs.Stream.publish(
run_id, Message(type="message", data=str(error))
)
log_run(worker_id, run_id, "failed")
raise RunError(str(error))

except AttemptsExceededError:
Expand Down
55 changes: 51 additions & 4 deletions src/agent_workflow_server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from datetime import datetime
from itertools import islice
from typing import Any, AsyncGenerator, Dict, List, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
from uuid import uuid4

from agent_workflow_server.generated.models.run_create_stateless import (
Expand All @@ -18,6 +18,12 @@
from agent_workflow_server.generated.models.run_stateless import (
RunStateless as ApiRun,
)
from agent_workflow_server.generated.models.stream_event_payload import (
StreamEventPayload,
)
from agent_workflow_server.generated.models.value_run_result_update import (
ValueRunResultUpdate,
)
from agent_workflow_server.storage.models import Run, RunInfo, RunStatus
from agent_workflow_server.storage.storage import DB

Expand Down Expand Up @@ -247,6 +253,37 @@ async def wait_for_output(run_id: str, timeout: float = None):

return None, None

@staticmethod
async def stream_events(run_id: str) -> AsyncIterator[StreamEventPayload | None]:
async for message in Runs.Stream.join(run_id):
msg_data = message.data

if message.type == "control":
if message.data == "done":
break
elif message.data == "timeout":
yield None
continue
else:
logger.error(
f'received unknown control message "{message.data}" in stream events for run: {run_id}'
)
continue

# We need to get the latest value to return
run = DB.get_run(run_id)
if run is None:
raise ValueError(f"Run {run_id} not found")

yield StreamEventPayload(
ValueRunResultUpdate(
type="values",
run_id=run["run_id"],
status=run["status"],
values=msg_data,
)
)

class Stream:
@staticmethod
async def publish(run_id: str, message: Message) -> None:
Expand All @@ -263,11 +300,21 @@ async def join(
run_id: str,
) -> AsyncGenerator[Message, None]:
queue = await Runs.Stream.subscribe(run_id)

# Check after subscribe whether the run is completed to
# avoid race condition.
run = DB.get_run(run_id)
if run is None:
raise ValueError(f"Run {run_id} not found")
if run["status"] != "pending" and queue.empty():
return

while True:
try:
message: Message = await asyncio.wait_for(queue.get(), timeout=1)
if message.topic == "control" and message.data == "done":
break
message: Message = await asyncio.wait_for(queue.get(), timeout=10)
yield message
if message.type == "control" and message.data == "done":
break
except TimeoutError as error:
logger.error(f"Timeout waiting for run {run_id}: {error}")
yield Message(type="control", data="timeout")
Loading
Loading