forked from agntcy/workflow-srv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
109 lines (89 loc) · 3.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import os
import pathlib
import signal
import sys
import uvicorn
from dotenv import find_dotenv, load_dotenv
from fastapi import Depends, FastAPI
from agent_workflow_server.agents.load import load_agents
from agent_workflow_server.apis.agents import public_router as PublicAgentsApiRouter
from agent_workflow_server.apis.agents import router as AgentsApiRouter
from agent_workflow_server.apis.authentication import (
authentication_with_api_key,
setup_api_key_auth,
)
from agent_workflow_server.apis.stateless_runs import router as StatelessRunsApiRouter
from agent_workflow_server.services.queue import start_workers
load_dotenv(dotenv_path=find_dotenv(usecwd=True))
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
logger = logging.getLogger(__name__)
app = FastAPI(
title="Agent Workflow Server",
version="0.1",
)
setup_api_key_auth(app)
app.include_router(
router=AgentsApiRouter,
dependencies=[Depends(authentication_with_api_key)],
)
app.include_router(
router=PublicAgentsApiRouter,
)
app.include_router(
router=StatelessRunsApiRouter,
dependencies=[Depends(authentication_with_api_key)],
)
def signal_handler(sig, frame):
logger.warning(f"Received {signal.Signals(sig).name}. Exiting...")
sys.exit(0)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Agent Workflow Server")
parser.add_argument("--host", default=os.getenv("API_HOST", DEFAULT_HOST))
parser.add_argument(
"--port", type=int, default=int(os.getenv("API_PORT", DEFAULT_PORT))
)
parser.add_argument(
"--num-workers", type=int, default=int(os.environ.get("NUM_WORKERS", 5))
)
parser.add_argument(
"--agent-manifest-path",
action="append",
type=pathlib.Path,
default=[os.getenv("AGENT_MANIFEST_PATH", "manifest.json")],
)
parser.add_argument("--agents-ref", default=os.getenv("AGENTS_REF", None))
parser.add_argument(
"--log-level", default=os.environ.get("NUM_WORKERS", logging.INFO)
)
return parser.parse_args()
def start():
try:
args = parse_args()
logging.basicConfig(level=args.log_level.upper())
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
load_agents(args.agents_ref, args.agent_manifest_path)
n_workers = args.num_workers
loop = asyncio.get_event_loop()
loop.create_task(start_workers(n_workers))
# use module import method to support reload argument
config = uvicorn.Config(
"agent_workflow_server.main:app",
host=args.host,
port=args.port,
loop="asyncio",
)
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
except SystemExit as e:
logger.warning(f"Agent Workflow Server exited with code: {e}")
except Exception as e:
logger.error(f"Exiting due to an unexpected error: {e}")
if __name__ == "__main__":
start()