Skip to content

Commit b510cfa

Browse files
replace _trace_manager variables with fastAPI dependency injection
1 parent cd6a0a2 commit b510cfa

9 files changed

Lines changed: 201 additions & 212 deletions

File tree

src/agentevals/api/app.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from ..utils.log_buffer import log_buffer
1717
from .debug_routes import debug_router
18-
from .debug_routes import set_trace_manager as set_debug_trace_manager
1918
from .routes import router
2019

2120
try:
@@ -42,11 +41,12 @@ async def lifespan(app: FastAPI):
4241
if log_buffer not in ae_logger.handlers:
4342
log_buffer.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
4443
ae_logger.addHandler(log_buffer)
45-
if _trace_manager:
46-
_trace_manager.start_cleanup_task()
44+
mgr = getattr(app.state, "trace_manager", None)
45+
if mgr:
46+
mgr.start_cleanup_task()
4747
yield
48-
if _trace_manager:
49-
await _trace_manager.shutdown()
48+
if mgr:
49+
await mgr.shutdown()
5050
ae_logger.removeHandler(log_buffer)
5151

5252

@@ -70,27 +70,27 @@ async def lifespan(app: FastAPI):
7070
app.include_router(debug_router, prefix="/api/debug")
7171

7272
_live_mode = os.getenv("AGENTEVALS_LIVE") == "1"
73-
_trace_manager = None
7473

7574
if _live_mode:
75+
from fastapi import Request as _Request
7676
from fastapi import WebSocket
7777

7878
from ..streaming.ws_server import StreamingTraceManager
79-
from .streaming_routes import set_trace_manager, streaming_router
79+
from .streaming_routes import streaming_router
8080

8181
app.include_router(streaming_router, prefix="/api/streaming")
82-
_trace_manager = StreamingTraceManager()
83-
set_trace_manager(_trace_manager)
84-
set_debug_trace_manager(_trace_manager)
82+
app.state.trace_manager = StreamingTraceManager()
8583

8684
@app.websocket("/ws/traces")
8785
async def websocket_endpoint(websocket: WebSocket):
88-
await _trace_manager.handle_connection(websocket)
86+
await websocket.app.state.trace_manager.handle_connection(websocket)
8987

9088
@app.get("/stream/ui-updates")
91-
async def ui_updates_stream():
89+
async def ui_updates_stream(request: _Request):
90+
mgr = request.app.state.trace_manager
91+
9292
async def event_generator():
93-
queue = _trace_manager.register_sse_client()
93+
queue = mgr.register_sse_client()
9494
try:
9595
while True:
9696
event = await queue.get()
@@ -100,7 +100,7 @@ async def event_generator():
100100
except asyncio.CancelledError:
101101
pass
102102
finally:
103-
_trace_manager.unregister_sse_client(queue)
103+
mgr.unregister_sse_client(queue)
104104

105105
return StreamingResponse(
106106
event_generator(),
@@ -112,10 +112,6 @@ async def event_generator():
112112
)
113113

114114

115-
def get_trace_manager():
116-
return _trace_manager
117-
118-
119115
_static_dir = Path(__file__).parent.parent / "_static"
120116
_has_ui = _static_dir.is_dir() and (_static_dir / "index.html").exists()
121117

src/agentevals/api/debug_routes.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from datetime import UTC, datetime
1414
from typing import TYPE_CHECKING
1515

16-
from fastapi import APIRouter, HTTPException, UploadFile
16+
from fastapi import APIRouter, Depends, HTTPException, UploadFile
1717
from fastapi import File as FastAPIFile
1818
from fastapi.responses import StreamingResponse
1919
from pydantic import BaseModel
2020

2121
from agentevals import __version__
2222

2323
from ..utils.log_buffer import log_buffer
24+
from .dependencies import get_trace_manager, require_trace_manager
2425
from .models import DebugLoadData, SessionInfo, StandardResponse, WSSessionCompleteEvent, WSSessionStartedEvent
2526

2627
if TYPE_CHECKING:
@@ -30,13 +31,6 @@
3031

3132
debug_router = APIRouter()
3233

33-
_trace_manager: StreamingTraceManager | None = None
34-
35-
36-
def set_trace_manager(manager: StreamingTraceManager) -> None:
37-
global _trace_manager
38-
_trace_manager = manager
39-
4034

4135
class FrontendDiagnostics(BaseModel):
4236
user_description: str = ""
@@ -83,12 +77,12 @@ def _collect_environment() -> dict:
8377
}
8478

8579

86-
def _collect_sessions() -> list[dict]:
87-
if not _trace_manager:
80+
def _collect_sessions(manager: StreamingTraceManager | None) -> list[dict]:
81+
if not manager:
8882
return []
8983

