From a54263cf05ea188903f6ecb9c0ca35696685a840 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Wed, 26 Mar 2025 15:57:57 +0100 Subject: [PATCH 1/9] fix: better import error --- src/agent_workflow_server/agents/load.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index 9e28069..efa70ee 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -67,11 +67,14 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: module_name = module_or_file try: module = importlib.import_module(module_name) - except ModuleNotFoundError: - raise ModuleNotFoundError( - f"""Failed to load agent module {module_name}. \ -Check if it\'s installed and that module name in 'AGENTS_REF' env variable is correct.""" - ) + except ImportError as e: + if any(part in str(e) for part in module_name.split(".")): + raise ImportError( + f"""Failed to load agent module {module_name}. \ +Check that it is installed and that the module name in 'AGENTS_REF' env variable is correct.""" + ) from e + else: + raise e else: # It's a path to a file, try to load it as a module file = module_or_file From 32fba6545ba0cdf239a4429eef5d182943287ba6 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Mon, 31 Mar 2025 12:56:16 +0200 Subject: [PATCH 2/9] feat: support interrupts (1/2) and different fixes and better logging --- .../agents/adapters/langgraph.py | 17 +++++++- .../agents/adapters/llamaindex.py | 13 ++++-- src/agent_workflow_server/agents/base.py | 7 ++-- src/agent_workflow_server/agents/load.py | 7 +++- .../apis/stateless_runs.py | 17 ++++++-- .../custom_logger.py => logging/logger.py} | 4 ++ src/agent_workflow_server/main.py | 4 +- src/agent_workflow_server/services/message.py | 9 +++-- src/agent_workflow_server/services/queue.py | 40 +++++++++++++++---- src/agent_workflow_server/services/runs.py | 25 +++++++++++- src/agent_workflow_server/services/stream.py | 18 ++++++--- .../services/validation.py | 2 + src/agent_workflow_server/storage/storage.py | 4 +- src/agent_workflow_server/utils/tools.py | 13 ++++-- 14 files changed, 139 insertions(+), 41 deletions(-) rename src/agent_workflow_server/{logger/custom_logger.py => logging/logger.py} (71%) diff --git a/src/agent_workflow_server/agents/adapters/langgraph.py b/src/agent_workflow_server/agents/adapters/langgraph.py index 626f71c..6dc11a3 100644 --- a/src/agent_workflow_server/agents/adapters/langgraph.py +++ b/src/agent_workflow_server/agents/adapters/langgraph.py @@ -4,9 +4,11 @@ from typing import Optional from langchain_core.runnables import RunnableConfig +from langgraph.constants import INTERRUPT from langgraph.graph.graph import CompiledGraph, Graph from agent_workflow_server.agents.base import BaseAdapter, BaseAgent +from agent_workflow_server.services.message import Message from agent_workflow_server.storage.models import Config @@ -33,6 +35,17 @@ async def astream(self, input: dict, config: Optional[Config]): ) if config else None, - stream_mode="values", ): - yield event + for k, v in event.items(): + if k == INTERRUPT: + yield Message( + type="interrupt", + event=k, + data=v[0].value, + ) + else: + yield Message( + type="message", + event=k, + data=v, + ) diff --git a/src/agent_workflow_server/agents/adapters/llamaindex.py b/src/agent_workflow_server/agents/adapters/llamaindex.py index 0483e4a..b96684c 100644 --- a/src/agent_workflow_server/agents/adapters/llamaindex.py +++ b/src/agent_workflow_server/agents/adapters/llamaindex.py @@ -7,6 +7,7 @@ from llama_index.core.workflow import Workflow from agent_workflow_server.agents.base import BaseAdapter, BaseAgent +from agent_workflow_server.services.message import Message from agent_workflow_server.storage.models import Config @@ -25,9 +26,15 @@ class LlamaIndexAgent(BaseAgent): def __init__(self, agent: Workflow): self.agent = agent - async def astream(self, input: dict, config: Config): + async def astream(self, input: dict, config: Optional[Config]): handler = self.agent.run(**input) async for event in handler.stream_events(): - yield event + yield Message( + type="message", + data=event, + ) final_result = await handler - yield final_result + yield Message( + type="message", + data=final_result, + ) diff --git a/src/agent_workflow_server/agents/base.py b/src/agent_workflow_server/agents/base.py index 2f315f8..2576b51 100644 --- a/src/agent_workflow_server/agents/base.py +++ b/src/agent_workflow_server/agents/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Dict, Optional +from agent_workflow_server.services.message import Message from agent_workflow_server.storage.models import Config @@ -11,9 +12,9 @@ class BaseAgent(ABC): @abstractmethod async def astream( self, input: Optional[Dict[str, Any]], config: Optional[Config] - ) -> AsyncGenerator[Any, None]: - """Invokes the agent with the given input and configuration and streams (returns) events asynchronously. - The last event includes the final result.""" + ) -> AsyncGenerator[Message, None]: + """Invokes the agent with the given input and configuration and streams (returns) `Message`s asynchronously. + The last `Message` includes the final result.""" pass diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index 60ab8f7..b7e4ed7 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -59,7 +59,12 @@ def _load_adapters() -> List[BaseAdapter]: def _read_manifest(path: str) -> AgentACPDescriptor: if os.path.isfile(path): with open(path, "r") as file: - manifest_data = json.load(file) + try: + manifest_data = json.load(file) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON format in manifest file: {path}. Error: {e}" + ) # print full path logger.info(f"Loaded Agent Manifest from {os.path.abspath(path)}") return AgentACPDescriptor(**manifest_data) diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index b244dc9..463c52d 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -23,6 +23,7 @@ ) from agent_workflow_server.generated.models.run_output import ( RunError, + RunInterrupt, RunOutput, RunResult, ) @@ -40,8 +41,6 @@ validate_run_create as validate, ) -from ..utils.tools import serialize_to_dict - router = APIRouter() @@ -56,6 +55,11 @@ async def _validate_run_create( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e), ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: @@ -70,10 +74,15 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: if run is None: return Response(status_code=status.HTTP_404_NOT_FOUND) if run.status == "success" and run_output is not None: + return RunWaitResponseStateless( + run=run, + output=RunOutput(RunResult(type="result", values=run_output)), + ) + elif run.status == "interrupted": return RunWaitResponseStateless( run=run, output=RunOutput( - RunResult(type="result", values=serialize_to_dict(run_output)) + RunInterrupt(type="interrupt", interrupt={"default": run_output}) ), ) else: @@ -244,7 +253,7 @@ async def resume_stateless_run( body: Dict[str, Any] = Body(None, description=""), ) -> RunStateless: """Provide the needed input to a run to resume its execution. Can only be called for runs that are in the interrupted state Schema of the provided input must match with the schema specified in the agent specs under interrupts for the interrupt type the agent generated for this specific interruption.""" - raise HTTPException(status_code=500, detail="Not implemented") + return await Runs.resume(run_id, body) @router.post( diff --git a/src/agent_workflow_server/logger/custom_logger.py b/src/agent_workflow_server/logging/logger.py similarity index 71% rename from src/agent_workflow_server/logger/custom_logger.py rename to src/agent_workflow_server/logging/logger.py index 92c6f78..d6d3dde 100644 --- a/src/agent_workflow_server/logger/custom_logger.py +++ b/src/agent_workflow_server/logging/logger.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import os from uvicorn.logging import ColourizedFormatter @@ -12,3 +13,6 @@ handler.setFormatter(colorformatter) CustomLoggerHandler = handler + +LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() +logging.basicConfig(level=LOG_LEVEL, handlers=[CustomLoggerHandler], force=True) diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index 791456a..5391fd3 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv from fastapi import Depends, FastAPI +import agent_workflow_server.logging.logger # noqa: F401 from agent_workflow_server.agents.load import load_agents from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter from agent_workflow_server.apis.agents import router as AgentsApiRouter @@ -20,7 +21,6 @@ setup_api_key_auth, ) from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter -from agent_workflow_server.logger.custom_logger import CustomLoggerHandler from agent_workflow_server.services.queue import start_workers load_dotenv() @@ -28,8 +28,6 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 -logging.basicConfig(level=logging.INFO, handlers=[CustomLoggerHandler], force=True) - logger = logging.getLogger(__name__) app = FastAPI( diff --git a/src/agent_workflow_server/services/message.py b/src/agent_workflow_server/services/message.py index 27f048c..d9c5269 100644 --- a/src/agent_workflow_server/services/message.py +++ b/src/agent_workflow_server/services/message.py @@ -1,10 +1,13 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Literal, Optional + +type MessageType = Literal["control", "message", "interrupt"] class Message: - def __init__(self, topic: str, data: Any): - self.topic = topic + def __init__(self, type: MessageType, data: Any, event: Optional[str] = None): + self.type = type self.data = data + self.event = event diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index ca1fd9e..ef321df 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import logging -import traceback from datetime import datetime from typing import Literal @@ -13,6 +13,7 @@ ) from agent_workflow_server.storage.models import RunInfo from agent_workflow_server.storage.storage import DB +from agent_workflow_server.utils.tools import make_serializable from .message import Message from .runs import RUNS_QUEUE, Runs @@ -43,13 +44,22 @@ async def start_workers(n_workers: int): def log_run( worker_id: int, run_id: str, - info: Literal["started", "succeeded", "failed", "exeeded attempts"], + info: Literal[ + "got message", + "started", + "interrupted", + "succeeded", + "failed", + "exeeded attempts", + ], **kwargs, ): log_methods = { + "got message": logger.debug, "started": logger.info, + "interrupted": logger.info, "succeeded": logger.info, - "failed": logger.warning, + "failed": logger.exception, "exeeded attempts": logger.error, } log_message = f"(Worker {worker_id}) Background Run {run_id} {info}" @@ -90,11 +100,21 @@ async def worker(worker_id: int): stream = stream_run(run) last_message = None async for message in stream: + message.data = make_serializable(message.data) await Runs.Stream.publish(run_id, message) last_message = message + if last_message.type == "interrupt": + await Runs.set_status(run_id, "interrupted") + log_run( + worker_id, + run_id, + "interrupted", + message_data=json.dumps(message.data), + ) + break except Exception as error: await Runs.Stream.publish( - run_id, Message(topic="error", data=(str(error))) + run_id, Message(type="message", data=(str(error))) ) raise RunError(error) @@ -108,15 +128,20 @@ async def worker(worker_id: int): try: validate_output(run_id, run["agent_id"], last_message.data) - + log_run( + worker_id, + run_id, + "got message", + message_data=json.dumps(last_message.data), + ) DB.add_run_output(run_id, last_message.data) - await Runs.Stream.publish(run_id, Message(topic="control", data="done")) + await Runs.Stream.publish(run_id, Message(type="control", data="done")) await Runs.set_status(run_id, "success") log_run(worker_id, run_id, "succeeded", **run_stats(run_info)) except InvalidFormatException as error: await Runs.Stream.publish( - run_id, Message(topic="error", data=str(error)) + run_id, Message(type="message", data=str(error)) ) log_run(worker_id, run_id, "failed") raise RunError(str(error)) @@ -154,7 +179,6 @@ async def worker(worker_id: int): "failed", **{"error": error, **run_stats(run_info)}, ) - traceback.print_exc() await RUNS_QUEUE.put(run_id) # Re-queue for retry diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index 3d4a349..8f80b02 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -6,7 +6,7 @@ from collections import defaultdict from datetime import datetime from itertools import islice -from typing import AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import uuid4 from agent_workflow_server.generated.models.run_create_stateless import ( @@ -172,6 +172,29 @@ def search(search_request: RunSearchRequest) -> List[ApiRun]: islice(islice(runs, search_request.offset, None), search_request.limit) ) + @staticmethod + async def resume(run_id: str, input: Dict[str, Any]) -> ApiRun: + # TODO: implement resume + + # run = DB.get_run(run_id) + # if run is None: + # raise Exception("Run not found") + # if run["status"] != "interrupted": + # raise Exception("Run is not in interrupted state") + + # new_run = _make_run(run_create) + # run_info = RunInfo( + # run_id=new_run["run_id"], + # attempts=0, + # ) + # DB.create_run(new_run) + # DB.create_run_info(run_info) + + # await RUNS_QUEUE.put(new_run["run_id"]) + # return _to_api_model(new_run) + + ... + @staticmethod async def set_status(run_id: str, status: RunStatus): run = DB.get_run(run_id) diff --git a/src/agent_workflow_server/services/stream.py b/src/agent_workflow_server/services/stream.py index c51de27..01a0708 100644 --- a/src/agent_workflow_server/services/stream.py +++ b/src/agent_workflow_server/services/stream.py @@ -10,9 +10,15 @@ async def stream_run(run: Run) -> AsyncGenerator[Message, None]: - agent = get_agent_info(run["agent_id"]).agent - async for event in agent.astream(input=run["input"], config=run["config"]): - yield Message( - topic="message", - data=event, - ) + agent_info = get_agent_info(run["agent_id"]) + agent = agent_info.agent + if agent_info.manifest.specs.capabilities.interrupts: + # TODO: This will only work for langgraph + if run.get("config") is None: + run["config"] = {} + if run["config"].get("configurable") is None: + run["config"]["configurable"] = {} + run["config"]["configurable"].setdefault("thread_id", run["thread_id"]) + + async for message in agent.astream(input=run["input"], config=run["config"]): + yield message diff --git a/src/agent_workflow_server/services/validation.py b/src/agent_workflow_server/services/validation.py index d9d8d3f..1af24d4 100644 --- a/src/agent_workflow_server/services/validation.py +++ b/src/agent_workflow_server/services/validation.py @@ -34,6 +34,8 @@ def validate_against_schema( def get_agent_schemas(agent_id: str): """Get input, output and config schemas for an agent""" agent_info = AGENTS.get(agent_id) + if not agent_info: + raise ValueError(f"Agent {agent_id} not found") specs = agent_info.manifest.specs return {"input": specs.input, "output": specs.output, "config": specs.config} diff --git a/src/agent_workflow_server/storage/storage.py b/src/agent_workflow_server/storage/storage.py index 31f5738..7f5713b 100644 --- a/src/agent_workflow_server/storage/storage.py +++ b/src/agent_workflow_server/storage/storage.py @@ -9,13 +9,11 @@ from dotenv import load_dotenv -from agent_workflow_server.logger.custom_logger import CustomLoggerHandler +import agent_workflow_server.logging.logger # noqa: F401 from .models import Run, RunInfo from .service import DBOperations -logging.basicConfig(level=logging.INFO, handlers=[CustomLoggerHandler]) - logger = logging.getLogger(__name__) load_dotenv() diff --git a/src/agent_workflow_server/utils/tools.py b/src/agent_workflow_server/utils/tools.py index 2528961..388ed5c 100644 --- a/src/agent_workflow_server/utils/tools.py +++ b/src/agent_workflow_server/utils/tools.py @@ -3,6 +3,7 @@ import uuid from enum import Enum +from typing import Any from pydantic import BaseModel @@ -15,13 +16,17 @@ def is_valid_uuid(val): return False -def serialize_to_dict(v): +def make_serializable(v: Any): if isinstance(v, list): - return [serialize_to_dict(vv) for vv in v] + return [make_serializable(vv) for vv in v] elif isinstance(v, dict): - return {kk: serialize_to_dict(vv) for kk, vv in v.items()} - elif isinstance(v, BaseModel): + return {kk: make_serializable(vv) for kk, vv in v.items()} + elif ( + isinstance(v, BaseModel) and hasattr(v, "model_dump") and callable(v.model_dump) + ): return v.model_dump() + elif isinstance(v, BaseModel) and hasattr(v, "dict") and callable(v.dict): + return v.dict() elif isinstance(v, Enum): return v.value else: From b775ec5827ce171d198d2f8a48aa7a8ae1333b7c Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Mon, 31 Mar 2025 18:27:21 +0200 Subject: [PATCH 3/9] feat: support interrupts (2/2) --- .../agents/adapters/langgraph.py | 24 ++++++++--- .../agents/adapters/llamaindex.py | 5 ++- src/agent_workflow_server/agents/base.py | 10 ++--- .../apis/stateless_runs.py | 8 +++- src/agent_workflow_server/services/queue.py | 18 ++++++--- src/agent_workflow_server/services/runs.py | 40 +++++++++---------- src/agent_workflow_server/services/stream.py | 10 +---- src/agent_workflow_server/storage/models.py | 10 +++++ src/agent_workflow_server/storage/service.py | 4 +- 9 files changed, 76 insertions(+), 53 deletions(-) diff --git a/src/agent_workflow_server/agents/adapters/langgraph.py b/src/agent_workflow_server/agents/adapters/langgraph.py index 6dc11a3..fd7ce26 100644 --- a/src/agent_workflow_server/agents/adapters/langgraph.py +++ b/src/agent_workflow_server/agents/adapters/langgraph.py @@ -6,10 +6,11 @@ from langchain_core.runnables import RunnableConfig from langgraph.constants import INTERRUPT from langgraph.graph.graph import CompiledGraph, Graph +from langgraph.types import Command from agent_workflow_server.agents.base import BaseAdapter, BaseAgent from agent_workflow_server.services.message import Message -from agent_workflow_server.storage.models import Config +from agent_workflow_server.storage.models import Run class LangGraphAdapter(BaseAdapter): @@ -25,16 +26,27 @@ class LangGraphAgent(BaseAgent): def __init__(self, agent: CompiledGraph): self.agent = agent - async def astream(self, input: dict, config: Optional[Config]): + async def astream(self, run: Run): + input = run["input"] + config = run["config"] + if config is None: + config = {} + configurable = config["configurable"] + if configurable is None: + configurable = {} + configurable.setdefault("thread_id", run["thread_id"]) + + # If there's an interrupt answer, ovverride the input + if run.get("interrupt") is not None and run["interrupt"].get("user_data"): + input = Command(resume=run["interrupt"]["user_data"]) + async for event in self.agent.astream( input=input, config=RunnableConfig( - configurable=config["configurable"], + configurable=configurable, tags=config["tags"], recursion_limit=config["recursion_limit"], - ) - if config - else None, + ), ): for k, v in event.items(): if k == INTERRUPT: diff --git a/src/agent_workflow_server/agents/adapters/llamaindex.py b/src/agent_workflow_server/agents/adapters/llamaindex.py index b96684c..036f70d 100644 --- a/src/agent_workflow_server/agents/adapters/llamaindex.py +++ b/src/agent_workflow_server/agents/adapters/llamaindex.py @@ -8,7 +8,7 @@ from agent_workflow_server.agents.base import BaseAdapter, BaseAgent from agent_workflow_server.services.message import Message -from agent_workflow_server.storage.models import Config +from agent_workflow_server.storage.models import Run class LlamaIndexAdapter(BaseAdapter): @@ -26,7 +26,8 @@ class LlamaIndexAgent(BaseAgent): def __init__(self, agent: Workflow): self.agent = agent - async def astream(self, input: dict, config: Optional[Config]): + async def astream(self, run: Run): + input = run["input"] handler = self.agent.run(**input) async for event in handler.stream_events(): yield Message( diff --git a/src/agent_workflow_server/agents/base.py b/src/agent_workflow_server/agents/base.py index 2576b51..d7c549b 100644 --- a/src/agent_workflow_server/agents/base.py +++ b/src/agent_workflow_server/agents/base.py @@ -2,18 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Optional from agent_workflow_server.services.message import Message -from agent_workflow_server.storage.models import Config +from agent_workflow_server.storage.models import Run class BaseAgent(ABC): @abstractmethod - async def astream( - self, input: Optional[Dict[str, Any]], config: Optional[Config] - ) -> AsyncGenerator[Message, None]: - """Invokes the agent with the given input and configuration and streams (returns) `Message`s asynchronously. + async def astream(self, run: Run) -> AsyncGenerator[Message, None]: + """Invokes the agent with the given `Run` and streams (returns) `Message`s asynchronously. The last `Message` includes the final result.""" pass diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index 463c52d..5a6ae34 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -253,7 +253,13 @@ async def resume_stateless_run( body: Dict[str, Any] = Body(None, description=""), ) -> RunStateless: """Provide the needed input to a run to resume its execution. Can only be called for runs that are in the interrupted state Schema of the provided input must match with the schema specified in the agent specs under interrupts for the interrupt type the agent generated for this specific interruption.""" - return await Runs.resume(run_id, body) + try: + return await Runs.resume(run_id, body) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) @router.post( diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index ef321df..f1342da 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -11,7 +11,7 @@ InvalidFormatException, validate_output, ) -from agent_workflow_server.storage.models import RunInfo +from agent_workflow_server.storage.models import Interrupt, RunInfo from agent_workflow_server.storage.storage import DB from agent_workflow_server.utils.tools import make_serializable @@ -104,7 +104,6 @@ async def worker(worker_id: int): await Runs.Stream.publish(run_id, message) last_message = message if last_message.type == "interrupt": - await Runs.set_status(run_id, "interrupted") log_run( worker_id, run_id, @@ -122,7 +121,7 @@ async def worker(worker_id: int): run_info["ended_at"] = ended_at run_info["exec_s"] = ended_at - started_at - run_info["queue_s"] = started_at - run["created_at"].timestamp() + run_info["queue_s"] = started_at - run_info["queued_at"].timestamp() DB.update_run_info(run_id, run_info) @@ -136,7 +135,14 @@ async def worker(worker_id: int): ) DB.add_run_output(run_id, last_message.data) await Runs.Stream.publish(run_id, Message(type="control", data="done")) - await Runs.set_status(run_id, "success") + if last_message.type == "interrupt": + interrupt = Interrupt( + event=last_message.event, ai_data=last_message.data + ) + DB.update_run(run_id, {"interrupt": interrupt}) + await Runs.set_status(run_id, "interrupted") + else: + await Runs.set_status(run_id, "success") log_run(worker_id, run_id, "succeeded", **run_stats(run_info)) except InvalidFormatException as error: @@ -152,7 +158,7 @@ async def worker(worker_id: int): { "ended_at": ended_at, "exec_s": ended_at - started_at, - "queue_s": (started_at - run["created_at"].timestamp()), + "queue_s": (started_at - run_info["queued_at"].timestamp()), } ) @@ -166,7 +172,7 @@ async def worker(worker_id: int): { "ended_at": ended_at, "exec_s": ended_at - started_at, - "queue_s": (started_at - run["created_at"].timestamp()), + "queue_s": (started_at - run_info["queued_at"].timestamp()), } ) diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index 8f80b02..45c1b59 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -122,6 +122,7 @@ async def put(run_create: ApiRunCreate) -> ApiRun: new_run = _make_run(run_create) run_info = RunInfo( run_id=new_run["run_id"], + queued_at=datetime.now(), attempts=0, ) DB.create_run(new_run) @@ -173,27 +174,24 @@ def search(search_request: RunSearchRequest) -> List[ApiRun]: ) @staticmethod - async def resume(run_id: str, input: Dict[str, Any]) -> ApiRun: - # TODO: implement resume - - # run = DB.get_run(run_id) - # if run is None: - # raise Exception("Run not found") - # if run["status"] != "interrupted": - # raise Exception("Run is not in interrupted state") - - # new_run = _make_run(run_create) - # run_info = RunInfo( - # run_id=new_run["run_id"], - # attempts=0, - # ) - # DB.create_run(new_run) - # DB.create_run_info(run_info) - - # await RUNS_QUEUE.put(new_run["run_id"]) - # return _to_api_model(new_run) - - ... + async def resume(run_id: str, user_input: Dict[str, Any]) -> ApiRun: + run = DB.get_run(run_id) + if run is None: + raise ValueError("Run not found") + if run["status"] != "interrupted": + raise ValueError("Run is not in interrupted state") + if run["interrupt"] is None: + raise ValueError(f"No interrupt found for run {run_id}") + + interrupt = run["interrupt"] + interrupt["user_data"] = user_input + + DB.update_run(run_id, {"interrupt": interrupt}) + DB.update_run_info(run_id, {"attempts": 0, "queued_at": datetime.now()}) + updated = DB.update_run_status(run_id, "pending") + + await RUNS_QUEUE.put(updated["run_id"]) + return _to_api_model(updated) @staticmethod async def set_status(run_id: str, status: RunStatus): diff --git a/src/agent_workflow_server/services/stream.py b/src/agent_workflow_server/services/stream.py index 01a0708..d803799 100644 --- a/src/agent_workflow_server/services/stream.py +++ b/src/agent_workflow_server/services/stream.py @@ -12,13 +12,5 @@ async def stream_run(run: Run) -> AsyncGenerator[Message, None]: agent_info = get_agent_info(run["agent_id"]) agent = agent_info.agent - if agent_info.manifest.specs.capabilities.interrupts: - # TODO: This will only work for langgraph - if run.get("config") is None: - run["config"] = {} - if run["config"].get("configurable") is None: - run["config"]["configurable"] = {} - run["config"]["configurable"].setdefault("thread_id", run["thread_id"]) - - async for message in agent.astream(input=run["input"], config=run["config"]): + async for message in agent.astream(run=run): yield message diff --git a/src/agent_workflow_server/storage/models.py b/src/agent_workflow_server/storage/models.py index ee0cc95..b281e02 100644 --- a/src/agent_workflow_server/storage/models.py +++ b/src/agent_workflow_server/storage/models.py @@ -7,6 +7,14 @@ RunStatus = Literal["pending", "error", "success", "timeout", "interrupted"] +class Interrupt(TypedDict): + """Definition for an Interrupt message""" + + event: str + ai_data: Any + user_data: Optional[Any] + + class Config(TypedDict): tags: Optional[List[str]] recursion_limit: Optional[int] @@ -25,12 +33,14 @@ class Run(TypedDict): created_at: datetime updated_at: datetime status: RunStatus + interrupt: Optional[Interrupt] # last interrupt (if any) class RunInfo(TypedDict): """Definition of statistics information about a Run""" run_id: str + queued_at: datetime attempts: Optional[int] started_at: Optional[datetime] ended_at: Optional[datetime] diff --git a/src/agent_workflow_server/storage/service.py b/src/agent_workflow_server/storage/service.py index 1e82861..e08f581 100644 --- a/src/agent_workflow_server/storage/service.py +++ b/src/agent_workflow_server/storage/service.py @@ -74,9 +74,9 @@ def get_run_status(self, run_id: str) -> Optional[RunStatus]: run = self.get_run(run_id) return run.get("status") if run else None - def update_run_status(self, run_id: str, status: RunStatus) -> None: + def update_run_status(self, run_id: str, status: RunStatus) -> Optional[Run]: """Update the status of a Run""" - self.update_run(run_id, {"status": status}) + return self.update_run(run_id, {"status": status}) def add_run_output(self, run_id: str, output: Any) -> None: """Add the output of a Run""" From ea6482765821a889e6f9bde7c209bbb9f31a0584 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Tue, 1 Apr 2025 16:34:05 +0200 Subject: [PATCH 4/9] fix: serialization --- src/agent_workflow_server/agents/load.py | 4 +++- src/agent_workflow_server/utils/tools.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index b7e4ed7..112dda2 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -127,7 +127,9 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: if manifest: break else: - raise ImportError("Failed to load agent manifest") + raise ImportError( + f"Failed to load agent manifest from any of the paths: {manifest_paths}" + ) try: schema = generate_agent_oapi(manifest, name) diff --git a/src/agent_workflow_server/utils/tools.py b/src/agent_workflow_server/utils/tools.py index 388ed5c..db42aac 100644 --- a/src/agent_workflow_server/utils/tools.py +++ b/src/agent_workflow_server/utils/tools.py @@ -24,7 +24,7 @@ def make_serializable(v: Any): elif ( isinstance(v, BaseModel) and hasattr(v, "model_dump") and callable(v.model_dump) ): - return v.model_dump() + return v.model_dump(mode="json") elif isinstance(v, BaseModel) and hasattr(v, "dict") and callable(v.dict): return v.dict() elif isinstance(v, Enum): From e77136e52d771a9a79a81e90b59b600cb6c0bab5 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Tue, 1 Apr 2025 16:39:05 +0200 Subject: [PATCH 5/9] fix: reset correct validation --- src/agent_workflow_server/services/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent_workflow_server/services/validation.py b/src/agent_workflow_server/services/validation.py index 0d86263..1af24d4 100644 --- a/src/agent_workflow_server/services/validation.py +++ b/src/agent_workflow_server/services/validation.py @@ -47,7 +47,7 @@ def validate_output(run_id, agent_id: str, output: Any) -> None: validate_against_schema( instance=output, - schema=schemas["output"].get("properties", schemas["output"]), + schema=schemas["output"], error_prefix=f"Output validation failed for run {run_id}", ) From 8f13330b1ddf2b0ac8ea6474906948fc111bbab1 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Tue, 1 Apr 2025 18:08:27 +0200 Subject: [PATCH 6/9] chore: add tests --- .env.example | 1 + tests/mock.py | 24 +++++++++++++++++++++--- tests/test_runs.py | 45 ++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/.env.example b/.env.example index c5bafb6..7208eb4 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,5 @@ ### SERVER-SPECIFIC ENV ### +LOG_LEVEL=INFO API_HOST=127.0.0.1 API_PORT=8000 AGENTS_REF='{"agent_uuid": "agent_module_name:agent_var"}' diff --git a/tests/mock.py b/tests/mock.py index 6f3c7d2..f9a6466 100644 --- a/tests/mock.py +++ b/tests/mock.py @@ -2,8 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +from typing import AsyncGenerator from agent_workflow_server.agents.base import BaseAdapter, BaseAgent +from agent_workflow_server.services.message import Message +from agent_workflow_server.storage.models import Run # Make sure that this and the one in the .env.test file are the same MOCK_AGENT_ID = "3f1e2549-5799-4321-91ae-2a4881d55526" @@ -11,6 +14,10 @@ MOCK_RUN_INPUT = {"message": "What's the color of the sky?"} MOCK_RUN_OUTPUT = {"message": "The color of the sky is blue"} +MOCK_RUN_INPUT_INTERRUPT = {"message": "Please interrupt"} +MOCK_RUN_EVENT_INTERRUPT = "__mock_interrupt__" +MOCK_RUN_OUTPUT_INTERRUPT = {"message": "How can I help you?"} + class MockAgentImpl: ... @@ -19,9 +26,20 @@ class MockAgent(BaseAgent): def __init__(self, agent: MockAgentImpl): self.agent = agent - async def astream(self, input: dict, config: dict): - await asyncio.sleep(3) - yield MOCK_RUN_OUTPUT + async def astream(self, run: Run) -> AsyncGenerator[Message, None]: + if run["input"] == MOCK_RUN_INPUT or ( + run.get("interrupt") is not None + and run["interrupt"].get("user_data") == MOCK_RUN_INPUT + ): + await asyncio.sleep(3) + yield Message(type="message", data=MOCK_RUN_OUTPUT) + return + if run["input"] == MOCK_RUN_INPUT_INTERRUPT: + yield Message( + type="interrupt", + event=MOCK_RUN_EVENT_INTERRUPT, + data=MOCK_RUN_OUTPUT_INTERRUPT, + ) class MockAdapter(BaseAdapter): diff --git a/tests/test_runs.py b/tests/test_runs.py index d92d620..5999842 100644 --- a/tests/test_runs.py +++ b/tests/test_runs.py @@ -9,24 +9,54 @@ from pytest_mock import MockerFixture from agent_workflow_server.agents.load import load_agents +from agent_workflow_server.generated.models.config import Config from agent_workflow_server.generated.models.run_search_request import ( RunSearchRequest, ) from agent_workflow_server.services.queue import start_workers from agent_workflow_server.services.runs import ApiRun, ApiRunCreate, Runs +from agent_workflow_server.storage.models import RunStatus from agent_workflow_server.storage.storage import DB from tests.mock import ( MOCK_AGENT_ID, MOCK_RUN_INPUT, + MOCK_RUN_INPUT_INTERRUPT, MOCK_RUN_OUTPUT, + MOCK_RUN_OUTPUT_INTERRUPT, MockAdapter, ) -run_create_mock = ApiRunCreate(agent_id=MOCK_AGENT_ID, input=MOCK_RUN_INPUT, config={}) - @pytest.mark.asyncio -async def test_invoke(mocker: MockerFixture): +@pytest.mark.parametrize( + "run_create_mock, expected_status, expected_output", + [ + ( + ApiRunCreate( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + config=Config( + tags=["test"], + recursion_limit=3, + configurable={"mock-key": "mock-value"}, + ), + ), + "success", + MOCK_RUN_OUTPUT, + ), + ( + ApiRunCreate(agent_id=MOCK_AGENT_ID, input=MOCK_RUN_INPUT_INTERRUPT), + "interrupted", + MOCK_RUN_OUTPUT_INTERRUPT, + ), + ], +) +async def test_invoke( + mocker: MockerFixture, + run_create_mock: ApiRunCreate, + expected_status: RunStatus, + expected_output: dict, +): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) try: @@ -46,9 +76,12 @@ async def test_invoke(mocker: MockerFixture): else: assert True - assert run.status == "success" + assert run.status == expected_status assert run.run_id == new_run.run_id - assert output == MOCK_RUN_OUTPUT + assert run.agent_id == run_create_mock.agent_id + assert run.creation.agent_id == run_create_mock.agent_id + assert run.creation.config == run_create_mock.config + assert output == expected_output finally: worker_task.cancel() try: @@ -62,6 +95,8 @@ async def test_invoke(mocker: MockerFixture): async def test_invoke_timeout(mocker: MockerFixture, timeout: float): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) + run_create_mock = ApiRunCreate(agent_id=MOCK_AGENT_ID, input=MOCK_RUN_INPUT) + try: load_agents() From 11aac1622fd90c8ba947b5147c29c4bebb7e494a Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Wed, 2 Apr 2025 13:39:04 +0200 Subject: [PATCH 7/9] fix: docs url Signed-off-by: Marco Trinelli (mtrinell) --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 77e6107..b195299 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,12 @@ The Agent Workflow Server (AgWS) enables participation in the Internet of Agents ## Getting Started -See [Agent Workflow Server Documentation](https://agntcy.github.io/workflow-srv/) +See [Agent Workflow Server Documentation](https://docs.agntcy.org/pages/agws/workflow_server) ## Contributing +See [Contributing](https://docs.agntcy.org/pages/agws/workflow_server#contributing) + Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. For detailed contributing guidelines, please see From 5a2f2f4bfff3e75feb22b0f7ac11a2bad99d3a87 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Wed, 2 Apr 2025 15:51:46 +0200 Subject: [PATCH 8/9] chore: fix as per pr comments Signed-off-by: Marco Trinelli (mtrinell) --- src/agent_workflow_server/agents/adapters/langgraph.py | 4 ++-- src/agent_workflow_server/services/runs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/agent_workflow_server/agents/adapters/langgraph.py b/src/agent_workflow_server/agents/adapters/langgraph.py index fd7ce26..ad18392 100644 --- a/src/agent_workflow_server/agents/adapters/langgraph.py +++ b/src/agent_workflow_server/agents/adapters/langgraph.py @@ -31,13 +31,13 @@ async def astream(self, run: Run): config = run["config"] if config is None: config = {} - configurable = config["configurable"] + configurable = config.get("configurable") if configurable is None: configurable = {} configurable.setdefault("thread_id", run["thread_id"]) # If there's an interrupt answer, ovverride the input - if run.get("interrupt") is not None and run["interrupt"].get("user_data"): + if "interrupt" in run and "user_data" in run["interrupt"]: input = Command(resume=run["interrupt"]["user_data"]) async for event in self.agent.astream( diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index 45c1b59..7d6cf89 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -180,7 +180,7 @@ async def resume(run_id: str, user_input: Dict[str, Any]) -> ApiRun: raise ValueError("Run not found") if run["status"] != "interrupted": raise ValueError("Run is not in interrupted state") - if run["interrupt"] is None: + if run.get("interrupt") is None: raise ValueError(f"No interrupt found for run {run_id}") interrupt = run["interrupt"] From b7a3ffbf397c9ef1ca0cfe8bf1b38c1e3a5631f8 Mon Sep 17 00:00:00 2001 From: "Marco Trinelli (mtrinell)" Date: Wed, 2 Apr 2025 16:30:21 +0200 Subject: [PATCH 9/9] fix: better logging Signed-off-by: Marco Trinelli (mtrinell) --- src/agent_workflow_server/agents/load.py | 8 ++++++++ src/agent_workflow_server/services/queue.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index 112dda2..0492e29 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -71,6 +71,14 @@ def _read_manifest(path: str) -> AgentACPDescriptor: def _resolve_agent(name: str, path: str) -> AgentInfo: + if ":" not in path: + raise ValueError( + f"""Invalid format for AGENTS_REF environment variable. \ +Value must be a module:var pair. \ +Example: "agent1_module:agent1_var" or "path/to/file.py:agent1_var" +Got: {path}""" + ) + module_or_file, export_symbol = path.split(":", 1) if not os.path.isfile(module_or_file): # It's a module (name), try to import it diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index f1342da..75075f4 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -126,13 +126,13 @@ async def worker(worker_id: int): DB.update_run_info(run_id, run_info) try: - validate_output(run_id, run["agent_id"], last_message.data) log_run( worker_id, run_id, "got message", message_data=json.dumps(last_message.data), ) + validate_output(run_id, run["agent_id"], last_message.data) DB.add_run_output(run_id, last_message.data) await Runs.Stream.publish(run_id, Message(type="control", data="done")) if last_message.type == "interrupt":