Skip to content

Commit b0b884e

Browse files
authored
Refactor API endpoints and implement JWT authentication (#11)
- Removed the chat endpoint and integrated AG-UI WebSocket routing. - Added JWT authentication for conversation initialization, allowing secure WebSocket connections. - Updated requirements to include PyJWT for token management. - Cleaned up unused CopilotKit integration code and adjusted conversation service to persist messages in the database. This refactor enhances the security and organization of the API while streamlining the conversation management process.
1 parent 3392c70 commit b0b884e

File tree

11 files changed

+685
-125
lines changed

11 files changed

+685
-125
lines changed

app/api/v1/agui_ws.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""
2+
AG-UI WebSocket endpoint.
3+
4+
Implements the Agent-User Interaction protocol over WebSockets.
5+
Each connection is authenticated via a JWT query-param that embeds the
6+
``conversationId``. Messages are persisted through ``ChatService``.
7+
"""
8+
9+
import json
10+
import logging
11+
import uuid
12+
13+
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
14+
15+
from app.security.auth import verify_conversation_token
16+
from app.services.chat_service import chat_service
17+
18+
logger = logging.getLogger(__name__)
19+
20+
router = APIRouter()
21+
22+
23+
def _agui_event(event_type: str, **fields) -> str:
24+
"""Serialise a single AG-UI event as a JSON text frame."""
25+
return json.dumps({"type": event_type, **fields})
26+
27+
28+
@router.websocket("/ws")
29+
async def agui_websocket(
30+
websocket: WebSocket,
31+
token: str = Query(...),
32+
):
33+
conversation_id = verify_conversation_token(token)
34+
await websocket.accept()
35+
logger.info(f"[AG-UI] WebSocket connected conv_id={conversation_id}")
36+
37+
try:
38+
while True:
39+
raw = await websocket.receive_text()
40+
41+
try:
42+
data = json.loads(raw)
43+
except json.JSONDecodeError:
44+
await websocket.send_text(
45+
_agui_event("RUN_ERROR", message="Invalid JSON", code="BAD_REQUEST")
46+
)
47+
continue
48+
49+
user_message = data.get("message", "").strip()
50+
if not user_message:
51+
await websocket.send_text(
52+
_agui_event("RUN_ERROR", message="Empty message", code="BAD_REQUEST")
53+
)
54+
continue
55+
56+
wizard_state = data.get("wizard_state")
57+
run_id = str(uuid.uuid4())
58+
thread_id = str(conversation_id)
59+
60+
await websocket.send_text(
61+
_agui_event("RUN_STARTED", threadId=thread_id, runId=run_id)
62+
)
63+
64+
try:
65+
result = await chat_service.process_message(
66+
message=user_message,
67+
conversation_id=conversation_id,
68+
wizard_state=wizard_state,
69+
)
70+
71+
response_text = result.get("response", "")
72+
message_id = str(uuid.uuid4())
73+
74+
await websocket.send_text(
75+
_agui_event("TEXT_MESSAGE_START", messageId=message_id, role="assistant")
76+
)
77+
await websocket.send_text(
78+
_agui_event("TEXT_MESSAGE_CONTENT", messageId=message_id, delta=response_text)
79+
)
80+
await websocket.send_text(
81+
_agui_event("TEXT_MESSAGE_END", messageId=message_id)
82+
)
83+
84+
state_snapshot = _build_state_snapshot(result)
85+
if state_snapshot:
86+
await websocket.send_text(
87+
_agui_event("STATE_SNAPSHOT", snapshot=state_snapshot)
88+
)
89+
90+
await websocket.send_text(
91+
_agui_event("RUN_FINISHED", threadId=thread_id, runId=run_id)
92+
)
93+
94+
except Exception as exc:
95+
logger.error(f"[AG-UI] Error processing message: {exc}", exc_info=True)
96+
await websocket.send_text(
97+
_agui_event("RUN_ERROR", message=str(exc), code="INTERNAL_ERROR")
98+
)
99+
await websocket.send_text(
100+
_agui_event("RUN_FINISHED", threadId=thread_id, runId=run_id)
101+
)
102+
103+
except WebSocketDisconnect:
104+
logger.info(f"[AG-UI] WebSocket disconnected conv_id={conversation_id}")
105+
106+
107+
def _build_state_snapshot(result: dict) -> dict | None:
108+
"""Extract wizard / agent metadata into an AG-UI STATE_SNAPSHOT payload."""
109+
snapshot: dict = {}
110+
111+
if result.get("wizard_state") and result["wizard_state"] != "INACTIVE":
112+
snapshot["wizard_state"] = result.get("wizard_state")
113+
snapshot["current_question"] = result.get("current_question")
114+
snapshot["wizard_responses"] = result.get("wizard_responses", {})
115+
snapshot["awaiting_answer"] = result.get("awaiting_answer", False)
116+
snapshot["wizard_session_id"] = result.get("wizard_session_id")
117+
118+
snapshot["agent_used"] = result.get("agent_used", "unknown")
119+
120+
return snapshot if snapshot else None

app/api/v1/chat.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

app/api/v1/conversations.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from app.db.config.database import get_async_session
1010
from app.db.models import Conversation
11+
from app.security.auth import create_conversation_token
1112

1213
router = APIRouter()
1314

@@ -22,6 +23,11 @@ class ConversationResponse(BaseModel):
2223
started_at: datetime
2324

2425

26+
class ConversationInitResponse(BaseModel):
27+
token: str
28+
conversationId: int
29+
30+
2531
@router.post("/conversations", response_model=ConversationResponse)
2632
async def create_conversation(
2733
conversation: ConversationCreate,
@@ -42,6 +48,25 @@ async def create_conversation(
4248
raise HTTPException(status_code=500, detail="Error creating conversation")
4349

4450

51+
@router.post("/conversations/init", response_model=ConversationInitResponse)
52+
async def init_conversation(
53+
payload: ConversationCreate = ConversationCreate(),
54+
session: AsyncSession = Depends(get_async_session),
55+
) -> ConversationInitResponse:
56+
"""Create a new conversation and return a signed JWT for the WebSocket."""
57+
try:
58+
new_conv = Conversation(email=payload.email)
59+
session.add(new_conv)
60+
await session.commit()
61+
await session.refresh(new_conv)
62+
63+
token = create_conversation_token(new_conv.id)
64+
return ConversationInitResponse(token=token, conversationId=new_conv.id)
65+
except Exception:
66+
await session.rollback()
67+
raise HTTPException(status_code=500, detail="Error creating conversation")
68+
69+
4570
@router.get("/conversations")
4671
async def get_conversations(
4772
session: AsyncSession = Depends(get_async_session)

app/api/v1/copilotkit_endpoint.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

app/main.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,28 @@
1212
force=True,
1313
)
1414

15-
from app.api.v1.chat import router as chat_router
15+
from app.api.v1.agui_ws import router as agui_ws_router
1616
from app.api.v1.conversations import router as conversations_router
17-
from app.api.v1.copilotkit_endpoint import router as copilotkit_router
1817
from app.api.v1.documents import router as documents_router
1918
from app.api.v1.scoring import router as scoring_router
2019

2120
v1 = '/api/v1'
2221

2322
app = FastAPI(title="Chatbot Backend", version="1.0.0")
2423

25-
# Configurar CORS para permitir conexiones desde el frontend
2624
app.add_middleware(
2725
CORSMiddleware,
2826
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000",
29-
"http://localhost:3001"], # Frontend URLs
27+
"http://localhost:3001"],
3028
allow_credentials=True,
3129
allow_methods=["*"],
3230
allow_headers=["*"],
3331
)
3432

3533
app.include_router(conversations_router, prefix=v1, tags=["Conversations"])
36-
app.include_router(chat_router, prefix=v1, tags=["Chat"])
34+
app.include_router(agui_ws_router, prefix=v1, tags=["AG-UI WebSocket"])
3735
app.include_router(documents_router, prefix=v1, tags=["Documents"])
3836
app.include_router(scoring_router, prefix=v1, tags=["Scoring"])
39-
app.include_router(copilotkit_router, prefix=v1, tags=["CopilotKit"])
4037

4138

4239
@app.get("/")

app/security/auth.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import os
22
import secrets
33
from dataclasses import dataclass
4+
from datetime import datetime, timedelta, timezone
45

6+
import jwt
57
from fastapi import Depends, HTTPException, status
68
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
79

810
bearer_scheme = HTTPBearer(auto_error=False)
911

12+
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "change-me-in-production")
13+
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
14+
JWT_EXPIRATION_MINUTES = int(os.getenv("JWT_EXPIRATION_MINUTES", "60"))
15+
1016

1117
@dataclass(frozen=True)
1218
class AuthUser:
@@ -76,3 +82,37 @@ async def require_admin_user(
7682
detail="Permisos insuficientes. Se requiere rol admin",
7783
)
7884
return user
85+
86+
87+
# ---------------------------------------------------------------------------
88+
# JWT helpers for WebSocket conversation tokens
89+
# ---------------------------------------------------------------------------
90+
91+
def create_conversation_token(conversation_id: int) -> str:
92+
"""Sign a short-lived JWT that embeds the conversation ID."""
93+
payload = {
94+
"sub": str(conversation_id),
95+
"exp": datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRATION_MINUTES),
96+
"iat": datetime.now(timezone.utc),
97+
}
98+
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
99+
100+
101+
def verify_conversation_token(token: str) -> int:
102+
"""Validate a conversation JWT and return the integer conversation_id.
103+
104+
Raises ``HTTPException(401)`` on any validation failure.
105+
"""
106+
try:
107+
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
108+
return int(payload["sub"])
109+
except jwt.ExpiredSignatureError:
110+
raise HTTPException(
111+
status_code=status.HTTP_401_UNAUTHORIZED,
112+
detail="Token expirado",
113+
)
114+
except (jwt.InvalidTokenError, KeyError, ValueError) as exc:
115+
raise HTTPException(
116+
status_code=status.HTTP_401_UNAUTHORIZED,
117+
detail=f"Token inválido: {exc}",
118+
)

0 commit comments

Comments
 (0)