diff --git a/.env.example b/.env.example index c5bafb6..7208eb4 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,5 @@ ### SERVER-SPECIFIC ENV ### +LOG_LEVEL=INFO API_HOST=127.0.0.1 API_PORT=8000 AGENTS_REF='{"agent_uuid": "agent_module_name:agent_var"}' diff --git a/README.md b/README.md index 77e6107..b195299 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,12 @@ The Agent Workflow Server (AgWS) enables participation in the Internet of Agents ## Getting Started -See [Agent Workflow Server Documentation](https://agntcy.github.io/workflow-srv/) +See [Agent Workflow Server Documentation](https://docs.agntcy.org/pages/agws/workflow_server) ## Contributing +See [Contributing](https://docs.agntcy.org/pages/agws/workflow_server#contributing) + Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. For detailed contributing guidelines, please see diff --git a/src/agent_workflow_server/agents/adapters/langgraph.py b/src/agent_workflow_server/agents/adapters/langgraph.py index 626f71c..ad18392 100644 --- a/src/agent_workflow_server/agents/adapters/langgraph.py +++ b/src/agent_workflow_server/agents/adapters/langgraph.py @@ -4,10 +4,13 @@ from typing import Optional from langchain_core.runnables import RunnableConfig +from langgraph.constants import INTERRUPT from langgraph.graph.graph import CompiledGraph, Graph +from langgraph.types import Command from agent_workflow_server.agents.base import BaseAdapter, BaseAgent -from agent_workflow_server.storage.models import Config +from agent_workflow_server.services.message import Message +from agent_workflow_server.storage.models import Run class LangGraphAdapter(BaseAdapter): @@ -23,16 +26,38 @@ class LangGraphAgent(BaseAgent): def __init__(self, agent: CompiledGraph): self.agent = agent - async def astream(self, input: dict, config: Optional[Config]): + async def astream(self, run: Run): + input = run["input"] + config = run["config"] + if config is None: + config = {} + configurable = config.get("configurable") + if configurable is None: + configurable = {} + configurable.setdefault("thread_id", run["thread_id"]) + + # If there's an interrupt answer, ovverride the input + if "interrupt" in run and "user_data" in run["interrupt"]: + input = Command(resume=run["interrupt"]["user_data"]) + async for event in self.agent.astream( input=input, config=RunnableConfig( - configurable=config["configurable"], + configurable=configurable, tags=config["tags"], recursion_limit=config["recursion_limit"], - ) - if config - else None, - stream_mode="values", + ), ): - yield event + for k, v in event.items(): + if k == INTERRUPT: + yield Message( + type="interrupt", + event=k, + data=v[0].value, + ) + else: + yield Message( + type="message", + event=k, + data=v, + ) diff --git a/src/agent_workflow_server/agents/adapters/llamaindex.py b/src/agent_workflow_server/agents/adapters/llamaindex.py index 0483e4a..036f70d 100644 --- a/src/agent_workflow_server/agents/adapters/llamaindex.py +++ b/src/agent_workflow_server/agents/adapters/llamaindex.py @@ -7,7 +7,8 @@ from llama_index.core.workflow import Workflow from agent_workflow_server.agents.base import BaseAdapter, BaseAgent -from agent_workflow_server.storage.models import Config +from agent_workflow_server.services.message import Message +from agent_workflow_server.storage.models import Run class LlamaIndexAdapter(BaseAdapter): @@ -25,9 +26,16 @@ class LlamaIndexAgent(BaseAgent): def __init__(self, agent: Workflow): self.agent = agent - async def astream(self, input: dict, config: Config): + async def astream(self, run: Run): + input = run["input"] handler = self.agent.run(**input) async for event in handler.stream_events(): - yield event + yield Message( + type="message", + data=event, + ) final_result = await handler - yield final_result + yield Message( + type="message", + data=final_result, + ) diff --git a/src/agent_workflow_server/agents/base.py b/src/agent_workflow_server/agents/base.py index 2f315f8..d7c549b 100644 --- a/src/agent_workflow_server/agents/base.py +++ b/src/agent_workflow_server/agents/base.py @@ -2,18 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Optional -from agent_workflow_server.storage.models import Config +from agent_workflow_server.services.message import Message +from agent_workflow_server.storage.models import Run class BaseAgent(ABC): @abstractmethod - async def astream( - self, input: Optional[Dict[str, Any]], config: Optional[Config] - ) -> AsyncGenerator[Any, None]: - """Invokes the agent with the given input and configuration and streams (returns) events asynchronously. - The last event includes the final result.""" + async def astream(self, run: Run) -> AsyncGenerator[Message, None]: + """Invokes the agent with the given `Run` and streams (returns) `Message`s asynchronously. + The last `Message` includes the final result.""" pass diff --git a/src/agent_workflow_server/agents/load.py b/src/agent_workflow_server/agents/load.py index c58e366..0492e29 100644 --- a/src/agent_workflow_server/agents/load.py +++ b/src/agent_workflow_server/agents/load.py @@ -59,24 +59,40 @@ def _load_adapters() -> List[BaseAdapter]: def _read_manifest(path: str) -> AgentACPDescriptor: if os.path.isfile(path): with open(path, "r") as file: - manifest_data = json.load(file) + try: + manifest_data = json.load(file) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON format in manifest file: {path}. Error: {e}" + ) # print full path logger.info(f"Loaded Agent Manifest from {os.path.abspath(path)}") return AgentACPDescriptor(**manifest_data) def _resolve_agent(name: str, path: str) -> AgentInfo: + if ":" not in path: + raise ValueError( + f"""Invalid format for AGENTS_REF environment variable. \ +Value must be a module:var pair. \ +Example: "agent1_module:agent1_var" or "path/to/file.py:agent1_var" +Got: {path}""" + ) + module_or_file, export_symbol = path.split(":", 1) if not os.path.isfile(module_or_file): # It's a module (name), try to import it module_name = module_or_file try: module = importlib.import_module(module_name) - except ModuleNotFoundError: - raise ModuleNotFoundError( - f"""Failed to load agent module {module_name}. \ -Check if it\'s installed and that module name in 'AGENTS_REF' env variable is correct.""" - ) + except ImportError as e: + if any(part in str(e) for part in module_name.split(".")): + raise ImportError( + f"""Failed to load agent module {module_name}. \ +Check that it is installed and that the module name in 'AGENTS_REF' env variable is correct.""" + ) from e + else: + raise e else: # It's a path to a file, try to load it as a module file = module_or_file @@ -119,7 +135,9 @@ def _resolve_agent(name: str, path: str) -> AgentInfo: if manifest: break else: - raise ImportError("Failed to load agent manifest") + raise ImportError( + f"Failed to load agent manifest from any of the paths: {manifest_paths}" + ) try: schema = generate_agent_oapi(manifest, name) diff --git a/src/agent_workflow_server/apis/stateless_runs.py b/src/agent_workflow_server/apis/stateless_runs.py index b244dc9..5a6ae34 100644 --- a/src/agent_workflow_server/apis/stateless_runs.py +++ b/src/agent_workflow_server/apis/stateless_runs.py @@ -23,6 +23,7 @@ ) from agent_workflow_server.generated.models.run_output import ( RunError, + RunInterrupt, RunOutput, RunResult, ) @@ -40,8 +41,6 @@ validate_run_create as validate, ) -from ..utils.tools import serialize_to_dict - router = APIRouter() @@ -56,6 +55,11 @@ async def _validate_run_create( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e), ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: @@ -70,10 +74,15 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless: if run is None: return Response(status_code=status.HTTP_404_NOT_FOUND) if run.status == "success" and run_output is not None: + return RunWaitResponseStateless( + run=run, + output=RunOutput(RunResult(type="result", values=run_output)), + ) + elif run.status == "interrupted": return RunWaitResponseStateless( run=run, output=RunOutput( - RunResult(type="result", values=serialize_to_dict(run_output)) + RunInterrupt(type="interrupt", interrupt={"default": run_output}) ), ) else: @@ -244,7 +253,13 @@ async def resume_stateless_run( body: Dict[str, Any] = Body(None, description=""), ) -> RunStateless: """Provide the needed input to a run to resume its execution. Can only be called for runs that are in the interrupted state Schema of the provided input must match with the schema specified in the agent specs under interrupts for the interrupt type the agent generated for this specific interruption.""" - raise HTTPException(status_code=500, detail="Not implemented") + try: + return await Runs.resume(run_id, body) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) @router.post( diff --git a/src/agent_workflow_server/logger/custom_logger.py b/src/agent_workflow_server/logging/logger.py similarity index 71% rename from src/agent_workflow_server/logger/custom_logger.py rename to src/agent_workflow_server/logging/logger.py index 92c6f78..d6d3dde 100644 --- a/src/agent_workflow_server/logger/custom_logger.py +++ b/src/agent_workflow_server/logging/logger.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import os from uvicorn.logging import ColourizedFormatter @@ -12,3 +13,6 @@ handler.setFormatter(colorformatter) CustomLoggerHandler = handler + +LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() +logging.basicConfig(level=LOG_LEVEL, handlers=[CustomLoggerHandler], force=True) diff --git a/src/agent_workflow_server/main.py b/src/agent_workflow_server/main.py index 791456a..5391fd3 100644 --- a/src/agent_workflow_server/main.py +++ b/src/agent_workflow_server/main.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv from fastapi import Depends, FastAPI +import agent_workflow_server.logging.logger # noqa: F401 from agent_workflow_server.agents.load import load_agents from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter from agent_workflow_server.apis.agents import router as AgentsApiRouter @@ -20,7 +21,6 @@ setup_api_key_auth, ) from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter -from agent_workflow_server.logger.custom_logger import CustomLoggerHandler from agent_workflow_server.services.queue import start_workers load_dotenv() @@ -28,8 +28,6 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 -logging.basicConfig(level=logging.INFO, handlers=[CustomLoggerHandler], force=True) - logger = logging.getLogger(__name__) app = FastAPI( diff --git a/src/agent_workflow_server/services/message.py b/src/agent_workflow_server/services/message.py index 27f048c..d9c5269 100644 --- a/src/agent_workflow_server/services/message.py +++ b/src/agent_workflow_server/services/message.py @@ -1,10 +1,13 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Literal, Optional + +type MessageType = Literal["control", "message", "interrupt"] class Message: - def __init__(self, topic: str, data: Any): - self.topic = topic + def __init__(self, type: MessageType, data: Any, event: Optional[str] = None): + self.type = type self.data = data + self.event = event diff --git a/src/agent_workflow_server/services/queue.py b/src/agent_workflow_server/services/queue.py index ca1fd9e..75075f4 100644 --- a/src/agent_workflow_server/services/queue.py +++ b/src/agent_workflow_server/services/queue.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import logging -import traceback from datetime import datetime from typing import Literal @@ -11,8 +11,9 @@ InvalidFormatException, validate_output, ) -from agent_workflow_server.storage.models import RunInfo +from agent_workflow_server.storage.models import Interrupt, RunInfo from agent_workflow_server.storage.storage import DB +from agent_workflow_server.utils.tools import make_serializable from .message import Message from .runs import RUNS_QUEUE, Runs @@ -43,13 +44,22 @@ async def start_workers(n_workers: int): def log_run( worker_id: int, run_id: str, - info: Literal["started", "succeeded", "failed", "exeeded attempts"], + info: Literal[ + "got message", + "started", + "interrupted", + "succeeded", + "failed", + "exeeded attempts", + ], **kwargs, ): log_methods = { + "got message": logger.debug, "started": logger.info, + "interrupted": logger.info, "succeeded": logger.info, - "failed": logger.warning, + "failed": logger.exception, "exeeded attempts": logger.error, } log_message = f"(Worker {worker_id}) Background Run {run_id} {info}" @@ -90,11 +100,20 @@ async def worker(worker_id: int): stream = stream_run(run) last_message = None async for message in stream: + message.data = make_serializable(message.data) await Runs.Stream.publish(run_id, message) last_message = message + if last_message.type == "interrupt": + log_run( + worker_id, + run_id, + "interrupted", + message_data=json.dumps(message.data), + ) + break except Exception as error: await Runs.Stream.publish( - run_id, Message(topic="error", data=(str(error))) + run_id, Message(type="message", data=(str(error))) ) raise RunError(error) @@ -102,21 +121,33 @@ async def worker(worker_id: int): run_info["ended_at"] = ended_at run_info["exec_s"] = ended_at - started_at - run_info["queue_s"] = started_at - run["created_at"].timestamp() + run_info["queue_s"] = started_at - run_info["queued_at"].timestamp() DB.update_run_info(run_id, run_info) try: + log_run( + worker_id, + run_id, + "got message", + message_data=json.dumps(last_message.data), + ) validate_output(run_id, run["agent_id"], last_message.data) - DB.add_run_output(run_id, last_message.data) - await Runs.Stream.publish(run_id, Message(topic="control", data="done")) - await Runs.set_status(run_id, "success") + await Runs.Stream.publish(run_id, Message(type="control", data="done")) + if last_message.type == "interrupt": + interrupt = Interrupt( + event=last_message.event, ai_data=last_message.data + ) + DB.update_run(run_id, {"interrupt": interrupt}) + await Runs.set_status(run_id, "interrupted") + else: + await Runs.set_status(run_id, "success") log_run(worker_id, run_id, "succeeded", **run_stats(run_info)) except InvalidFormatException as error: await Runs.Stream.publish( - run_id, Message(topic="error", data=str(error)) + run_id, Message(type="message", data=str(error)) ) log_run(worker_id, run_id, "failed") raise RunError(str(error)) @@ -127,7 +158,7 @@ async def worker(worker_id: int): { "ended_at": ended_at, "exec_s": ended_at - started_at, - "queue_s": (started_at - run["created_at"].timestamp()), + "queue_s": (started_at - run_info["queued_at"].timestamp()), } ) @@ -141,7 +172,7 @@ async def worker(worker_id: int): { "ended_at": ended_at, "exec_s": ended_at - started_at, - "queue_s": (started_at - run["created_at"].timestamp()), + "queue_s": (started_at - run_info["queued_at"].timestamp()), } ) @@ -154,7 +185,6 @@ async def worker(worker_id: int): "failed", **{"error": error, **run_stats(run_info)}, ) - traceback.print_exc() await RUNS_QUEUE.put(run_id) # Re-queue for retry diff --git a/src/agent_workflow_server/services/runs.py b/src/agent_workflow_server/services/runs.py index 3d4a349..7d6cf89 100644 --- a/src/agent_workflow_server/services/runs.py +++ b/src/agent_workflow_server/services/runs.py @@ -6,7 +6,7 @@ from collections import defaultdict from datetime import datetime from itertools import islice -from typing import AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import uuid4 from agent_workflow_server.generated.models.run_create_stateless import ( @@ -122,6 +122,7 @@ async def put(run_create: ApiRunCreate) -> ApiRun: new_run = _make_run(run_create) run_info = RunInfo( run_id=new_run["run_id"], + queued_at=datetime.now(), attempts=0, ) DB.create_run(new_run) @@ -172,6 +173,26 @@ def search(search_request: RunSearchRequest) -> List[ApiRun]: islice(islice(runs, search_request.offset, None), search_request.limit) ) + @staticmethod + async def resume(run_id: str, user_input: Dict[str, Any]) -> ApiRun: + run = DB.get_run(run_id) + if run is None: + raise ValueError("Run not found") + if run["status"] != "interrupted": + raise ValueError("Run is not in interrupted state") + if run.get("interrupt") is None: + raise ValueError(f"No interrupt found for run {run_id}") + + interrupt = run["interrupt"] + interrupt["user_data"] = user_input + + DB.update_run(run_id, {"interrupt": interrupt}) + DB.update_run_info(run_id, {"attempts": 0, "queued_at": datetime.now()}) + updated = DB.update_run_status(run_id, "pending") + + await RUNS_QUEUE.put(updated["run_id"]) + return _to_api_model(updated) + @staticmethod async def set_status(run_id: str, status: RunStatus): run = DB.get_run(run_id) diff --git a/src/agent_workflow_server/services/stream.py b/src/agent_workflow_server/services/stream.py index c51de27..d803799 100644 --- a/src/agent_workflow_server/services/stream.py +++ b/src/agent_workflow_server/services/stream.py @@ -10,9 +10,7 @@ async def stream_run(run: Run) -> AsyncGenerator[Message, None]: - agent = get_agent_info(run["agent_id"]).agent - async for event in agent.astream(input=run["input"], config=run["config"]): - yield Message( - topic="message", - data=event, - ) + agent_info = get_agent_info(run["agent_id"]) + agent = agent_info.agent + async for message in agent.astream(run=run): + yield message diff --git a/src/agent_workflow_server/services/validation.py b/src/agent_workflow_server/services/validation.py index 164e98f..1af24d4 100644 --- a/src/agent_workflow_server/services/validation.py +++ b/src/agent_workflow_server/services/validation.py @@ -34,6 +34,8 @@ def validate_against_schema( def get_agent_schemas(agent_id: str): """Get input, output and config schemas for an agent""" agent_info = AGENTS.get(agent_id) + if not agent_info: + raise ValueError(f"Agent {agent_id} not found") specs = agent_info.manifest.specs return {"input": specs.input, "output": specs.output, "config": specs.config} @@ -45,7 +47,7 @@ def validate_output(run_id, agent_id: str, output: Any) -> None: validate_against_schema( instance=output, - schema=schemas["output"].get("properties", schemas["output"]), + schema=schemas["output"], error_prefix=f"Output validation failed for run {run_id}", ) diff --git a/src/agent_workflow_server/storage/models.py b/src/agent_workflow_server/storage/models.py index ee0cc95..b281e02 100644 --- a/src/agent_workflow_server/storage/models.py +++ b/src/agent_workflow_server/storage/models.py @@ -7,6 +7,14 @@ RunStatus = Literal["pending", "error", "success", "timeout", "interrupted"] +class Interrupt(TypedDict): + """Definition for an Interrupt message""" + + event: str + ai_data: Any + user_data: Optional[Any] + + class Config(TypedDict): tags: Optional[List[str]] recursion_limit: Optional[int] @@ -25,12 +33,14 @@ class Run(TypedDict): created_at: datetime updated_at: datetime status: RunStatus + interrupt: Optional[Interrupt] # last interrupt (if any) class RunInfo(TypedDict): """Definition of statistics information about a Run""" run_id: str + queued_at: datetime attempts: Optional[int] started_at: Optional[datetime] ended_at: Optional[datetime] diff --git a/src/agent_workflow_server/storage/service.py b/src/agent_workflow_server/storage/service.py index 1e82861..e08f581 100644 --- a/src/agent_workflow_server/storage/service.py +++ b/src/agent_workflow_server/storage/service.py @@ -74,9 +74,9 @@ def get_run_status(self, run_id: str) -> Optional[RunStatus]: run = self.get_run(run_id) return run.get("status") if run else None - def update_run_status(self, run_id: str, status: RunStatus) -> None: + def update_run_status(self, run_id: str, status: RunStatus) -> Optional[Run]: """Update the status of a Run""" - self.update_run(run_id, {"status": status}) + return self.update_run(run_id, {"status": status}) def add_run_output(self, run_id: str, output: Any) -> None: """Add the output of a Run""" diff --git a/src/agent_workflow_server/storage/storage.py b/src/agent_workflow_server/storage/storage.py index 31f5738..7f5713b 100644 --- a/src/agent_workflow_server/storage/storage.py +++ b/src/agent_workflow_server/storage/storage.py @@ -9,13 +9,11 @@ from dotenv import load_dotenv -from agent_workflow_server.logger.custom_logger import CustomLoggerHandler +import agent_workflow_server.logging.logger # noqa: F401 from .models import Run, RunInfo from .service import DBOperations -logging.basicConfig(level=logging.INFO, handlers=[CustomLoggerHandler]) - logger = logging.getLogger(__name__) load_dotenv() diff --git a/src/agent_workflow_server/utils/tools.py b/src/agent_workflow_server/utils/tools.py index 2528961..db42aac 100644 --- a/src/agent_workflow_server/utils/tools.py +++ b/src/agent_workflow_server/utils/tools.py @@ -3,6 +3,7 @@ import uuid from enum import Enum +from typing import Any from pydantic import BaseModel @@ -15,13 +16,17 @@ def is_valid_uuid(val): return False -def serialize_to_dict(v): +def make_serializable(v: Any): if isinstance(v, list): - return [serialize_to_dict(vv) for vv in v] + return [make_serializable(vv) for vv in v] elif isinstance(v, dict): - return {kk: serialize_to_dict(vv) for kk, vv in v.items()} - elif isinstance(v, BaseModel): - return v.model_dump() + return {kk: make_serializable(vv) for kk, vv in v.items()} + elif ( + isinstance(v, BaseModel) and hasattr(v, "model_dump") and callable(v.model_dump) + ): + return v.model_dump(mode="json") + elif isinstance(v, BaseModel) and hasattr(v, "dict") and callable(v.dict): + return v.dict() elif isinstance(v, Enum): return v.value else: diff --git a/tests/mock.py b/tests/mock.py index 6f3c7d2..f9a6466 100644 --- a/tests/mock.py +++ b/tests/mock.py @@ -2,8 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +from typing import AsyncGenerator from agent_workflow_server.agents.base import BaseAdapter, BaseAgent +from agent_workflow_server.services.message import Message +from agent_workflow_server.storage.models import Run # Make sure that this and the one in the .env.test file are the same MOCK_AGENT_ID = "3f1e2549-5799-4321-91ae-2a4881d55526" @@ -11,6 +14,10 @@ MOCK_RUN_INPUT = {"message": "What's the color of the sky?"} MOCK_RUN_OUTPUT = {"message": "The color of the sky is blue"} +MOCK_RUN_INPUT_INTERRUPT = {"message": "Please interrupt"} +MOCK_RUN_EVENT_INTERRUPT = "__mock_interrupt__" +MOCK_RUN_OUTPUT_INTERRUPT = {"message": "How can I help you?"} + class MockAgentImpl: ... @@ -19,9 +26,20 @@ class MockAgent(BaseAgent): def __init__(self, agent: MockAgentImpl): self.agent = agent - async def astream(self, input: dict, config: dict): - await asyncio.sleep(3) - yield MOCK_RUN_OUTPUT + async def astream(self, run: Run) -> AsyncGenerator[Message, None]: + if run["input"] == MOCK_RUN_INPUT or ( + run.get("interrupt") is not None + and run["interrupt"].get("user_data") == MOCK_RUN_INPUT + ): + await asyncio.sleep(3) + yield Message(type="message", data=MOCK_RUN_OUTPUT) + return + if run["input"] == MOCK_RUN_INPUT_INTERRUPT: + yield Message( + type="interrupt", + event=MOCK_RUN_EVENT_INTERRUPT, + data=MOCK_RUN_OUTPUT_INTERRUPT, + ) class MockAdapter(BaseAdapter): diff --git a/tests/test_runs.py b/tests/test_runs.py index d92d620..5999842 100644 --- a/tests/test_runs.py +++ b/tests/test_runs.py @@ -9,24 +9,54 @@ from pytest_mock import MockerFixture from agent_workflow_server.agents.load import load_agents +from agent_workflow_server.generated.models.config import Config from agent_workflow_server.generated.models.run_search_request import ( RunSearchRequest, ) from agent_workflow_server.services.queue import start_workers from agent_workflow_server.services.runs import ApiRun, ApiRunCreate, Runs +from agent_workflow_server.storage.models import RunStatus from agent_workflow_server.storage.storage import DB from tests.mock import ( MOCK_AGENT_ID, MOCK_RUN_INPUT, + MOCK_RUN_INPUT_INTERRUPT, MOCK_RUN_OUTPUT, + MOCK_RUN_OUTPUT_INTERRUPT, MockAdapter, ) -run_create_mock = ApiRunCreate(agent_id=MOCK_AGENT_ID, input=MOCK_RUN_INPUT, config={}) - @pytest.mark.asyncio -async def test_invoke(mocker: MockerFixture): +@pytest.mark.parametrize( + "run_create_mock, expected_status, expected_output", + [ + ( + ApiRunCreate( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + config=Config( + tags=["test"], + recursion_limit=3, + configurable={"mock-key": "mock-value"}, + ), + ), + "success", + MOCK_RUN_OUTPUT, + ), + ( + ApiRunCreate(agent_id=MOCK_AGENT_ID, input=MOCK_RUN_INPUT_INTERRUPT), + "interrupted", + MOCK_RUN_OUTPUT_INTERRUPT, + ), + ], +) +async def test_invoke( + mocker: MockerFixture, + run_create_mock: ApiRunCreate, + expected_status: RunStatus, + expected_output: dict, +): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) try: @@ -46,9 +76,12 @@ async def test_invoke(mocker: MockerFixture): else: assert True - assert run.status == "success" + assert run.status == expected_status assert run.run_id == new_run.run_id - assert output == MOCK_RUN_OUTPUT + assert run.agent_id == run_create_mock.agent_id + assert run.creation.agent_id == run_create_mock.agent_id + assert run.creation.config == run_create_mock.config + assert output == expected_output finally: worker_task.cancel() try: @@ -62,6 +95,8 @@ async def test_invoke(mocker: MockerFixture): async def test_invoke_timeout(mocker: MockerFixture, timeout: float): mocker.patch("agent_workflow_server.agents.load.ADAPTERS", [MockAdapter()]) + run_create_mock = ApiRunCreate(agent_id=MOCK_AGENT_ID, input=MOCK_RUN_INPUT) + try: load_agents()