Skip to content

Commit fc9afbb

Browse files
simonrosenbergDebug Agentclaude
authored
fix(security): first-message WebSocket auth to prevent token leakage (#2790)
Co-authored-by: Debug Agent <debug@example.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9fdb2f0 commit fc9afbb

3 files changed

Lines changed: 469 additions & 8 deletions

File tree

openhands-agent-server/openhands/agent_server/sockets.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22
WebSocket endpoints for OpenHands SDK.
33
44
These endpoints are separate from the main API routes to handle WebSocket-specific
5-
authentication. Browsers cannot send custom HTTP headers directly with WebSocket
6-
connections, so we support the `session_api_key` query param. For non-browser
7-
clients (e.g. Python/Node), we also support authenticating via headers.
5+
authentication. Three auth methods are supported (highest to lowest precedence):
6+
7+
1. **First-message auth** (recommended): The client sends
8+
``{"type": "auth", "session_api_key": "..."}`` as the very first WebSocket
9+
frame after the connection opens. This keeps tokens out of URLs and
10+
therefore out of reverse-proxy / load-balancer access logs.
11+
2. Query parameter ``session_api_key`` — deprecated, kept for backwards compat.
12+
3. ``X-Session-API-Key`` header — for non-browser clients.
813
"""
914

15+
import asyncio
16+
import json
1017
import logging
1118
from dataclasses import dataclass
1219
from datetime import datetime
@@ -78,18 +85,102 @@ def _resolve_websocket_session_api_key(
7885
return None
7986

8087

88+
# Give clients 10 seconds to send auth frame after connection opens.
89+
# This balances security (don't hold connections indefinitely) with
90+
# accommodating slow networks and client startup time.
91+
_FIRST_MESSAGE_AUTH_TIMEOUT_SECONDS = 10
92+
93+
8194
async def _accept_authenticated_websocket(
8295
websocket: WebSocket,
8396
session_api_key: str | None,
8497
) -> bool:
85-
"""Authenticate and accept the socket, or close with an auth error."""
98+
"""Authenticate and accept the socket, or close with an auth error.
99+
100+
Authentication is attempted in the following order:
101+
102+
1. Query parameter / header (legacy, deprecated).
103+
2. First-message auth — the client sends
104+
``{"type": "auth", "session_api_key": "..."}`` as the first frame.
105+
106+
The WebSocket is always *accepted* before first-message auth is attempted
107+
because raw WebSocket requires ``accept()`` before any frames can be read.
108+
"""
86109
config = _get_config(websocket)
87110
resolved_key = _resolve_websocket_session_api_key(websocket, session_api_key)
88-
if config.session_api_keys and resolved_key not in config.session_api_keys:
89-
logger.warning("WebSocket authentication failed: invalid or missing API key")
111+
112+
# No auth configured — accept unconditionally.
113+
if not config.session_api_keys:
114+
await websocket.accept()
115+
return True
116+
117+
# Legacy path: key supplied via query param or header.
118+
if resolved_key is not None:
119+
if resolved_key in config.session_api_keys:
120+
logger.warning(
121+
"session_api_key passed via query param or header is deprecated. "
122+
"Use first-message auth instead."
123+
)
124+
await websocket.accept()
125+
return True
126+
logger.warning("WebSocket authentication failed: invalid API key")
90127
await websocket.close(code=4001, reason="Authentication failed")
91128
return False
129+
130+
# First-message auth: we must accept() before reading frames because the
131+
# WebSocket protocol requires the handshake to complete first. The legacy
132+
# path above can reject *before* accepting (close on an un-accepted socket
133+
# sends an HTTP 403-style response), but here we need to read a frame.
92134
await websocket.accept()
135+
try:
136+
raw = await asyncio.wait_for(
137+
websocket.receive_text(),
138+
timeout=_FIRST_MESSAGE_AUTH_TIMEOUT_SECONDS,
139+
)
140+
data = json.loads(raw)
141+
except TimeoutError:
142+
logger.warning(
143+
"WebSocket first-message auth failed: timeout waiting for auth frame"
144+
)
145+
await _safe_close_websocket(
146+
websocket, code=4001, reason="Authentication failed"
147+
)
148+
return False
149+
except json.JSONDecodeError:
150+
logger.warning("WebSocket first-message auth failed: malformed JSON")
151+
await _safe_close_websocket(
152+
websocket, code=4001, reason="Authentication failed"
153+
)
154+
return False
155+
except WebSocketDisconnect:
156+
logger.warning("WebSocket first-message auth failed: client disconnected")
157+
await _safe_close_websocket(
158+
websocket, code=4001, reason="Authentication failed"
159+
)
160+
return False
161+
162+
if not isinstance(data, dict):
163+
logger.warning(
164+
"WebSocket first-message auth failed: payload is not a JSON object"
165+
)
166+
await _safe_close_websocket(
167+
websocket, code=4001, reason="Authentication failed"
168+
)
169+
return False
170+
if data.get("type") != "auth":
171+
logger.warning("WebSocket first-message auth failed: wrong message type")
172+
await _safe_close_websocket(
173+
websocket, code=4001, reason="Authentication failed"
174+
)
175+
return False
176+
if data.get("session_api_key") not in config.session_api_keys:
177+
logger.warning("WebSocket first-message auth failed: invalid API key")
178+
await _safe_close_websocket(
179+
websocket, code=4001, reason="Authentication failed"
180+
)
181+
return False
182+
183+
logger.info("WebSocket authenticated via first-message auth")
93184
return True
94185

95186

@@ -329,9 +420,13 @@ async def _send_event(event: Event, websocket: WebSocket):
329420
logger.exception("error_sending_event: %r", event, stack_info=True)
330421

331422

332-
async def _safe_close_websocket(websocket: WebSocket):
423+
async def _safe_close_websocket(
424+
websocket: WebSocket,
425+
code: int = 1000,
426+
reason: str = "Connection closed",
427+
):
333428
try:
334-
await websocket.close(code=1000, reason="Connection closed")
429+
await websocket.close(code=code, reason=reason)
335430
except Exception:
336431
# WebSocket may already be closed or in inconsistent state
337432
logger.debug("WebSocket close failed (may already be closed)")

tests/agent_server/test_agent_server_wsproto.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import socket
88
import sys
99
import time
10+
from uuid import uuid4
1011

1112
import pytest
1213
import requests
1314
import websockets
15+
import websockets.exceptions
1416

1517

1618
def find_free_port():
@@ -155,3 +157,96 @@ async def test_agent_server_websocket_with_wsproto_header_auth(agent_server):
155157
{"role": "user", "content": "Hello from wsproto header auth test"}
156158
)
157159
)
160+
161+
162+
@pytest.mark.asyncio
163+
async def test_agent_server_websocket_first_message_auth_accepted(agent_server):
164+
"""First-message auth: connect with no query/header key, auth via first frame.
165+
166+
Exercises the real WebSocket protocol transition (handshake → consume first
167+
frame as auth → continue normal message flow) that mock-only tests can't
168+
cover. See PR review feedback on test coverage gaps.
169+
"""
170+
port = agent_server["port"]
171+
api_key = agent_server["api_key"]
172+
173+
response = requests.post(
174+
f"http://127.0.0.1:{port}/api/conversations",
175+
headers={"X-Session-API-Key": api_key},
176+
json={
177+
"agent": {
178+
"kind": "Agent",
179+
"llm": {
180+
"usage_id": "test-llm",
181+
"model": "test-provider/test-model",
182+
"api_key": "test-key",
183+
},
184+
"tools": [],
185+
},
186+
"workspace": {"working_dir": "/tmp/test-workspace"},
187+
},
188+
)
189+
assert response.status_code in [200, 201]
190+
conversation_id = response.json()["id"]
191+
192+
# No session_api_key in URL or header — must authenticate via first frame.
193+
ws_url = f"ws://127.0.0.1:{port}/sockets/events/{conversation_id}?resend_all=true"
194+
195+
async with websockets.connect(ws_url, open_timeout=5) as ws:
196+
# Send the auth frame as the very first message after handshake.
197+
await ws.send(json.dumps({"type": "auth", "session_api_key": api_key}))
198+
199+
# Connection must remain usable: try to receive (resend_all may produce
200+
# nothing for an empty conversation, so a timeout here is fine).
201+
try:
202+
response = await asyncio.wait_for(ws.recv(), timeout=2)
203+
assert response is not None
204+
except TimeoutError:
205+
pass
206+
207+
# Subsequent message must be processed as a Message (not auth) — proves
208+
# the auth frame was consumed by the auth handler, not the main loop.
209+
await ws.send(
210+
json.dumps({"role": "user", "content": "Hello after first-message auth"})
211+
)
212+
213+
214+
@pytest.mark.asyncio
215+
async def test_agent_server_websocket_first_message_auth_rejected(agent_server):
216+
"""First-message auth: invalid key triggers WebSocket close with code 4001."""
217+
port = agent_server["port"]
218+
219+
# No conversation needed — auth rejection happens before conversation lookup.
220+
ws_url = f"ws://127.0.0.1:{port}/sockets/events/{uuid4()}"
221+
222+
async with websockets.connect(ws_url, open_timeout=5) as ws:
223+
# Send an invalid first-message auth frame.
224+
await ws.send(
225+
json.dumps({"type": "auth", "session_api_key": "definitely-wrong-key"})
226+
)
227+
228+
# Server must close the connection with code 4001 ("Authentication
229+
# failed"). Receiving on a closed socket raises ConnectionClosed.
230+
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
231+
await asyncio.wait_for(ws.recv(), timeout=5)
232+
233+
assert exc_info.value.rcvd is not None
234+
assert exc_info.value.rcvd.code == 4001
235+
236+
237+
@pytest.mark.asyncio
238+
async def test_agent_server_websocket_first_message_auth_malformed(agent_server):
239+
"""First-message auth: malformed JSON triggers close with code 4001."""
240+
port = agent_server["port"]
241+
242+
ws_url = f"ws://127.0.0.1:{port}/sockets/events/{uuid4()}"
243+
244+
async with websockets.connect(ws_url, open_timeout=5) as ws:
245+
# Send invalid JSON as the first frame.
246+
await ws.send("this is not json")
247+
248+
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
249+
await asyncio.wait_for(ws.recv(), timeout=5)
250+
251+
assert exc_info.value.rcvd is not None
252+
assert exc_info.value.rcvd.code == 4001

0 commit comments

Comments
 (0)