Skip to content

Add support for interrupts (human-in-the-loop) - langgraph #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 2, 2025
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
### SERVER-SPECIFIC ENV ###
LOG_LEVEL=INFO
API_HOST=127.0.0.1
API_PORT=8000
AGENTS_REF='{"agent_uuid": "agent_module_name:agent_var"}'
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ The Agent Workflow Server (AgWS) enables participation in the Internet of Agents

## Getting Started

See [Agent Workflow Server Documentation](https://agntcy.github.io/workflow-srv/)
See [Agent Workflow Server Documentation](https://docs.agntcy.org/pages/agws/workflow_server)

## Contributing

See [Contributing](https://docs.agntcy.org/pages/agws/workflow_server#contributing)

Contributions are what make the open source community such an amazing place to
learn, inspire, and create. Any contributions you make are **greatly
appreciated**. For detailed contributing guidelines, please see
Expand Down
41 changes: 33 additions & 8 deletions src/agent_workflow_server/agents/adapters/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from typing import Optional

from langchain_core.runnables import RunnableConfig
from langgraph.constants import INTERRUPT
from langgraph.graph.graph import CompiledGraph, Graph
from langgraph.types import Command

from agent_workflow_server.agents.base import BaseAdapter, BaseAgent
from agent_workflow_server.storage.models import Config
from agent_workflow_server.services.message import Message
from agent_workflow_server.storage.models import Run


class LangGraphAdapter(BaseAdapter):
Expand All @@ -23,16 +26,38 @@ class LangGraphAgent(BaseAgent):
def __init__(self, agent: CompiledGraph):
self.agent = agent

async def astream(self, input: dict, config: Optional[Config]):
async def astream(self, run: Run):
input = run["input"]
config = run["config"]
if config is None:
config = {}
configurable = config.get("configurable")
if configurable is None:
configurable = {}
configurable.setdefault("thread_id", run["thread_id"])

# If there's an interrupt answer, ovverride the input
if "interrupt" in run and "user_data" in run["interrupt"]:
input = Command(resume=run["interrupt"]["user_data"])

async for event in self.agent.astream(
input=input,
config=RunnableConfig(
configurable=config["configurable"],
configurable=configurable,
tags=config["tags"],
recursion_limit=config["recursion_limit"],
)
if config
else None,
stream_mode="values",
),
):
yield event
for k, v in event.items():
if k == INTERRUPT:
yield Message(
type="interrupt",
event=k,
data=v[0].value,
)
else:
yield Message(
type="message",
event=k,
data=v,
)
16 changes: 12 additions & 4 deletions src/agent_workflow_server/agents/adapters/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from llama_index.core.workflow import Workflow

from agent_workflow_server.agents.base import BaseAdapter, BaseAgent
from agent_workflow_server.storage.models import Config
from agent_workflow_server.services.message import Message
from agent_workflow_server.storage.models import Run


class LlamaIndexAdapter(BaseAdapter):
Expand All @@ -25,9 +26,16 @@ class LlamaIndexAgent(BaseAgent):
def __init__(self, agent: Workflow):
self.agent = agent

async def astream(self, input: dict, config: Config):
async def astream(self, run: Run):
input = run["input"]
handler = self.agent.run(**input)
async for event in handler.stream_events():
yield event
yield Message(
type="message",
data=event,
)
final_result = await handler
yield final_result
yield Message(
type="message",
data=final_result,
)
13 changes: 6 additions & 7 deletions src/agent_workflow_server/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, Optional
from typing import Any, AsyncGenerator, Optional

from agent_workflow_server.storage.models import Config
from agent_workflow_server.services.message import Message
from agent_workflow_server.storage.models import Run


class BaseAgent(ABC):
@abstractmethod
async def astream(
self, input: Optional[Dict[str, Any]], config: Optional[Config]
) -> AsyncGenerator[Any, None]:
"""Invokes the agent with the given input and configuration and streams (returns) events asynchronously.
The last event includes the final result."""
async def astream(self, run: Run) -> AsyncGenerator[Message, None]:
"""Invokes the agent with the given `Run` and streams (returns) `Message`s asynchronously.
The last `Message` includes the final result."""
pass


Expand Down
32 changes: 25 additions & 7 deletions src/agent_workflow_server/agents/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,40 @@ def _load_adapters() -> List[BaseAdapter]:
def _read_manifest(path: str) -> AgentACPDescriptor:
if os.path.isfile(path):
with open(path, "r") as file:
manifest_data = json.load(file)
try:
manifest_data = json.load(file)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON format in manifest file: {path}. Error: {e}"
)
# print full path
logger.info(f"Loaded Agent Manifest from {os.path.abspath(path)}")
return AgentACPDescriptor(**manifest_data)


def _resolve_agent(name: str, path: str) -> AgentInfo:
if ":" not in path:
raise ValueError(
f"""Invalid format for AGENTS_REF environment variable. \
Value must be a module:var pair. \
Example: "agent1_module:agent1_var" or "path/to/file.py:agent1_var"
Got: {path}"""
)

module_or_file, export_symbol = path.split(":", 1)
if not os.path.isfile(module_or_file):
# It's a module (name), try to import it
module_name = module_or_file
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"""Failed to load agent module {module_name}. \
Check if it\'s installed and that module name in 'AGENTS_REF' env variable is correct."""
)
except ImportError as e:
if any(part in str(e) for part in module_name.split(".")):
raise ImportError(
f"""Failed to load agent module {module_name}. \
Check that it is installed and that the module name in 'AGENTS_REF' env variable is correct."""
) from e
else:
raise e
else:
# It's a path to a file, try to load it as a module
file = module_or_file
Expand Down Expand Up @@ -119,7 +135,9 @@ def _resolve_agent(name: str, path: str) -> AgentInfo:
if manifest:
break
else:
raise ImportError("Failed to load agent manifest")
raise ImportError(
f"Failed to load agent manifest from any of the paths: {manifest_paths}"
)