9084
sessions_data = []
91-
for session in _trace_manager.sessions.values():
85+
for session in manager.sessions.values():
9286
sessions_data.append(
9387
{
9488
"session_id": session.session_id,
@@ -128,7 +122,10 @@ def _collect_temp_files(session_ids: set[str] | None = None) -> dict[str, str]:
128122

129123

130124
@debug_router.post("/bundle")
131-
async def create_debug_bundle(diagnostics: FrontendDiagnostics):
125+
async def create_debug_bundle(
126+
diagnostics: FrontendDiagnostics,
127+
manager: StreamingTraceManager | None = Depends(get_trace_manager),
128+
):
132129
timestamp = datetime.now(tz=UTC).strftime("%Y%m%d-%H%M%S")
133130
prefix = f"bug-report-{timestamp}"
134131

@@ -142,7 +139,7 @@ async def create_debug_bundle(diagnostics: FrontendDiagnostics):
142139
}
143140
zf.writestr(f"{prefix}/metadata.json", json.dumps(metadata, indent=2))
144141

145-
sessions = _collect_sessions()
142+
sessions = _collect_sessions(manager)
146143
for s in sessions:
147144
sid = s["session_id"]
148145
zf.writestr(
@@ -188,13 +185,10 @@ async def create_debug_bundle(diagnostics: FrontendDiagnostics):
188185

189186

190187
@debug_router.post("/load", response_model=StandardResponse[DebugLoadData])
191-
async def load_debug_bundle(file: UploadFile = FastAPIFile(...)):
192-
if not _trace_manager:
193-
raise HTTPException(
194-
status_code=400,
195-
detail="Live mode is not enabled. Start with: agentevals serve --dev",
196-
)
197-
188+
async def load_debug_bundle(
189+
file: UploadFile = FastAPIFile(...),
190+
manager: StreamingTraceManager = Depends(require_trace_manager),
191+
):
198192
content = await file.read()
199193
try:
200194
zf = zipfile.ZipFile(io.BytesIO(content))
@@ -236,9 +230,9 @@ async def load_debug_bundle(file: UploadFile = FastAPIFile(...)):
236230
metadata=meta.get("metadata", {}),
237231
)
238232

239-
_trace_manager.sessions[session.session_id] = session
233+
manager.sessions[session.session_id] = session
240234

241-
await _trace_manager.broadcast_to_ui(
235+
await manager.broadcast_to_ui(
242236
WSSessionStartedEvent(
243237
session=SessionInfo(
244238
session_id=session.session_id,
@@ -252,10 +246,10 @@ async def load_debug_bundle(file: UploadFile = FastAPIFile(...)):
252246
).model_dump(by_alias=True)
253247
)
254248

255-
invocations_data = await _trace_manager._extract_invocations(session)
256-
await _trace_manager._save_spans_to_temp_file(session)
249+
invocations_data = await manager._extract_invocations(session)
250+
await manager._save_spans_to_temp_file(session)
257251

258-
await _trace_manager.broadcast_to_ui(
252+
await manager.broadcast_to_ui(
259253
WSSessionCompleteEvent(
260254
session_id=session.session_id,
261255
invocations=invocations_data,

src/agentevals/api/dependencies.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""FastAPI dependency functions for shared services."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from fastapi import HTTPException, Request
8+
9+
if TYPE_CHECKING:
10+
from ..streaming.ws_server import StreamingTraceManager
11+
12+
13+
def get_trace_manager(request: Request) -> StreamingTraceManager | None:
14+
"""Return the StreamingTraceManager or None if live mode is off."""
15+
return getattr(request.app.state, "trace_manager", None)
16+
17+
18+
def require_trace_manager(request: Request) -> StreamingTraceManager:
19+
"""Return the StreamingTraceManager, raising 503 if live mode is off."""
20+
mgr = getattr(request.app.state, "trace_manager", None)
21+
if mgr is None:
22+
raise HTTPException(status_code=503, detail="Live mode not enabled")
23+
return mgr

src/agentevals/api/otlp_app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
from fastapi import FastAPI
1010

11-
from .otlp_routes import otlp_router, set_trace_manager
11+
from .otlp_routes import otlp_router
1212

1313

1414
@asynccontextmanager
1515
async def lifespan(app: FastAPI):
16-
from .app import get_trace_manager
16+
from .app import app as main_app
1717

18-
mgr = get_trace_manager()
18+
mgr = getattr(main_app.state, "trace_manager", None)
1919
if mgr:
20-
set_trace_manager(mgr)
20+
app.state.trace_manager = mgr
2121
yield
2222

2323

0 commit comments

Comments
 (0)