try:
schema = generate_agent_oapi(manifest, name)
Expand Down
23 changes: 19 additions & 4 deletions src/agent_workflow_server/apis/stateless_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from agent_workflow_server.generated.models.run_output import (
RunError,
RunInterrupt,
RunOutput,
RunResult,
)
Expand All @@ -40,8 +41,6 @@
validate_run_create as validate,
)

from ..utils.tools import serialize_to_dict

router = APIRouter()


Expand All @@ -56,6 +55,11 @@ async def _validate_run_create(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)


async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless:
Expand All @@ -70,10 +74,15 @@ async def _wait_and_return_run_output(run_id: str) -> RunWaitResponseStateless:
if run is None:
return Response(status_code=status.HTTP_404_NOT_FOUND)
if run.status == "success" and run_output is not None:
return RunWaitResponseStateless(
run=run,
output=RunOutput(RunResult(type="result", values=run_output)),
)
elif run.status == "interrupted":
return RunWaitResponseStateless(
run=run,
output=RunOutput(
RunResult(type="result", values=serialize_to_dict(run_output))
RunInterrupt(type="interrupt", interrupt={"default": run_output})
),
)
else:
Expand Down Expand Up @@ -244,7 +253,13 @@ async def resume_stateless_run(
body: Dict[str, Any] = Body(None, description=""),
) -> RunStateless:
"""Provide the needed input to a run to resume its execution. Can only be called for runs that are in the interrupted state Schema of the provided input must match with the schema specified in the agent specs under interrupts for the interrupt type the agent generated for this specific interruption."""
raise HTTPException(status_code=500, detail="Not implemented")
try:
return await Runs.resume(run_id, body)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)


@router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import os

from uvicorn.logging import ColourizedFormatter

Expand All @@ -12,3 +13,6 @@
handler.setFormatter(colorformatter)

CustomLoggerHandler = handler

LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=LOG_LEVEL, handlers=[CustomLoggerHandler], force=True)
4 changes: 1 addition & 3 deletions src/agent_workflow_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dotenv import load_dotenv
from fastapi import Depends, FastAPI

import agent_workflow_server.logging.logger # noqa: F401
from agent_workflow_server.agents.load import load_agents
from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter
from agent_workflow_server.apis.agents import router as AgentsApiRouter
Expand All @@ -20,16 +21,13 @@
setup_api_key_auth,
)
from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter
from agent_workflow_server.logger.custom_logger import CustomLoggerHandler
from agent_workflow_server.services.queue import start_workers

load_dotenv()

DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000

logging.basicConfig(level=logging.INFO, handlers=[CustomLoggerHandler], force=True)

logger = logging.getLogger(__name__)

app = FastAPI(
Expand Down
9 changes: 6 additions & 3 deletions src/agent_workflow_server/services/message.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from typing import Any, Literal, Optional

type MessageType = Literal["control", "message", "interrupt"]


class Message:
def __init__(self, topic: str, data: Any):
self.topic = topic
def __init__(self, type: MessageType, data: Any, event: Optional[str] = None):
self.type = type
self.data = data
self.event = event
Loading